Skip to content

feat(linear_solve): matrix inversion linear solver with neumann series approximation #98

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
d158462
feat(linear_solve): matrix inversion linear solver with neumann serie…
XuehaiPan Oct 13, 2022
e171f38
fix: fix ns
XuehaiPan Nov 6, 2022
db88ee3
fix: add test
Benjamin-eecs Nov 6, 2022
0946dbe
test: update tests
XuehaiPan Nov 6, 2022
b3cef2f
fix: add test
Benjamin-eecs Nov 6, 2022
e3a38a7
fix: add test
Benjamin-eecs Nov 7, 2022
de27c02
fix: add test
Benjamin-eecs Nov 7, 2022
183e769
fix: pass test
Benjamin-eecs Nov 7, 2022
155ab8d
fix: pass lint
Benjamin-eecs Nov 7, 2022
081d0b1
fix: pass lint
Benjamin-eecs Nov 7, 2022
7fc64b7
fix: pass lint
Benjamin-eecs Nov 7, 2022
1199814
fix: pass lint
Benjamin-eecs Nov 7, 2022
ede0cba
fix: update test
Benjamin-eecs Nov 7, 2022
54f8e01
chore: update CHANGELOG
Benjamin-eecs Nov 7, 2022
dc49c4f
merge: resolve conflicts
Benjamin-eecs Nov 7, 2022
d1d316c
fix: resolve comments
Benjamin-eecs Nov 7, 2022
31cb010
docs: add solve_inv
Benjamin-eecs Nov 7, 2022
909bd81
chore: update Makefile
Benjamin-eecs Nov 7, 2022
ff54c48
docs: update
Benjamin-eecs Nov 7, 2022
7cb77c8
wip
XuehaiPan Nov 7, 2022
03d6972
wip
XuehaiPan Nov 7, 2022
e8c4b38
wip
XuehaiPan Nov 7, 2022
fcf5148
wip
XuehaiPan Nov 7, 2022
8cd84bb
fix: pass test
Benjamin-eecs Nov 7, 2022
4a7db3c
wip
XuehaiPan Nov 8, 2022
8865eb6
wip
XuehaiPan Nov 8, 2022
afbcf11
wip
XuehaiPan Nov 8, 2022
77cd2ee
wip
XuehaiPan Nov 8, 2022
6753929
wip
XuehaiPan Nov 8, 2022
70ec405
feat: support normalize matvec with tensortree
XuehaiPan Nov 8, 2022
2bd55ba
feat: support implicit matvec
XuehaiPan Nov 8, 2022
564c7ec
feat: support implicit matvec
XuehaiPan Nov 8, 2022
4e7c1f9
fix: fix jacobian tree compose
XuehaiPan Nov 9, 2022
071026a
feat: multi-tensor support for solve_inv
XuehaiPan Nov 9, 2022
08de4bc
docs: update linear_solve docs
XuehaiPan Nov 9, 2022
5a0e458
chore: update ns_inv
XuehaiPan Nov 9, 2022
71ddf58
chore: add shortcuts
XuehaiPan Nov 9, 2022
67b774d
docs: update dictionary
XuehaiPan Nov 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,8 @@ good-names=i,
lr,
mu,
nu,
x
x,
y

# Good variable names regexes, separated by a comma. If names match any regex,
# they will always be accepted
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add matrix inversion linear solver with neumann series approximation by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#98](https://github.com/metaopt/torchopt/pull/98).
- Add if condition of number of threads for CPU OPs by [@JieRen98](https://github.com/JieRen98) in [#105](https://github.com/metaopt/torchopt/pull/105).
- Add implicit MAML omniglot few-shot classification example with OOP APIs by [@XuehaiPan](https://github.com/XuehaiPan) in [#107](https://github.com/metaopt/torchopt/pull/107).
- Add implicit MAML omniglot few-shot classification example by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#48](https://github.com/metaopt/torchopt/pull/48).
Expand Down
2 changes: 2 additions & 0 deletions docs/source/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,14 @@ Linear system solving

solve_cg
solve_normal_cg
solve_inv

Indirect solvers
~~~~~~~~~~~~~~~~

.. autofunction:: solve_cg
.. autofunction:: solve_normal_cg
.. autofunction:: solve_inv

------

Expand Down
116 changes: 111 additions & 5 deletions tests/test_implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import jaxopt
import numpy as np
import optax
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -341,7 +342,7 @@ def outer_level(p, xs, ys):
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
)
def test_rr(
def test_rr_solve_cg(
dtype: torch.dtype,
lr: float,
) -> None:
Expand Down Expand Up @@ -371,7 +372,7 @@ def ridge_objective_torch(params, l2reg, data):
return 0.5 * torch.mean(torch.square(residuals)) + regularization_loss

@torchopt.diff.implicit.custom_root(functorch.grad(ridge_objective_torch, argnums=0), argnums=1)
def ridge_solver_torch(params, l2reg, data):
def ridge_solver_torch_cg(params, l2reg, data):
"""Solve ridge regression by conjugate gradient."""
X_tr, y_tr = data

Expand All @@ -393,7 +394,7 @@ def ridge_objective_jax(params, l2reg, X_tr, y_tr):
return 0.5 * jnp.mean(jnp.square(residuals)) + regularization_loss

@jaxopt.implicit_diff.custom_root(jax.grad(ridge_objective_jax, argnums=0))
def ridge_solver_jax(params, l2reg, X_tr, y_tr):
def ridge_solver_jax_cg(params, l2reg, X_tr, y_tr):
"""Solve ridge regression by conjugate gradient."""

def matvec(u):
Expand All @@ -413,7 +414,7 @@ def matvec(u):
xq = xq.to(dtype=dtype)
yq = yq.to(dtype=dtype)

w_fit = ridge_solver_torch(init_params_torch, l2reg_torch, (xs, ys))
w_fit = ridge_solver_torch_cg(init_params_torch, l2reg_torch, (xs, ys))
outer_loss = F.mse_loss(xq @ w_fit, yq)

grads, *_ = torch.autograd.grad(outer_loss, l2reg_torch)
Expand All @@ -426,7 +427,112 @@ def matvec(u):
yq = jnp.array(yq.numpy(), dtype=np_dtype)

def outer_level(params_jax, l2reg_jax, xs, ys, xq, yq):
w_fit = ridge_solver_jax(params_jax, l2reg_jax, xs, ys)
w_fit = ridge_solver_jax_cg(params_jax, l2reg_jax, xs, ys)
y_pred = xq @ w_fit
loss_value = jnp.mean(jnp.square(y_pred - yq))
return loss_value

grads_jax = jax.grad(outer_level, argnums=1)(init_params_jax, l2reg_jax, xs, ys, xq, yq)
updates_jax, optim_state_jax = optim_jax.update(grads_jax, optim_state_jax) # get updates
l2reg_jax = optax.apply_updates(l2reg_jax, updates_jax)

l2reg_jax_as_tensor = torch.tensor(np.asarray(l2reg_jax), dtype=dtype)
helpers.assert_all_close(l2reg_torch, l2reg_jax_as_tensor)


@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
ns=[True, False],
)
def test_rr_solve_inv(
dtype: torch.dtype,
lr: float,
ns: bool,
) -> None:
if dtype == torch.float64 and ns:
pytest.skip('Neumann Series test skips torch.float64 due to numerical stability.')
helpers.seed_everything(42)
np_dtype = helpers.dtype_torch2numpy(dtype)
input_size = 10

init_params_torch = torch.randn(input_size, dtype=dtype)
l2reg_torch = torch.rand(1, dtype=dtype).squeeze_().requires_grad_(True)

init_params_jax = jnp.array(init_params_torch.detach().numpy(), dtype=np_dtype)
l2reg_jax = jnp.array(l2reg_torch.detach().numpy(), dtype=np_dtype)

loader = get_rr_dataset_torch()

optim = torchopt.sgd(lr)
optim_state = optim.init(l2reg_torch)

optim_jax = optax.sgd(lr)
optim_state_jax = optim_jax.init(l2reg_jax)

def ridge_objective_torch(params, l2reg, data):
"""Ridge objective function."""
X_tr, y_tr = data
residuals = X_tr @ params - y_tr
regularization_loss = 0.5 * l2reg * torch.sum(torch.square(params))
return 0.5 * torch.mean(torch.square(residuals)) + regularization_loss

@torchopt.diff.implicit.custom_root(functorch.grad(ridge_objective_torch, argnums=0), argnums=1)
def ridge_solver_torch_inv(params, l2reg, data):
"""Solve ridge regression by conjugate gradient."""
X_tr, y_tr = data

def matvec(u):
return X_tr.T @ (X_tr @ u)

solve = torchopt.linear_solve.solve_inv(
matvec=matvec,
b=X_tr.T @ y_tr,
ridge=len(y_tr) * l2reg.item(),
ns=ns,
)

return solve(matvec=matvec, b=X_tr.T @ y_tr)

def ridge_objective_jax(params, l2reg, X_tr, y_tr):
"""Ridge objective function."""
residuals = X_tr @ params - y_tr
regularization_loss = 0.5 * l2reg * jnp.sum(jnp.square(params))
return 0.5 * jnp.mean(jnp.square(residuals)) + regularization_loss

@jaxopt.implicit_diff.custom_root(jax.grad(ridge_objective_jax, argnums=0))
def ridge_solver_jax_inv(params, l2reg, X_tr, y_tr):
"""Solve ridge regression by conjugate gradient."""

def matvec(u):
return X_tr.T @ ((X_tr @ u))

return jaxopt.linear_solve.solve_inv(
matvec=matvec,
b=X_tr.T @ y_tr,
ridge=len(y_tr) * l2reg.item(),
)

for xs, ys, xq, yq in loader:
xs = xs.to(dtype=dtype)
ys = ys.to(dtype=dtype)
xq = xq.to(dtype=dtype)
yq = yq.to(dtype=dtype)

w_fit = ridge_solver_torch_inv(init_params_torch, l2reg_torch, (xs, ys))
outer_loss = F.mse_loss(xq @ w_fit, yq)

grads, *_ = torch.autograd.grad(outer_loss, l2reg_torch)
updates, optim_state = optim.update(grads, optim_state)
l2reg_torch = torchopt.apply_updates(l2reg_torch, updates)

xs = jnp.array(xs.numpy(), dtype=np_dtype)
ys = jnp.array(ys.numpy(), dtype=np_dtype)
xq = jnp.array(xq.numpy(), dtype=np_dtype)
yq = jnp.array(yq.numpy(), dtype=np_dtype)

def outer_level(params_jax, l2reg_jax, xs, ys, xq, yq):
w_fit = ridge_solver_jax_inv(params_jax, l2reg_jax, xs, ys)
y_pred = xq @ w_fit
loss_value = jnp.mean(jnp.square(y_pred - yq))
return loss_value
Expand Down
3 changes: 2 additions & 1 deletion torchopt/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"""Linear algebra functions."""

from torchopt.linalg.cg import cg
from torchopt.linalg.ns import ns, ns_inv


__all__ = ['cg']
__all__ = ['cg', 'ns', 'ns_inv']
32 changes: 6 additions & 26 deletions torchopt/linalg/cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,34 +34,19 @@
# pylint: disable=invalid-name

from functools import partial
from typing import Callable, List, Optional, Union
from typing import Callable, Optional, Union

import torch

from torchopt import pytree
from torchopt.linalg.utils import cat_shapes
from torchopt.pytree import tree_vdot_real
from torchopt.typing import TensorTree


__all__ = ['cg']


def _vdot_real_kernel(x: torch.Tensor, y: torch.Tensor) -> float:
"""Computes dot(x.conj(), y).real."""
x = x.contiguous().view(-1)
y = y.contiguous().view(-1)
prod = torch.dot(x.real, y.real).item()
if x.is_complex() and y.is_complex():
prod += torch.dot(x.imag, y.imag).item()
return prod


def tree_vdot_real(tree_x: TensorTree, tree_y: TensorTree) -> float:
"""Computes dot(tree_x.conj(), tree_y).real.sum()."""
leaves_x, treespec = pytree.tree_flatten(tree_x)
leaves_y = treespec.flatten_up_to(tree_y)
return sum(map(_vdot_real_kernel, leaves_x, leaves_y)) # type: ignore[arg-type]


def _identity(x: TensorTree) -> TensorTree:
return x

Expand Down Expand Up @@ -126,11 +111,6 @@ def body_fn(value):
return x_final


def _shapes(tree: TensorTree) -> List[int]:
flattened = pytree.tree_leaves(tree)
return pytree.tree_leaves([tuple(term.shape) for term in flattened]) # type: ignore[arg-type]


def _isolve(
_isolve_solve: Callable,
A: Union[torch.Tensor, Callable[[TensorTree], TensorTree]],
Expand All @@ -146,17 +126,17 @@ def _isolve(
x0 = pytree.tree_map(torch.zeros_like, b)

if maxiter is None:
size = sum(_shapes(b))
size = sum(cat_shapes(b))
maxiter = 10 * size # copied from SciPy

if M is None:
M = _identity
A = _normalize_matvec(A)
M = _normalize_matvec(M)

if _shapes(x0) != _shapes(b):
if cat_shapes(x0) != cat_shapes(b):
raise ValueError(
'arrays in x0 and b must have matching shapes: ' f'{_shapes(x0)} vs {_shapes(b)}'
f'Tensors in x0 and b must have matching shapes: {cat_shapes(x0)} vs. {cat_shapes(b)}.'
)

isolve_solve = partial(_isolve_solve, x0=x0, rtol=rtol, atol=atol, maxiter=maxiter, M=M)
Expand Down
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy