Skip to content

Commit e9611fc

Browse files
committed
refactor: refactor scale_by_rss
1 parent 9908594 commit e9611fc

File tree

2 files changed

+51
-23
lines changed

2 files changed

+51
-23
lines changed

torchopt/transform/scale_by_rss.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
from torchopt import pytree
4141
from torchopt.base import GradientTransformation
42-
from torchopt.transform.utils import tree_map_flat
42+
from torchopt.transform.utils import tree_map_flat, update_moment
4343
from torchopt.typing import OptState, Params, Updates
4444

4545

@@ -63,8 +63,10 @@ def scale_by_rss(
6363
- McMahan et al., 2010: https://arxiv.org/abs/1002.4908
6464
6565
Args:
66-
initial_accumulator_value: Starting value for accumulators, must be >= 0.
67-
eps: A small floating point value to avoid zero denominator.
66+
initial_accumulator_value (float, optional): Starting value for accumulators, must be
67+
``>= 0``. (default: :const:`0.0`)
68+
eps (float, optional): A small floating point value to avoid zero denominator.
69+
(default: :const:`1e-10`)
6870
6971
Returns:
7072
An (init_fn, update_fn) tuple.
@@ -115,32 +117,34 @@ def update_fn(
115117
params: Params | None = None, # pylint: disable=unused-argument
116118
inplace: bool = True,
117119
) -> tuple[Updates, OptState]:
118-
sum_of_squares = tree_map(
119-
lambda g, t: t + (g.conj() * g).real,
120+
sum_of_squares = update_moment.impl( # type: ignore[attr-defined]
120121
updates,
121122
state.sum_of_squares,
123+
decay=1.0,
124+
order=2,
125+
inplace=inplace,
126+
already_flattened=already_flattened,
122127
)
123128

124129
if inplace:
125130

126-
def f(t: torch.Tensor) -> torch.Tensor:
131+
def f(g: torch.Tensor, sos: torch.Tensor) -> torch.Tensor:
127132
return torch.where(
128-
t > 0.0,
129-
torch.ones_like(t).div_(t.sqrt().add_(eps)),
130-
torch.tensor(0.0),
133+
sos > 0.0,
134+
g.div_(sos.sqrt().add_(eps)),
135+
0.0,
131136
)
132137

133138
else:
134139

135-
def f(t: torch.Tensor) -> torch.Tensor:
140+
def f(g: torch.Tensor, sos: torch.Tensor) -> torch.Tensor:
136141
return torch.where(
137-
t > 0.0,
138-
torch.ones_like(t).div(t.sqrt().add(eps)),
139-
torch.tensor(0.0),
142+
sos > 0.0,
143+
g.div(sos.sqrt().add(eps)),
144+
0.0,
140145
)
141146

142-
inv_sqrt_g_square = tree_map(f, sum_of_squares)
143-
updates = tree_map(lambda scale, g: g * scale, inv_sqrt_g_square, updates)
147+
updates = tree_map(f, updates, sum_of_squares)
144148
return updates, ScaleByRssState(sum_of_squares=sum_of_squares)
145149

146150
return GradientTransformation(init_fn, update_fn)

torchopt/transform/utils.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -173,25 +173,49 @@ def _update_moment(
173173

174174
if inplace:
175175
if order == 2:
176+
if decay != 1.0:
176177

177-
def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
178-
return t.mul_(decay).addcmul_(g, g, value=1 - decay) if g is not None else t
178+
def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
179+
return t.mul_(decay).addcmul_(g, g, value=1 - decay) if g is not None else t
180+
181+
else:
182+
183+
def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
184+
return t.addcmul_(g, g) if g is not None else t
179185

180186
else:
187+
if decay != 1.0:
188+
189+
def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
190+
return t.mul_(decay).add_(g, alpha=1 - decay) if g is not None else t
191+
192+
else:
181193

182-
def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
183-
return t.mul_(decay).add_(g, alpha=1 - decay) if g is not None else t
194+
def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
195+
return t.add_(g) if g is not None else t
184196

185197
else:
186198
if order == 2:
199+
if decay != 1.0:
187200

188-
def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
189-
return t.mul(decay).addcmul_(g, g, value=1 - decay) if g is not None else t
201+
def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
202+
return t.mul(decay).addcmul_(g, g, value=1 - decay) if g is not None else t
203+
204+
else:
205+
206+
def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
207+
return t.addcmul(g, g) if g is not None else t
190208

191209
else:
210+
if decay != 1.0:
211+
212+
def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
213+
return t.mul(decay).add_(g, alpha=1 - decay) if g is not None else t
214+
215+
else:
192216

193-
def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
194-
return t.mul(decay).add_(g, alpha=1 - decay) if g is not None else t
217+
def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
218+
return t.add(g) if g is not None else t
195219

196220
if already_flattened:
197221
return tree_map_flat(f, updates, moments, none_is_leaf=True)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy