Skip to content

Commit 07a4e1e

Browse files
committed
fix(implicit): fix cg convergence problem
1 parent addce03 commit 07a4e1e

File tree

1 file changed

+21
-15
lines changed

1 file changed

+21
-15
lines changed

TorchOpt/_src/linalg.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
# See the License for the specific language governing permissions and
3030
# limitations under the License.
3131
# ==============================================================================
32+
import math
3233

3334
import torch
3435
import jax
@@ -79,39 +80,44 @@ def _normalize_matvec(f):
7980

8081
def _cg_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):
8182
# tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.cg
82-
bs = _vdot_real_tree(b, b)
83-
# atol2 = max(tol ** 2 * bs, atol ** 2)
84-
atol2 = jax.tree_util.tree_map(lambda bs: max(tol ** 2 * bs, atol ** 2), bs)
83+
bs = sum(_vdot_real_tree(b, b))
84+
atol2 = max(tol ** 2 * bs, atol ** 2)
85+
# atol2 = jax.tree_util.tree_map(lambda bs: max(tol ** 2 * bs, atol ** 2), bs)
8586

8687
# https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method
8788

88-
def cond_fun(value):
89+
min_rs = math.inf
90+
def cond_fun(value, min_rs):
8991
_, r, gamma, _, k = value
90-
rs = gamma if M is _identity else _vdot_real_tree(r, r)
91-
return (rs > atol2) & (k < maxiter)
92+
rs = gamma if M is _identity else sum(_vdot_real_tree(r, r))
93+
return (rs > atol2) & (k < maxiter) & (rs <= min_rs), rs
9294

9395
def body_fun(value):
9496
x, r, gamma, p, k = value
9597
Ap = A(p)
9698
# alpha = gamma / _vdot_real_tree(p, Ap)
97-
alpha = jax.tree_util.tree_map(lambda gamma, inner_product: gamma / inner_product, gamma, _vdot_real_tree(p, Ap))
98-
x_ = jax.tree_util.tree_map(torch.add, x, jax.tree_util.tree_map(torch.mul, alpha, p))
99-
r_ = jax.tree_util.tree_map(torch.sub, r, jax.tree_util.tree_map(torch.mul, alpha, Ap))
99+
alpha = gamma / sum(_vdot_real_tree(p, Ap))
100+
# alpha = jax.tree_util.tree_map(lambda gamma, inner_product: gamma / inner_product, gamma, _vdot_real_tree(p, Ap))
101+
x_ = jax.tree_util.tree_map(lambda a, b: a.add(b, alpha=alpha), x, p)
102+
r_ = jax.tree_util.tree_map(lambda a, b: a.sub(b, alpha=alpha), r, Ap)
100103
z_ = M(r_)
101-
gamma_ = _vdot_real_tree(r_, z_)
104+
gamma_ = sum(_vdot_real_tree(r_, z_))
102105
# beta_ = gamma_ / gamma
103-
beta_ = jax.tree_util.tree_map(torch.div, gamma_, gamma)
104-
p_ = jax.tree_util.tree_map(torch.add, z_, jax.tree_util.tree_map(torch.mul, beta_, p))
106+
beta_ = gamma_ / gamma
107+
p_ = jax.tree_util.tree_map(lambda a, b: a.add(b, alpha=beta_), z_, p)
105108
return x_, r_, gamma_, p_, k + 1
106109

107110
r0 = jax.tree_util.tree_map(torch.sub, b, A(x0))
108111
p0 = z0 = M(r0)
109-
gamma0 = _vdot_real_tree(r0, z0)
112+
gamma0 = sum(_vdot_real_tree(r0, z0))
110113
initial_value = (x0, r0, gamma0, p0, 0)
111114

112115
value = initial_value
113-
while cond_fun(value):
116+
not_stop, min_rs = cond_fun(value, min_rs)
117+
while not_stop:
114118
value = body_fun(value)
119+
not_stop, rs = cond_fun(value, min_rs)
120+
min_rs = min(rs, min_rs)
115121

116122
x_final, *_ = value
117123

@@ -130,7 +136,7 @@ def _isolve(_isolve_solve, A, b, x0=None, *, tol=1e-5, atol=0.0,
130136

131137
if maxiter is None:
132138
size = sum(_shapes(b))
133-
maxiter = 10 * size # copied from scipy
139+
maxiter = size # copied from scipy
134140

135141
if M is None:
136142
M = _identity

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy