Skip to content

Commit 180f62f

Browse files
committed
feat: another ns implementation to accelerate A^-1b
1 parent 425140d commit 180f62f

File tree

2 files changed

+58
-2
lines changed

2 files changed

+58
-2
lines changed

torchopt/linalg.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,45 @@ def cg(
218218

219219

220220
def ns(
221+
A: Union[torch.Tensor, Callable[[TensorTree], TensorTree]],
222+
b: TensorTree,
223+
maxiter: Optional[int] = None,
224+
*,
225+
alpha: Optional[float] = None,
226+
) -> TensorTree:
227+
"""Use Neumann Series Inverse Matric Approximation to solve ``Ax = b``.
228+
229+
Args:
230+
A: (tensor or tree of tensors or function)
231+
2D array or function that calculates the linear map (matrix-vector product) ``Ax`` when
232+
called like ``A(x)``. ``A`` must represent a hermitian, positive definite matrix, and
233+
must return array(s) with the same structure and shape as its argument.
234+
b: (tensor or tree of tensors)
235+
Right hand side of the linear system representing a single vector. Can be stored as an
236+
array or Python container of array(s) with any shape.
237+
maxiter: (integer, optional)
238+
Maximum number of iterations. Iteration will stop after maxiter steps even if the
239+
specified tolerance has not been achieved.
240+
alpha: (float, optional)
241+
Decay coefficient.
242+
243+
Returns:
244+
the Neumann Series (NS) matrix inversion approximation
245+
"""
246+
if maxiter is None:
247+
size = sum(_shapes(b))
248+
maxiter = 10 * size
249+
inv_A_hat_b = b
250+
for rank in range(maxiter):
251+
if alpha is not None:
252+
b = pytree.tree_sub(b, pytree.tree_mul(alpha, A(b)))
253+
else:
254+
b = pytree.tree_sub(b, A(b))
255+
inv_A_hat_b = pytree.tree_sub(inv_A_hat_b, b)
256+
return inv_A_hat_b
257+
258+
259+
def ns_inv(
221260
A: TensorTree,
222261
n: int,
223262
maxiter: Optional[int] = None,
@@ -229,7 +268,7 @@ def ns(
229268
Args:
230269
A: (tensor or tree of tensors or function)
231270
2D array or function that calculates the linear map (matrix-vector product) ``Ax`` when
232-
called like ``A(x)``. ``A`` must represent a Hermitian, positive definite matrix, and
271+
called like ``A(x)``. ``A`` must represent a hermitian, positive definite matrix, and
233272
must return array(s) with the same structure and shape as its argument.
234273
n: (integer)
235274
Number of rows (and columns) in n x n output.

torchopt/pytree.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
"""The PyTree utilities."""
16+
import functools
17+
import itertools
18+
import operator
1619

1720
import optree
1821
import optree.typing as typing # pylint: disable=unused-import
@@ -22,7 +25,21 @@
2225
from torchopt.typing import Future, PyTree, RRef, T
2326

2427

25-
__all__ = [*optree.__all__, 'tree_wait']
28+
__all__ = [*optree.__all__, 'tree_wait', 'tree_add', 'tree_sub', 'tree_mul', 'tree_div']
29+
30+
tree_map = optree.tree_map
31+
32+
tree_add = functools.partial(tree_map, operator.add)
33+
tree_add.__doc__ = "Tree addition."
34+
35+
tree_sub = functools.partial(tree_map, operator.sub)
36+
tree_sub.__doc__ = "Tree subtraction."
37+
38+
tree_mul = functools.partial(tree_map, operator.mul)
39+
tree_mul.__doc__ = "Tree multiplication."
40+
41+
tree_div = functools.partial(tree_map, operator.truediv)
42+
tree_div.__doc__ = "Tree division."
2643

2744

2845
def tree_wait(future_tree: PyTree[Future[T]]) -> PyTree[T]:

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy