Skip to content

feat(linear_solve): matrix inversion linear solver with neumann series approximation #98

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
d158462
feat(linear_solve): matrix inversion linear solver with neumann serie…
XuehaiPan Oct 13, 2022
e171f38
fix: fix ns
XuehaiPan Nov 6, 2022
db88ee3
fix: add test
Benjamin-eecs Nov 6, 2022
0946dbe
test: update tests
XuehaiPan Nov 6, 2022
b3cef2f
fix: add test
Benjamin-eecs Nov 6, 2022
e3a38a7
fix: add test
Benjamin-eecs Nov 7, 2022
de27c02
fix: add test
Benjamin-eecs Nov 7, 2022
183e769
fix: pass test
Benjamin-eecs Nov 7, 2022
155ab8d
fix: pass lint
Benjamin-eecs Nov 7, 2022
081d0b1
fix: pass lint
Benjamin-eecs Nov 7, 2022
7fc64b7
fix: pass lint
Benjamin-eecs Nov 7, 2022
1199814
fix: pass lint
Benjamin-eecs Nov 7, 2022
ede0cba
fix: update test
Benjamin-eecs Nov 7, 2022
54f8e01
chore: update CHANGELOG
Benjamin-eecs Nov 7, 2022
dc49c4f
merge: resolve conflicts
Benjamin-eecs Nov 7, 2022
d1d316c
fix: resolve comments
Benjamin-eecs Nov 7, 2022
31cb010
docs: add solve_inv
Benjamin-eecs Nov 7, 2022
909bd81
chore: update Makefile
Benjamin-eecs Nov 7, 2022
ff54c48
docs: update
Benjamin-eecs Nov 7, 2022
7cb77c8
wip
XuehaiPan Nov 7, 2022
03d6972
wip
XuehaiPan Nov 7, 2022
e8c4b38
wip
XuehaiPan Nov 7, 2022
fcf5148
wip
XuehaiPan Nov 7, 2022
8cd84bb
fix: pass test
Benjamin-eecs Nov 7, 2022
4a7db3c
wip
XuehaiPan Nov 8, 2022
8865eb6
wip
XuehaiPan Nov 8, 2022
afbcf11
wip
XuehaiPan Nov 8, 2022
77cd2ee
wip
XuehaiPan Nov 8, 2022
6753929
wip
XuehaiPan Nov 8, 2022
70ec405
feat: support normalize matvec with tensortree
XuehaiPan Nov 8, 2022
2bd55ba
feat: support implicit matvec
XuehaiPan Nov 8, 2022
564c7ec
feat: support implicit matvec
XuehaiPan Nov 8, 2022
4e7c1f9
fix: fix jacobian tree compose
XuehaiPan Nov 9, 2022
071026a
feat: multi-tensor support for solve_inv
XuehaiPan Nov 9, 2022
08de4bc
docs: update linear_solve docs
XuehaiPan Nov 9, 2022
5a0e458
chore: update ns_inv
XuehaiPan Nov 9, 2022
71ddf58
chore: add shortcuts
XuehaiPan Nov 9, 2022
67b774d
docs: update dictionary
XuehaiPan Nov 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
test: update tests
  • Loading branch information
XuehaiPan committed Nov 6, 2022
commit 0946dbe48c8833660c2fcab8e14efbc1ce3cd0ce
130 changes: 19 additions & 111 deletions tests/test_implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import copy
from collections import OrderedDict
from types import FunctionType
from typing import Tuple
from typing import Callable, Tuple

import functorch
import jax
Expand Down Expand Up @@ -340,15 +340,22 @@ def outer_level(p, xs, ys):
@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
solvers=[
(torchopt.linear_solve.solve_cg, jaxopt.linear_solve.solve_cg),
(torchopt.linear_solve.solve_inv, jaxopt.linear_solve.solve_inv),
],
)
def test_rr_solve_cg(
def test_rr(
dtype: torch.dtype,
lr: float,
solvers: Tuple[Callable, Callable],
) -> None:
helpers.seed_everything(42)
np_dtype = helpers.dtype_torch2numpy(dtype)
input_size = 10

torchopt_solve, jaxopt_solve = solvers

init_params_torch = torch.randn(input_size, dtype=dtype)
l2reg_torch = torch.rand(1, dtype=dtype).squeeze_().requires_grad_(True)

Expand All @@ -367,18 +374,19 @@ def ridge_objective_torch(params, l2reg, data):
"""Ridge objective function."""
X_tr, y_tr = data
residuals = X_tr @ params - y_tr
loss = 0.5 * torch.mean(torch.square(residuals))
regularization_loss = 0.5 * l2reg * torch.sum(torch.square(params))
return 0.5 * torch.mean(torch.square(residuals)) + regularization_loss
return loss + regularization_loss

@torchopt.diff.implicit.custom_root(functorch.grad(ridge_objective_torch, argnums=0), argnums=1)
def ridge_solver_torch_cg(params, l2reg, data):
def ridge_solver_torch(params, l2reg, data):
"""Solve ridge regression by conjugate gradient."""
X_tr, y_tr = data

def matvec(u):
return X_tr.T @ (X_tr @ u)

solve = torchopt.linear_solve.solve_cg(
solve = torchopt_solve(
ridge=len(y_tr) * l2reg.item(),
init=params,
maxiter=20,
Expand All @@ -389,17 +397,18 @@ def matvec(u):
def ridge_objective_jax(params, l2reg, X_tr, y_tr):
"""Ridge objective function."""
residuals = X_tr @ params - y_tr
loss = 0.5 * jnp.mean(jnp.square(residuals))
regularization_loss = 0.5 * l2reg * jnp.sum(jnp.square(params))
return 0.5 * jnp.mean(jnp.square(residuals)) + regularization_loss
return loss + regularization_loss

@jaxopt.implicit_diff.custom_root(jax.grad(ridge_objective_jax, argnums=0))
def ridge_solver_jax_cg(params, l2reg, X_tr, y_tr):
def ridge_solver_jax(params, l2reg, X_tr, y_tr):
"""Solve ridge regression by conjugate gradient."""

def matvec(u):
return X_tr.T @ ((X_tr @ u))

return jaxopt.linear_solve.solve_cg(
return jaxopt_solve(
matvec=matvec,
b=X_tr.T @ y_tr,
ridge=len(y_tr) * l2reg.item(),
Expand All @@ -413,108 +422,7 @@ def matvec(u):
xq = xq.to(dtype=dtype)
yq = yq.to(dtype=dtype)

w_fit = ridge_solver_torch_cg(init_params_torch, l2reg_torch, (xs, ys))
outer_loss = F.mse_loss(xq @ w_fit, yq)

grads, *_ = torch.autograd.grad(outer_loss, l2reg_torch)
updates, optim_state = optim.update(grads, optim_state)
l2reg_torch = torchopt.apply_updates(l2reg_torch, updates)

xs = jnp.array(xs.numpy(), dtype=np_dtype)
ys = jnp.array(ys.numpy(), dtype=np_dtype)
xq = jnp.array(xq.numpy(), dtype=np_dtype)
yq = jnp.array(yq.numpy(), dtype=np_dtype)

def outer_level(params_jax, l2reg_jax, xs, ys, xq, yq):
w_fit = ridge_solver_jax_cg(params_jax, l2reg_jax, xs, ys)
y_pred = xq @ w_fit
loss_value = jnp.mean(jnp.square(y_pred - yq))
return loss_value

grads_jax = jax.grad(outer_level, argnums=1)(init_params_jax, l2reg_jax, xs, ys, xq, yq)
updates_jax, optim_state_jax = optim_jax.update(grads_jax, optim_state_jax) # get updates
l2reg_jax = optax.apply_updates(l2reg_jax, updates_jax)

l2reg_jax_as_tensor = torch.tensor(np.asarray(l2reg_jax), dtype=dtype)
helpers.assert_all_close(l2reg_torch, l2reg_jax_as_tensor)


@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
)
def test_rr_solve_inv(
dtype: torch.dtype,
lr: float,
) -> None:
helpers.seed_everything(42)
np_dtype = helpers.dtype_torch2numpy(dtype)
input_size = 10

init_params_torch = torch.randn(input_size, dtype=dtype)
l2reg_torch = torch.rand(1, dtype=dtype).squeeze_().requires_grad_(True)

init_params_jax = jnp.array(init_params_torch.detach().numpy(), dtype=np_dtype)
l2reg_jax = jnp.array(l2reg_torch.detach().numpy(), dtype=np_dtype)

loader = get_rr_dataset_torch()

optim = torchopt.sgd(lr)
optim_state = optim.init(l2reg_torch)

optim_jax = optax.sgd(lr)
optim_state_jax = optim_jax.init(l2reg_jax)

def ridge_objective_torch(params, l2reg, data):
"""Ridge objective function."""
X_tr, y_tr = data
residuals = X_tr @ params - y_tr
regularization_loss = 0.5 * l2reg * torch.sum(torch.square(params))
return 0.5 * torch.mean(torch.square(residuals)) + regularization_loss

@torchopt.diff.implicit.custom_root(functorch.grad(ridge_objective_torch, argnums=0), argnums=1)
def ridge_solver_torch_inv(params, l2reg, data):
"""Solve ridge regression by conjugate gradient."""
X_tr, y_tr = data

def matvec(u):
return X_tr.T @ (X_tr @ u)

solve = torchopt.linear_solve.solve_inv(
matvec=matvec,
b=X_tr.T @ y_tr,
ridge=len(y_tr) * l2reg.item(),
ns=True,
)

return solve(matvec=matvec, b=X_tr.T @ y_tr)

def ridge_objective_jax(params, l2reg, X_tr, y_tr):
"""Ridge objective function."""
residuals = X_tr @ params - y_tr
regularization_loss = 0.5 * l2reg * jnp.sum(jnp.square(params))
return 0.5 * jnp.mean(jnp.square(residuals)) + regularization_loss

@jaxopt.implicit_diff.custom_root(jax.grad(ridge_objective_jax, argnums=0))
def ridge_solver_jax_inv(params, l2reg, X_tr, y_tr):
"""Solve ridge regression by conjugate gradient."""

def matvec(u):
return X_tr.T @ ((X_tr @ u))

return jaxopt.linear_solve.solve_inv(
matvec=matvec,
b=X_tr.T @ y_tr,
ridge=len(y_tr) * l2reg.item(),
)

for xs, ys, xq, yq in loader:
xs = xs.to(dtype=dtype)
ys = ys.to(dtype=dtype)
xq = xq.to(dtype=dtype)
yq = yq.to(dtype=dtype)

w_fit = ridge_solver_torch_inv(init_params_torch, l2reg_torch, (xs, ys))
w_fit = ridge_solver_torch(init_params_torch, l2reg_torch, (xs, ys))
outer_loss = F.mse_loss(xq @ w_fit, yq)

grads, *_ = torch.autograd.grad(outer_loss, l2reg_torch)
Expand All @@ -527,7 +435,7 @@ def matvec(u):
yq = jnp.array(yq.numpy(), dtype=np_dtype)

def outer_level(params_jax, l2reg_jax, xs, ys, xq, yq):
w_fit = ridge_solver_jax_inv(params_jax, l2reg_jax, xs, ys)
w_fit = ridge_solver_jax(params_jax, l2reg_jax, xs, ys)
y_pred = xq @ w_fit
loss_value = jnp.mean(jnp.square(y_pred - yq))
return loss_value
Expand Down
10 changes: 7 additions & 3 deletions torchopt/linear_solve/inv.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
# pylint: disable=invalid-name

import functools
from typing import Callable, Optional
from typing import Callable, Optional, Tuple

import functorch
import torch
Expand All @@ -49,9 +49,13 @@
__all__ = ['solve_inv']


def materialize_array(matvec, shape, dtype=None):
def materialize_array(
matvec: Callable[[TensorTree], TensorTree],
shape: Tuple[int, ...],
dtype: Optional[torch.dtype] = None,
) -> TensorTree:
"""Materializes the matrix ``A`` used in ``matvec(x) = A x``."""
x = torch.zeros(shape, dtype)
x = torch.zeros(shape, dtype=dtype)
return functorch.jacfwd(matvec)(x)


Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy