29
29
# See the License for the specific language governing permissions and
30
30
# limitations under the License.
31
31
# ==============================================================================
32
+ import math
32
33
33
34
import torch
34
35
import jax
@@ -79,39 +80,44 @@ def _normalize_matvec(f):
79
80
80
81
def _cg_solve (A , b , x0 = None , * , maxiter , tol = 1e-5 , atol = 0.0 , M = _identity ):
81
82
# tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.cg
82
- bs = _vdot_real_tree (b , b )
83
- # atol2 = max(tol ** 2 * bs, atol ** 2)
84
- atol2 = jax .tree_util .tree_map (lambda bs : max (tol ** 2 * bs , atol ** 2 ), bs )
83
+ bs = sum ( _vdot_real_tree (b , b ) )
84
+ atol2 = max (tol ** 2 * bs , atol ** 2 )
85
+ # atol2 = jax.tree_util.tree_map(lambda bs: max(tol ** 2 * bs, atol ** 2), bs)
85
86
86
87
# https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method
87
88
88
- def cond_fun (value ):
89
+ min_rs = math .inf
90
+ def cond_fun (value , min_rs ):
89
91
_ , r , gamma , _ , k = value
90
- rs = gamma if M is _identity else _vdot_real_tree (r , r )
91
- return (rs > atol2 ) & (k < maxiter )
92
+ rs = gamma if M is _identity else sum ( _vdot_real_tree (r , r ) )
93
+ return (rs > atol2 ) & (k < maxiter ) & ( rs <= min_rs ), rs
92
94
93
95
def body_fun (value ):
94
96
x , r , gamma , p , k = value
95
97
Ap = A (p )
96
98
# alpha = gamma / _vdot_real_tree(p, Ap)
97
- alpha = jax .tree_util .tree_map (lambda gamma , inner_product : gamma / inner_product , gamma , _vdot_real_tree (p , Ap ))
98
- x_ = jax .tree_util .tree_map (torch .add , x , jax .tree_util .tree_map (torch .mul , alpha , p ))
99
- r_ = jax .tree_util .tree_map (torch .sub , r , jax .tree_util .tree_map (torch .mul , alpha , Ap ))
99
+ alpha = gamma / sum (_vdot_real_tree (p , Ap ))
100
+ # alpha = jax.tree_util.tree_map(lambda gamma, inner_product: gamma / inner_product, gamma, _vdot_real_tree(p, Ap))
101
+ x_ = jax .tree_util .tree_map (lambda a , b : a .add (b , alpha = alpha ), x , p )
102
+ r_ = jax .tree_util .tree_map (lambda a , b : a .sub (b , alpha = alpha ), r , Ap )
100
103
z_ = M (r_ )
101
- gamma_ = _vdot_real_tree (r_ , z_ )
104
+ gamma_ = sum ( _vdot_real_tree (r_ , z_ ) )
102
105
# beta_ = gamma_ / gamma
103
- beta_ = jax . tree_util . tree_map ( torch . div , gamma_ , gamma )
104
- p_ = jax .tree_util .tree_map (torch . add , z_ , jax . tree_util . tree_map ( torch . mul , beta_ , p ) )
106
+ beta_ = gamma_ / gamma
107
+ p_ = jax .tree_util .tree_map (lambda a , b : a . add ( b , alpha = beta_ ), z_ , p )
105
108
return x_ , r_ , gamma_ , p_ , k + 1
106
109
107
110
r0 = jax .tree_util .tree_map (torch .sub , b , A (x0 ))
108
111
p0 = z0 = M (r0 )
109
- gamma0 = _vdot_real_tree (r0 , z0 )
112
+ gamma0 = sum ( _vdot_real_tree (r0 , z0 ) )
110
113
initial_value = (x0 , r0 , gamma0 , p0 , 0 )
111
114
112
115
value = initial_value
113
- while cond_fun (value ):
116
+ not_stop , min_rs = cond_fun (value , min_rs )
117
+ while not_stop :
114
118
value = body_fun (value )
119
+ not_stop , rs = cond_fun (value , min_rs )
120
+ min_rs = min (rs , min_rs )
115
121
116
122
x_final , * _ = value
117
123
@@ -130,7 +136,7 @@ def _isolve(_isolve_solve, A, b, x0=None, *, tol=1e-5, atol=0.0,
130
136
131
137
if maxiter is None :
132
138
size = sum (_shapes (b ))
133
- maxiter = 10 * size # copied from scipy
139
+ maxiter = size # copied from scipy
134
140
135
141
if M is None :
136
142
M = _identity
0 commit comments