@@ -218,6 +218,45 @@ def cg(
218
218
219
219
220
220
def ns (
221
+ A : Union [torch .Tensor , Callable [[TensorTree ], TensorTree ]],
222
+ b : TensorTree ,
223
+ maxiter : Optional [int ] = None ,
224
+ * ,
225
+ alpha : Optional [float ] = None ,
226
+ ) -> TensorTree :
227
+ """Use Neumann Series Inverse Matric Approximation to solve ``Ax = b``.
228
+
229
+ Args:
230
+ A: (tensor or tree of tensors or function)
231
+ 2D array or function that calculates the linear map (matrix-vector product) ``Ax`` when
232
+ called like ``A(x)``. ``A`` must represent a hermitian, positive definite matrix, and
233
+ must return array(s) with the same structure and shape as its argument.
234
+ b: (tensor or tree of tensors)
235
+ Right hand side of the linear system representing a single vector. Can be stored as an
236
+ array or Python container of array(s) with any shape.
237
+ maxiter: (integer, optional)
238
+ Maximum number of iterations. Iteration will stop after maxiter steps even if the
239
+ specified tolerance has not been achieved.
240
+ alpha: (float, optional)
241
+ Decay coefficient.
242
+
243
+ Returns:
244
+ the Neumann Series (NS) matrix inversion approximation
245
+ """
246
+ if maxiter is None :
247
+ size = sum (_shapes (b ))
248
+ maxiter = 10 * size
249
+ inv_A_hat_b = b
250
+ for rank in range (maxiter ):
251
+ if alpha is not None :
252
+ b = pytree .tree_sub (b , pytree .tree_mul (alpha , A (b )))
253
+ else :
254
+ b = pytree .tree_sub (b , A (b ))
255
+ inv_A_hat_b = pytree .tree_sub (inv_A_hat_b , b )
256
+ return inv_A_hat_b
257
+
258
+
259
+ def ns_inv (
221
260
A : TensorTree ,
222
261
n : int ,
223
262
maxiter : Optional [int ] = None ,
@@ -229,7 +268,7 @@ def ns(
229
268
Args:
230
269
A: (tensor or tree of tensors or function)
231
270
2D array or function that calculates the linear map (matrix-vector product) ``Ax`` when
232
- called like ``A(x)``. ``A`` must represent a Hermitian , positive definite matrix, and
271
+ called like ``A(x)``. ``A`` must represent a hermitian , positive definite matrix, and
233
272
must return array(s) with the same structure and shape as its argument.
234
273
n: (integer)
235
274
Number of rows (and columns) in n x n output.
0 commit comments