Skip to content

refactor: align argument names with PyTorch #65

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Aug 29, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
5e4676e
refactor: refactor optimizer bases
XuehaiPan Aug 25, 2022
20d5377
refactor: align Adam options with PyTorch
XuehaiPan Aug 25, 2022
bbc3bdd
refactor: align RMSProp options with PyTorch
XuehaiPan Aug 25, 2022
a45111a
refactor: align SGD options with PyTorch
XuehaiPan Aug 25, 2022
0c30794
feat(alias): check value range
XuehaiPan Aug 25, 2022
4a985ea
feat: add `params` to `update_fn`'s signature
XuehaiPan Aug 25, 2022
a3cb948
feat: add weight decay
XuehaiPan Aug 26, 2022
cbf1a52
test: add weight decay tests
XuehaiPan Aug 26, 2022
a0be5d9
lint: pass lint
XuehaiPan Aug 26, 2022
71a4a63
docs: update docstring
XuehaiPan Aug 26, 2022
2ab1da0
chore: update type hints
XuehaiPan Aug 26, 2022
a7d7643
fix: fix grad tracing for weight decay
XuehaiPan Aug 26, 2022
63d77cc
test: reorganize tests
XuehaiPan Aug 26, 2022
b31e245
chore: add RMSprop aliases for PyTorch compatibility
XuehaiPan Aug 26, 2022
a929c51
test: add module buffers
XuehaiPan Aug 26, 2022
f082d6c
test: update test parameters
XuehaiPan Aug 26, 2022
1c8421d
chore: update .gitignore
XuehaiPan Aug 26, 2022
2723c01
test: update test parameters
XuehaiPan Aug 27, 2022
1ec3d0f
refactor: refactor transform
XuehaiPan Aug 27, 2022
e8bd609
refactor: chain
XuehaiPan Aug 27, 2022
353b628
refactor: identity
XuehaiPan Aug 27, 2022
bee81bb
feat: add with_flattened_tree
XuehaiPan Aug 27, 2022
a8b6dc0
test: update test parameters
XuehaiPan Aug 27, 2022
7d1d20a
feat: add dampening
XuehaiPan Aug 27, 2022
58dcc56
docs: update docstring
XuehaiPan Aug 28, 2022
acde3fd
lint: fix mypy
XuehaiPan Aug 28, 2022
b1e1521
fix: fix grad tracing for initial value
XuehaiPan Aug 29, 2022
6331cdc
test: update test parameters
XuehaiPan Aug 29, 2022
1b39ad5
docs: update docstrings
XuehaiPan Aug 29, 2022
9976f96
chore: rename variables
XuehaiPan Aug 29, 2022
4c726f2
test: update test parameters
XuehaiPan Aug 29, 2022
3eee243
docs(CHANGELOG): update CHANGELOG.md
XuehaiPan Aug 29, 2022
1fded5d
test: test with pre-release of PyTorch
XuehaiPan Aug 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
test: update test parameters
  • Loading branch information
XuehaiPan committed Aug 27, 2022
commit 2723c01587a1cba4571dc73fc773c4cd84299ea7
10 changes: 5 additions & 5 deletions tests/test_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
lr=[1e-1, 1e-3, 1e-4],
momentum=[0.0, 0.1],
nesterov=[False, True],
inplace=[True, False],
Expand Down Expand Up @@ -86,7 +86,7 @@ def test_sgd(

@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
lr=[1e-1, 1e-3, 1e-4],
betas=[(0.9, 0.999), (0.95, 0.9995)],
eps=[1e-8],
inplace=[True, False],
Expand Down Expand Up @@ -144,7 +144,7 @@ def test_adam(

@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
lr=[1e-1, 1e-3, 1e-4],
betas=[(0.9, 0.999), (0.95, 0.9995)],
eps=[1e-8],
inplace=[True, False],
Expand Down Expand Up @@ -204,7 +204,7 @@ def test_adam_accelerated_cpu(
@pytest.mark.skipif(not torch.cuda.is_available(), reason='No CUDA device available.')
@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
lr=[1e-1, 1e-3, 1e-4],
betas=[(0.9, 0.999), (0.95, 0.9995)],
eps=[1e-8],
inplace=[True, False],
Expand Down Expand Up @@ -265,7 +265,7 @@ def test_adam_accelerated_cuda(

@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
lr=[1e-1, 1e-3, 1e-4],
alpha=[0.9, 0.99],
eps=[1e-8],
momentum=[0.0, 0.1],
Expand Down
23 changes: 19 additions & 4 deletions tests/test_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,20 @@
@helpers.parametrize(
dtype=[torch.float64, torch.float32],
max_norm=[1.0, 10.0],
lr=[1e-3, 1e-4],
lr=[1e-1, 1e-3, 1e-4],
momentum=[0.0, 0.1],
nesterov=[False, True],
weight_decay=[0.0, 1e-3],
maximize=[False, True],
)
def test_sgd(
dtype: torch.dtype, max_norm: float, lr: float, momentum: float, nesterov: bool
dtype: torch.dtype,
max_norm: float,
lr: float,
momentum: float,
nesterov: bool,
weight_decay: float,
maximize: bool,
) -> None:
if nesterov and momentum <= 0.0:
pytest.skip('Nesterov momentum requires a momentum and zero dampening.')
Expand All @@ -39,7 +47,13 @@ def test_sgd(

chain = torchopt.combine.chain(
torchopt.clip.clip_grad_norm(max_norm=max_norm),
torchopt.sgd(lr=lr, momentum=momentum, nesterov=nesterov),
torchopt.sgd(
lr=lr,
momentum=momentum,
nesterov=nesterov,
weight_decay=weight_decay,
maximize=maximize,
),
)
optim = torchopt.Optimizer(model.parameters(), chain)
optim_ref = torch.optim.SGD(
Expand All @@ -48,7 +62,8 @@ def test_sgd(
momentum=momentum,
dampening=0.0,
nesterov=nesterov,
weight_decay=0.0,
weight_decay=weight_decay,
maximize=maximize,
)

for xs, ys in loader:
Expand Down
10 changes: 5 additions & 5 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
lr=[1e-1, 1e-3, 1e-4],
momentum=[0.0, 0.1],
nesterov=[False, True],
weight_decay=[0.0, 1e-3],
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_SGD(

@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
lr=[1e-1, 1e-3, 1e-4],
betas=[(0.9, 0.999), (0.95, 0.9995)],
eps=[1e-8],
weight_decay=[0.0, 1e-3],
Expand Down Expand Up @@ -137,7 +137,7 @@ def test_Adam(

@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
lr=[1e-1, 1e-3, 1e-4],
betas=[(0.9, 0.999), (0.95, 0.9995)],
eps=[1e-8],
weight_decay=[0.0, 1e-3],
Expand Down Expand Up @@ -194,7 +194,7 @@ def test_Adam_accelerated_cpu(
@pytest.mark.skipif(not torch.cuda.is_available(), reason='No CUDA device available.')
@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
lr=[1e-1, 1e-3, 1e-4],
betas=[(0.9, 0.999), (0.95, 0.9995)],
eps=[1e-8],
weight_decay=[0.0, 1e-3],
Expand Down Expand Up @@ -252,7 +252,7 @@ def test_Adam_accelerated_cuda(

@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
lr=[1e-1, 1e-3, 1e-4],
alpha=[0.9, 0.99],
eps=[1e-8],
momentum=[0.0, 0.1],
Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy