39
39
40
40
from torchopt import pytree
41
41
from torchopt .base import GradientTransformation
42
- from torchopt .transform .utils import tree_map_flat
42
+ from torchopt .transform .utils import tree_map_flat , update_moment
43
43
from torchopt .typing import OptState , Params , Updates
44
44
45
45
@@ -63,8 +63,10 @@ def scale_by_rss(
63
63
- McMahan et al., 2010: https://arxiv.org/abs/1002.4908
64
64
65
65
Args:
66
- initial_accumulator_value: Starting value for accumulators, must be >= 0.
67
- eps: A small floating point value to avoid zero denominator.
66
+ initial_accumulator_value (float, optional): Starting value for accumulators, must be
67
+ ``>= 0``. (default: :const:`0.0`)
68
+ eps (float, optional): A small floating point value to avoid zero denominator.
69
+ (default: :const:`1e-10`)
68
70
69
71
Returns:
70
72
An (init_fn, update_fn) tuple.
@@ -115,32 +117,34 @@ def update_fn(
115
117
params : Params | None = None , # pylint: disable=unused-argument
116
118
inplace : bool = True ,
117
119
) -> tuple [Updates , OptState ]:
118
- sum_of_squares = tree_map (
119
- lambda g , t : t + (g .conj () * g ).real ,
120
+ sum_of_squares = update_moment .impl ( # type: ignore[attr-defined]
120
121
updates ,
121
122
state .sum_of_squares ,
123
+ decay = 1.0 ,
124
+ order = 2 ,
125
+ inplace = inplace ,
126
+ already_flattened = already_flattened ,
122
127
)
123
128
124
129
if inplace :
125
130
126
- def f (t : torch .Tensor ) -> torch .Tensor :
131
+ def f (g : torch . Tensor , sos : torch .Tensor ) -> torch .Tensor :
127
132
return torch .where (
128
- t > 0.0 ,
129
- torch . ones_like ( t ). div_ (t .sqrt ().add_ (eps )),
130
- torch . tensor ( 0.0 ) ,
133
+ sos > 0.0 ,
134
+ g . div_ (sos .sqrt ().add_ (eps )),
135
+ 0.0 ,
131
136
)
132
137
133
138
else :
134
139
135
- def f (t : torch .Tensor ) -> torch .Tensor :
140
+ def f (g : torch . Tensor , sos : torch .Tensor ) -> torch .Tensor :
136
141
return torch .where (
137
- t > 0.0 ,
138
- torch . ones_like ( t ). div (t .sqrt ().add (eps )),
139
- torch . tensor ( 0.0 ) ,
142
+ sos > 0.0 ,
143
+ g . div (sos .sqrt ().add (eps )),
144
+ 0.0 ,
140
145
)
141
146
142
- inv_sqrt_g_square = tree_map (f , sum_of_squares )
143
- updates = tree_map (lambda scale , g : g * scale , inv_sqrt_g_square , updates )
147
+ updates = tree_map (f , updates , sum_of_squares )
144
148
return updates , ScaleByRssState (sum_of_squares = sum_of_squares )
145
149
146
150
return GradientTransformation (init_fn , update_fn )
0 commit comments