Skip to content

Commit ef51cc8

Browse files
feat(optim): Adadelta RAdam Adamax optimizer support (#171)
Co-authored-by: Benjamin-eecs <benjaminliu.eecs@gmail.com>
1 parent 3f68378 commit ef51cc8

30 files changed

+1775
-13
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,6 @@ repos:
2424
- id: detect-private-key
2525
- id: debug-statements
2626
- id: double-quote-string-fixer
27-
- repo: https://github.com/pre-commit/mirrors-clang-format
28-
rev: v16.0.6
29-
hooks:
30-
- id: clang-format
3127
- repo: https://github.com/astral-sh/ruff-pre-commit
3228
rev: v0.0.278
3329
hooks:

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313

1414
### Added
1515

16-
-
16+
- Implement `Adadelta`, `RAdam`, `Adamax` optimizer by [@JieRen98](https://github.com/JieRen98) and [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#171](https://github.com/metaopt/torchopt/pull/171).
1717

1818
### Changed
1919

docs/source/api/api.rst

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,12 @@ Functional Optimizers
3030
.. autosummary::
3131

3232
FuncOptimizer
33+
adadelta
3334
adagrad
3435
adam
3536
adamw
37+
adamax
38+
radam
3639
rmsprop
3740
sgd
3841

@@ -42,6 +45,11 @@ Wrapper for Function Optimizer
4245
.. autoclass:: FuncOptimizer
4346
:members:
4447

48+
Functional AdaDelta Optimizer
49+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
50+
51+
.. autofunction:: adadelta
52+
4553
Functional AdaGrad Optimizer
4654
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4755

@@ -57,6 +65,16 @@ Functional AdamW Optimizer
5765

5866
.. autofunction:: adamw
5967

68+
Functional AdaMax Optimizer
69+
~~~~~~~~~~~~~~~~~~~~~~~~~~~
70+
71+
.. autofunction:: adamax
72+
73+
Functional RAdam Optimizer
74+
~~~~~~~~~~~~~~~~~~~~~~~~~~
75+
76+
.. autofunction:: radam
77+
6078
Functional RMSProp Optimizer
6179
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
6280

@@ -76,12 +94,23 @@ Classic Optimizers
7694

7795
.. autosummary::
7896

97+
AdaDelta
98+
Adadelta
7999
AdaGrad
100+
Adagrad
80101
Adam
81102
AdamW
103+
AdaMax
104+
Adamax
105+
RAdam
82106
RMSProp
83107
SGD
84108

109+
Classic AdaDelta Optimizer
110+
~~~~~~~~~~~~~~~~~~~~~~~~~~
111+
112+
.. autoclass:: AdaDelta
113+
85114
Classic AdaGrad Optimizer
86115
~~~~~~~~~~~~~~~~~~~~~~~~~
87116

@@ -97,6 +126,16 @@ Classic AdamW Optimizer
97126

98127
.. autoclass:: AdamW
99128

129+
Classic AdaMax Optimizer
130+
~~~~~~~~~~~~~~~~~~~~~~~~
131+
132+
.. autoclass:: AdaMax
133+
134+
Classic RAdam Optimizer
135+
~~~~~~~~~~~~~~~~~~~~~~~
136+
137+
.. autoclass:: RAdam
138+
100139
Classic RMSProp Optimizer
101140
~~~~~~~~~~~~~~~~~~~~~~~~~
102141

@@ -116,12 +155,23 @@ Differentiable Meta-Optimizers
116155

117156
.. autosummary::
118157

158+
MetaAdaDelta
159+
MetaAdadelta
119160
MetaAdaGrad
161+
MetaAdagrad
120162
MetaAdam
121163
MetaAdamW
164+
MetaAdaMax
165+
MetaAdamax
166+
MetaRAdam
122167
MetaRMSProp
123168
MetaSGD
124169

170+
Differentiable Meta-AdaDelta Optimizer
171+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
172+
173+
.. autoclass:: MetaAdaDelta
174+
125175
Differentiable Meta-AdaGrad Optimizer
126176
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
127177

@@ -137,6 +187,16 @@ Differentiable Meta-AdamW Optimizer
137187

138188
.. autoclass:: MetaAdamW
139189

190+
Differentiable Meta-AdaMax Optimizer
191+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
192+
193+
.. autoclass:: MetaAdaMax
194+
195+
Differentiable Meta-RAdam Optimizer
196+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
197+
198+
.. autoclass:: MetaRAdam
199+
140200
Differentiable Meta-RMSProp Optimizer
141201
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
142202

docs/source/explicit_diff/explicit_diff.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,15 @@ For PyTorch-like API (e.g., ``step()``), we designed a base class :class:`torcho
5353
.. autosummary::
5454

5555
torchopt.MetaOptimizer
56+
torchopt.MetaAdaDelta
57+
torchopt.MetaAdadelta
5658
torchopt.MetaAdaGrad
59+
torchopt.MetaAdagrad
5760
torchopt.MetaAdam
5861
torchopt.MetaAdamW
62+
torchopt.AdaMax
63+
torchopt.MetaAdamax
64+
torchopt.MetaRAdam
5965
torchopt.MetaRMSProp
6066
torchopt.MetaSGD
6167

docs/source/optimizer/optim.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,12 @@ Currently, TorchOpt supports 4 functional optimizers: :func:`sgd`, :func:`adam`,
1818
.. autosummary::
1919

2020
torchopt.FuncOptimizer
21+
torchopt.adadelta
2122
torchopt.adagrad
2223
torchopt.adam
2324
torchopt.adamw
25+
torchopt.adamax
26+
torchopt.radam
2427
torchopt.rmsprop
2528
torchopt.sgd
2629

@@ -85,9 +88,15 @@ We offer original PyTorch APIs (e.g., ``zero_grad()`` or ``step()``) for traditi
8588
.. autosummary::
8689

8790
torchopt.Optimizer
91+
torchopt.AdaDelta
92+
torchopt.Adadelta
8893
torchopt.AdaGrad
94+
torchopt.Adagrad
8995
torchopt.Adam
9096
torchopt.AdamW
97+
torchopt.AdaMax
98+
torchopt.Adamax
99+
torchopt.RAdam
91100
torchopt.RMSProp
92101
torchopt.SGD
93102

docs/source/spelling_wordlist.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,10 @@ ctx
175175
Duchi
176176
invertible
177177
AdaGrad
178+
Adadelta
179+
Zeiler
180+
radam
181+
adamax
182+
RAdam
183+
AdaDelta
184+
AdaMax

tests/test_alias.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,63 @@ def test_sgd(
144144
_set_use_chain_flat(True)
145145

146146

147+
@helpers.parametrize(
148+
dtype=[torch.float64],
149+
lr=[1e-2, 1e-3, 1e-4],
150+
rho=[0.9, 0.95],
151+
eps=[1e-8],
152+
inplace=[True, False],
153+
weight_decay=[0.0, 1e-2],
154+
use_chain_flat=[True, False],
155+
)
156+
def test_adadelta(
157+
dtype: torch.dtype,
158+
lr: float,
159+
rho: float,
160+
eps: float,
161+
inplace: bool,
162+
weight_decay: float,
163+
use_chain_flat: bool,
164+
) -> None:
165+
_set_use_chain_flat(use_chain_flat)
166+
167+
model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype)
168+
169+
fmodel, params, buffers = functorch.make_functional_with_buffers(model)
170+
optim = torchopt.adadelta(
171+
lr,
172+
rho=rho,
173+
eps=eps,
174+
weight_decay=weight_decay,
175+
)
176+
optim_state = optim.init(params)
177+
optim_ref = torch.optim.Adadelta(
178+
model_ref.parameters(),
179+
lr,
180+
rho=rho,
181+
eps=eps,
182+
weight_decay=weight_decay,
183+
)
184+
185+
for xs, ys in loader:
186+
xs = xs.to(dtype=dtype)
187+
pred = fmodel(params, buffers, xs)
188+
pred_ref = model_ref(xs)
189+
loss = F.cross_entropy(pred, ys)
190+
loss_ref = F.cross_entropy(pred_ref, ys)
191+
192+
grads = torch.autograd.grad(loss, params, allow_unused=True)
193+
updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
194+
params = torchopt.apply_updates(params, updates, inplace=inplace)
195+
196+
optim_ref.zero_grad()
197+
loss_ref.backward()
198+
optim_ref.step()
199+
200+
helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype)
201+
_set_use_chain_flat(True)
202+
203+
147204
@helpers.parametrize(
148205
dtype=[torch.float64],
149206
lr=[1e-2, 1e-3, 1e-4],
@@ -210,6 +267,120 @@ def test_adam(
210267
_set_use_chain_flat(True)
211268

212269

270+
@helpers.parametrize(
271+
dtype=[torch.float64],
272+
lr=[1e-2, 1e-3, 1e-4],
273+
betas=[(0.9, 0.999), (0.95, 0.9995)],
274+
eps=[1e-8],
275+
inplace=[True, False],
276+
weight_decay=[0.0, 1e-2],
277+
use_chain_flat=[True, False],
278+
)
279+
def test_radam(
280+
dtype: torch.dtype,
281+
lr: float,
282+
betas: tuple[float, float],
283+
eps: float,
284+
inplace: bool,
285+
weight_decay: float,
286+
use_chain_flat: bool,
287+
) -> None:
288+
_set_use_chain_flat(use_chain_flat)
289+
290+
model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype)
291+
292+
fmodel, params, buffers = functorch.make_functional_with_buffers(model)
293+
optim = torchopt.radam(
294+
lr,
295+
betas=betas,
296+
eps=eps,
297+
weight_decay=weight_decay,
298+
)
299+
optim_state = optim.init(params)
300+
optim_ref = torch.optim.RAdam(
301+
model_ref.parameters(),
302+
lr,
303+
betas=betas,
304+
eps=eps,
305+
weight_decay=weight_decay,
306+
)
307+
308+
for xs, ys in loader:
309+
xs = xs.to(dtype=dtype)
310+
pred = fmodel(params, buffers, xs)
311+
pred_ref = model_ref(xs)
312+
loss = F.cross_entropy(pred, ys)
313+
loss_ref = F.cross_entropy(pred_ref, ys)
314+
315+
grads = torch.autograd.grad(loss, params, allow_unused=True)
316+
updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
317+
params = torchopt.apply_updates(params, updates, inplace=inplace)
318+
319+
optim_ref.zero_grad()
320+
loss_ref.backward()
321+
optim_ref.step()
322+
323+
helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype)
324+
_set_use_chain_flat(True)
325+
326+
327+
@helpers.parametrize(
328+
dtype=[torch.float64],
329+
lr=[1e-2, 1e-3, 1e-4],
330+
betas=[(0.9, 0.999), (0.95, 0.9995)],
331+
eps=[1e-8],
332+
inplace=[True, False],
333+
weight_decay=[0.0, 1e-2],
334+
use_chain_flat=[True, False],
335+
)
336+
def test_adamax(
337+
dtype: torch.dtype,
338+
lr: float,
339+
betas: tuple[float, float],
340+
eps: float,
341+
inplace: bool,
342+
weight_decay: float,
343+
use_chain_flat: bool,
344+
) -> None:
345+
_set_use_chain_flat(use_chain_flat)
346+
347+
model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype)
348+
349+
fmodel, params, buffers = functorch.make_functional_with_buffers(model)
350+
optim = torchopt.adamax(
351+
lr,
352+
betas=betas,
353+
eps=eps,
354+
weight_decay=weight_decay,
355+
)
356+
optim_state = optim.init(params)
357+
optim_ref = torch.optim.Adamax(
358+
model_ref.parameters(),
359+
lr,
360+
betas=betas,
361+
eps=eps,
362+
weight_decay=weight_decay,
363+
)
364+
365+
for xs, ys in loader:
366+
xs = xs.to(dtype=dtype)
367+
pred = fmodel(params, buffers, xs)
368+
pred_ref = model_ref(xs)
369+
loss = F.cross_entropy(pred, ys)
370+
loss_ref = F.cross_entropy(pred_ref, ys)
371+
372+
grads = torch.autograd.grad(loss, params, allow_unused=True)
373+
updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
374+
params = torchopt.apply_updates(params, updates, inplace=inplace)
375+
376+
optim_ref.zero_grad()
377+
loss_ref.backward()
378+
optim_ref.step()
379+
380+
helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype)
381+
_set_use_chain_flat(True)
382+
383+
213384
@helpers.parametrize(
214385
dtype=[torch.float64],
215386
outer_lr=[1e-2, 1e-3, 1e-4],

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy