Skip to content

feat: add AdamW optimizer #44

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Sep 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
70b8a47
feat(torchopt): init adamw optimizer
Benjamin-eecs Jul 27, 2022
3300142
Merge remote-tracking branch 'upstream/main' into feature/adamw
Benjamin-eecs Aug 4, 2022
253cc2a
fix(torchopt): pass adamw tests
Benjamin-eecs Aug 4, 2022
17d5784
fix: force add adamw.py
Benjamin-eecs Aug 4, 2022
cdc3836
feat: add MetaAdamW test and pass lint
Benjamin-eecs Aug 5, 2022
cc3a3c7
feat: add MetaAdamW test and pass lint
Benjamin-eecs Aug 5, 2022
a071550
fix: pass lint and pass MetaAdamW tests
Benjamin-eecs Aug 5, 2022
89fac53
fix: rewrite MetaOptimizer test, pass MetaAdamW tests with error tol
Benjamin-eecs Aug 5, 2022
b50abe0
merge: resolve conflicts
Benjamin-eecs Aug 24, 2022
47ff9f3
merge: resolve conflicts
Benjamin-eecs Aug 24, 2022
476332e
fix: update adamw low level test
Benjamin-eecs Aug 26, 2022
8175181
merge: resolve conflicts
Benjamin-eecs Sep 1, 2022
bb82209
fix(tests): use new test
Benjamin-eecs Sep 4, 2022
4b01c7e
Merge remote-tracking branch 'upstream/main' into feature/adamw
Benjamin-eecs Sep 4, 2022
d935014
fix: pass lint
Benjamin-eecs Sep 4, 2022
47cfa45
fix: pass test
Benjamin-eecs Sep 4, 2022
9b32e7b
Merge remote-tracking branch 'upstream/main' into feature/adamw
Benjamin-eecs Sep 4, 2022
42ed8a5
fix: pass test
Benjamin-eecs Sep 4, 2022
1e64877
fix: pass test
Benjamin-eecs Sep 4, 2022
872b8d4
fix: update docstring
Benjamin-eecs Sep 4, 2022
824d1c5
fix: update docstring
Benjamin-eecs Sep 4, 2022
e920c74
fix: update docstring
Benjamin-eecs Sep 4, 2022
8ee3c41
fix: correct already_flattened
Benjamin-eecs Sep 4, 2022
0f129c0
fix: correct weight_decay range check
Benjamin-eecs Sep 4, 2022
e75671e
fix: already_flattened of mask
Benjamin-eecs Sep 4, 2022
c791bba
style: format code
XuehaiPan Sep 5, 2022
24690a0
feat: add shortcut
XuehaiPan Sep 5, 2022
fec6f99
chore: reorganize code structure
XuehaiPan Sep 5, 2022
d3ad838
feat: inplace support for AdamW
XuehaiPan Sep 5, 2022
c685954
docs: update docstrings
XuehaiPan Sep 5, 2022
8114286
docs(CHANGELOG): update CHANGELOG.md
XuehaiPan Sep 5, 2022
c075533
docs: update docstrings
XuehaiPan Sep 5, 2022
0f5c90a
docs: update docstrings
XuehaiPan Sep 5, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Implement AdamW optimizer with masking by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#44](https://github.com/metaopt/torchopt/pull/44).
- Add half float support for accelerated OPs by [@XuehaiPan](https://github.com/XuehaiPan) in [#67](https://github.com/metaopt/torchopt/pull/67).
- Add MAML example with TorchRL integration by [@vmoens](https://github.com/vmoens) and [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#12](https://github.com/metaopt/TorchOpt/pull/12).
- Add optional argument `params` to update function in gradient transformations by [@XuehaiPan](https://github.com/XuehaiPan) in [#65](https://github.com/metaopt/torchopt/pull/65).
Expand Down
18 changes: 18 additions & 0 deletions docs/source/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,18 @@ Functional Optimizers
adam
sgd
rmsprop
adamw

Functional Adam Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: adam

Functional AdamW Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: adamw

Functional SGD Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -60,12 +66,18 @@ Classic Optimizers
Adam
SGD
RMSProp
AdamW

Classic Adam Optimizer
~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: Adam

Classic AdamW Optimizer
~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: AdamW

Classic SGD Optimizer
~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -88,12 +100,18 @@ Differentiable Meta-Optimizers
MetaAdam
MetaSGD
MetaRMSProp
MetaAdamW

Differentiable Meta-Adam Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: MetaAdam

Differentiable Meta-AdamW Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: MetaAdamW

Differentiable Meta-SGD Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
5 changes: 5 additions & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,8 @@ CPython
nn
Vincent
Moens
AdamW
Loshchilov
pytree
booleans
subtrees
58 changes: 58 additions & 0 deletions tests/test_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,64 @@ def test_adam(
weight_decay=[0.0, 1e-2],
maximize=[False, True],
)
def test_adamw(
dtype: torch.dtype,
lr: float,
betas: Tuple[float, float],
eps: float,
inplace: bool,
weight_decay: float,
maximize: bool,
) -> None:
model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype)

fmodel, params, buffers = functorch.make_functional_with_buffers(model)
optim = torchopt.adamw(
lr,
betas=betas,
eps=eps,
eps_root=0.0,
weight_decay=weight_decay,
maximize=maximize,
)
optim_state = optim.init(params)
optim_ref = torch.optim.AdamW(
model_ref.parameters(),
lr,
betas=betas,
eps=eps,
amsgrad=False,
weight_decay=weight_decay,
maximize=maximize,
)

for xs, ys in loader:
xs = xs.to(dtype=dtype)
pred = fmodel(params, buffers, xs)
pred_ref = model_ref(xs)
loss = F.cross_entropy(pred, ys)
loss_ref = F.cross_entropy(pred_ref, ys)

grads = torch.autograd.grad(loss, params)
updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
params = torchopt.apply_updates(params, updates, inplace=inplace)

optim_ref.zero_grad()
loss_ref.backward()
optim_ref.step()

helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype)


@helpers.parametrize(
dtype=[torch.float64],
lr=[1e-2, 1e-3, 1e-4],
betas=[(0.9, 0.999), (0.95, 0.9995)],
eps=[1e-8],
inplace=[True, False],
weight_decay=[1e-2, 1e-1],
maximize=[False, True],
)
def test_adam_accelerated_cpu(
dtype: torch.dtype,
lr: float,
Expand Down
55 changes: 55 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,61 @@ def test_Adam(
helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype)


@helpers.parametrize(
dtype=[torch.float64],
lr=[1e-2, 1e-3, 1e-4],
betas=[(0.9, 0.999), (0.95, 0.9995)],
eps=[1e-8],
weight_decay=[1e-2, 1e-1],
maximize=[False, True],
)
def test_AdamW(
dtype: torch.dtype,
lr: float,
betas: Tuple[float, float],
eps: float,
weight_decay: float,
maximize: bool,
) -> None:
model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype)

optim = torchopt.AdamW(
model.parameters(),
lr,
betas=betas,
eps=eps,
eps_root=0.0,
weight_decay=weight_decay,
maximize=maximize,
)
optim_ref = torch.optim.AdamW(
model_ref.parameters(),
lr,
betas=betas,
eps=eps,
amsgrad=False,
weight_decay=weight_decay,
maximize=maximize,
)

for xs, ys in loader:
xs = xs.to(dtype=dtype)
pred = model(xs)
pred_ref = model_ref(xs)
loss = F.cross_entropy(pred, ys)
loss_ref = F.cross_entropy(pred_ref, ys)

optim.zero_grad()
loss.backward()
optim.step()

optim_ref.zero_grad()
loss_ref.backward()
optim_ref.step()

helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype)


@helpers.parametrize(
dtype=[torch.float64],
lr=[1e-2, 1e-3, 1e-4],
Expand Down
16 changes: 13 additions & 3 deletions torchopt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,18 @@
"""TorchOpt: a high-performance optimizer library built upon PyTorch."""

from torchopt._src import accelerated_op_available, clip, combine, hook, schedule, visual
from torchopt._src.alias import adam, rmsprop, sgd
from torchopt._src.alias import adam, adamw, rmsprop, sgd
from torchopt._src.clip import clip_grad_norm
from torchopt._src.combine import chain
from torchopt._src.optimizer import SGD, Adam, Optimizer, RMSProp, RMSprop, meta
from torchopt._src.optimizer.meta import MetaAdam, MetaOptimizer, MetaRMSProp, MetaRMSprop, MetaSGD
from torchopt._src.optimizer import SGD, Adam, AdamW, Optimizer, RMSProp, RMSprop, meta
from torchopt._src.optimizer.meta import (
MetaAdam,
MetaAdamW,
MetaOptimizer,
MetaRMSProp,
MetaRMSprop,
MetaSGD,
)
from torchopt._src.update import apply_updates
from torchopt._src.utils import extract_state_dict, recover_state_dict, stop_gradient
from torchopt.version import __version__
Expand All @@ -33,18 +40,21 @@
'schedule',
'visual',
'adam',
'adamw',
'rmsprop',
'sgd',
'clip_grad_norm',
'chain',
'Optimizer',
'SGD',
'Adam',
'AdamW',
'RMSProp',
'RMSprop',
'MetaOptimizer',
'MetaSGD',
'MetaAdam',
'MetaAdamW',
'MetaRMSProp',
'MetaRMSprop',
'apply_updates',
Expand Down
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy