Skip to content

refactor: align argument names with PyTorch #65

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Aug 29, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
5e4676e
refactor: refactor optimizer bases
XuehaiPan Aug 25, 2022
20d5377
refactor: align Adam options with PyTorch
XuehaiPan Aug 25, 2022
bbc3bdd
refactor: align RMSProp options with PyTorch
XuehaiPan Aug 25, 2022
a45111a
refactor: align SGD options with PyTorch
XuehaiPan Aug 25, 2022
0c30794
feat(alias): check value range
XuehaiPan Aug 25, 2022
4a985ea
feat: add `params` to `update_fn`'s signature
XuehaiPan Aug 25, 2022
a3cb948
feat: add weight decay
XuehaiPan Aug 26, 2022
cbf1a52
test: add weight decay tests
XuehaiPan Aug 26, 2022
a0be5d9
lint: pass lint
XuehaiPan Aug 26, 2022
71a4a63
docs: update docstring
XuehaiPan Aug 26, 2022
2ab1da0
chore: update type hints
XuehaiPan Aug 26, 2022
a7d7643
fix: fix grad tracing for weight decay
XuehaiPan Aug 26, 2022
63d77cc
test: reorganize tests
XuehaiPan Aug 26, 2022
b31e245
chore: add RMSprop aliases for PyTorch compatibility
XuehaiPan Aug 26, 2022
a929c51
test: add module buffers
XuehaiPan Aug 26, 2022
f082d6c
test: update test parameters
XuehaiPan Aug 26, 2022
1c8421d
chore: update .gitignore
XuehaiPan Aug 26, 2022
2723c01
test: update test parameters
XuehaiPan Aug 27, 2022
1ec3d0f
refactor: refactor transform
XuehaiPan Aug 27, 2022
e8bd609
refactor: chain
XuehaiPan Aug 27, 2022
353b628
refactor: identity
XuehaiPan Aug 27, 2022
bee81bb
feat: add with_flattened_tree
XuehaiPan Aug 27, 2022
a8b6dc0
test: update test parameters
XuehaiPan Aug 27, 2022
7d1d20a
feat: add dampening
XuehaiPan Aug 27, 2022
58dcc56
docs: update docstring
XuehaiPan Aug 28, 2022
acde3fd
lint: fix mypy
XuehaiPan Aug 28, 2022
b1e1521
fix: fix grad tracing for initial value
XuehaiPan Aug 29, 2022
6331cdc
test: update test parameters
XuehaiPan Aug 29, 2022
1b39ad5
docs: update docstrings
XuehaiPan Aug 29, 2022
9976f96
chore: rename variables
XuehaiPan Aug 29, 2022
4c726f2
test: update test parameters
XuehaiPan Aug 29, 2022
3eee243
docs(CHANGELOG): update CHANGELOG.md
XuehaiPan Aug 29, 2022
1fded5d
test: test with pre-release of PyTorch
XuehaiPan Aug 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: add weight decay
  • Loading branch information
XuehaiPan committed Aug 26, 2022
commit a3cb9481a783f2f91c802304c820fa33dd0d4525
128 changes: 107 additions & 21 deletions torchopt/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,98 @@
from torchopt._src.utils import pytree


def _scale_by_lr(lr: ScalarOrSchedule, maximize=False):
sign = -1 if not maximize else 1
def _flip_sign_and_weight_decay(weight_decay: float = 0.0, maximize=False):
if not 0.0 <= weight_decay: # pylint: disable=unneeded-not
raise ValueError(f"Invalid weight_decay value: {weight_decay}")

if not maximize and weight_decay == 0.0:
return base.identity()

def init_fn(_):
return base.EmptyState()

if not maximize: # gradient descent

def update_fn(updates, state, *, params=None, inplace=True):
assert params is not None, (
"Parameters are required for weight decay. "
"Call `update(updates, state, params=params)` instead."
)

if inplace:

def f(g, p):
return g.add_(p, alpha=weight_decay) if g is not None else None

else:

def f(g, p):
return g.add(p, alpha=weight_decay) if g is not None else None

updates = pytree.tree_map(f, updates, params)
return updates, state

else: # gradient ascent

if weight_decay == 0.0:
# pylint: disable-next=unused-argument
def update_fn(updates, state, *, params=None, inplace=True):
if inplace:

def f(g):
return g.neg_() if g is not None else None

else:

def f(g):
return g.neg() if g is not None else None

updates = pytree.tree_map(f, updates)
return updates, state

else:

def update_fn(updates, state, *, params=None, inplace=True):
assert params is not None, (
"Parameters are required for weight decay. "
"Call `update(updates, state, params=params)` instead."
)

if inplace:

def f(g, p):
return g.neg_().add_(p, alpha=weight_decay) if g is not None else None

else:

def f(g, p):
return g.neg().add_(p, alpha=weight_decay) if g is not None else None

updates = pytree.tree_map(f, updates, params)
return updates, state

return base.GradientTransformation(init_fn, update_fn)


def _scale_by_neg_lr(lr: ScalarOrSchedule):
if callable(lr):

def schedule_wrapper(count):
def f(scaled_lr):
return sign * scaled_lr
return -scaled_lr

return pytree.tree_map(f, lr(count)) # type: ignore

return transform.scale_by_schedule(schedule_wrapper)
return transform.scale(sign * lr)
return transform.scale(-lr)


# pylint: disable-next=too-many-arguments
def adam(
lr: ScalarOrSchedule = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
*,
eps_root: float = 0.0,
moment_requires_grad: bool = False,
Expand All @@ -81,6 +154,8 @@ def adam(
eps: (float, default: :const:`1e-8`)
A small constant applied to denominator outside of the square root (as in the Adam
paper) to avoid dividing by zero when rescaling.
weight_decay: (float, default: :const:`0.0`):
Weight decay, add L2 penalty to parameters.
eps_root: (float, default: :data:`0.0`)
A small constant applied to denominator inside the square root (as in RMSProp), to avoid
dividing by zero when rescaling. This is needed for example when computing
Expand All @@ -106,26 +181,32 @@ def adam(
raise ValueError(f'Invalid beta parameter at index 0: {b1}')
if not 0.0 <= b2 < 1.0:
raise ValueError(f'Invalid beta parameter at index 1: {b2}')
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
# pylint: enable=unneeded-not

adam_inst = (
transform.scale_by_accelerated_adam if use_accelerated_op else transform.scale_by_adam
)
if use_accelerated_op:
adam_scaler = transform.scale_by_accelerated_adam
else:
adam_scaler = transform.scale_by_adam

return combine.chain(
adam_inst(
_flip_sign_and_weight_decay(weight_decay=weight_decay, maximize=maximize),
adam_scaler(
b1=b1,
b2=b2,
eps=eps,
eps_root=eps_root,
moment_requires_grad=moment_requires_grad,
),
_scale_by_lr(lr, maximize=maximize),
_scale_by_neg_lr(lr),
)


def sgd(
lr: ScalarOrSchedule,
momentum: float = 0.0,
weight_decay: float = 0.0,
nesterov: bool = False,
*,
moment_requires_grad: bool = False,
Expand All @@ -146,6 +227,8 @@ def sgd(
momentum: (float, default: :const:`0.0`)
The decay rate used by the momentum term. The momentum is not used when it is set to
:const:`0.0`.
weight_decay: (float, default: :const:`0.0`):
Weight decay, add L2 penalty to parameters.
nesterov: (bool, default: :data:`False`)
Whether the nesterov momentum is used.
moment_requires_grad: (bool, default: :data:`False`)
Expand All @@ -162,9 +245,12 @@ def sgd(
raise ValueError(f'Invalid learning rate: {lr}')
if not 0.0 <= momentum:
raise ValueError(f'Invalid momentum value: {momentum}')
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
# pylint: enable=unneeded-not

return combine.chain(
_flip_sign_and_weight_decay(weight_decay=weight_decay, maximize=maximize),
(
transform.trace(
decay=momentum,
Expand All @@ -174,7 +260,7 @@ def sgd(
if momentum is not None and momentum != 0.0
else base.identity()
),
_scale_by_lr(lr, maximize=maximize),
_scale_by_neg_lr(lr),
)


Expand All @@ -183,6 +269,7 @@ def rmsprop(
lr: ScalarOrSchedule = 1e-2,
alpha: float = 0.9,
eps: float = 1e-8,
weight_decay: float = 0.0,
momentum: float = 0.0,
centered: bool = False,
*,
Expand All @@ -208,6 +295,8 @@ def rmsprop(
Smoothing constant, the decay used to track the magnitude of previous gradients.
eps: (float, default: :const:`1e-8`)
A small numerical constant to avoid dividing by zero when rescaling.
weight_decay: (float, default: :const:`0.0`):
Weight decay, add L2 penalty to parameters.
momentum: (float, default: :const:`0.0`)
The decay rate used by the momentum term. The momentum is not used when it is set to
:const:`0.0`.
Expand Down Expand Up @@ -235,25 +324,22 @@ def rmsprop(
raise ValueError(f'Invalid epsilon value: {eps}')
if not 0.0 <= momentum:
raise ValueError(f'Invalid momentum value: {momentum}')
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
# pylint: enable=unneeded-not

if centered:
return combine.chain(
transform.scale_by_stddev(alpha=alpha, eps=eps, initial_scale=initial_scale),
(
transform.trace(decay=momentum, nesterov=nesterov)
if momentum is not None and momentum != 0.0
else base.identity()
),
_scale_by_lr(lr, maximize=maximize),
)
rmsprop_scaler = transform.scale_by_stddev
else:
rmsprop_scaler = transform.scale_by_rms

return combine.chain(
transform.scale_by_rms(alpha=alpha, eps=eps, initial_scale=initial_scale),
_flip_sign_and_weight_decay(weight_decay=weight_decay, maximize=maximize),
rmsprop_scaler(alpha=alpha, eps=eps, initial_scale=initial_scale),
(
transform.trace(decay=momentum, nesterov=nesterov)
if momentum is not None and momentum != 0.0
else base.identity()
),
_scale_by_lr(lr, maximize=maximize),
_scale_by_neg_lr(lr),
)
4 changes: 4 additions & 0 deletions torchopt/_src/optimizer/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
lr: ScalarOrSchedule,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
*,
eps_root: float = 0.0,
maximize: bool = False,
Expand All @@ -54,6 +55,8 @@ def __init__(
eps: (float, default: :const:`1e-8`)
A small constant applied to denominator outside of the square root (as in the Adam
paper) to avoid dividing by zero when rescaling.
weight_decay: (float, default: :const:`0.0`):
Weight decay, add L2 penalty to parameters.
eps_root: (float, default: :data:`0.0`)
A small constant applied to denominator inside the square root (as in RMSProp), to
avoid dividing by zero when rescaling. This is needed for example when computing
Expand All @@ -69,6 +72,7 @@ def __init__(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
eps_root=eps_root,
moment_requires_grad=False,
maximize=maximize,
Expand Down
4 changes: 4 additions & 0 deletions torchopt/_src/optimizer/meta/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
lr: ScalarOrSchedule = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
*,
eps_root: float = 0.0,
moment_requires_grad: bool = True,
Expand All @@ -55,6 +56,8 @@ def __init__(
eps: (float, default: :const:`1e-8`)
A small constant applied to denominator outside of the square root (as in the Adam
paper) to avoid dividing by zero when rescaling.
weight_decay: (float, default: :const:`0.0`):
Weight decay, add L2 penalty to parameters.
eps_root: (float, default: :data:`0.0`)
A small constant applied to denominator inside the square root (as in RMSProp), to
avoid dividing by zero when rescaling. This is needed for example when computing
Expand All @@ -73,6 +76,7 @@ def __init__(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
eps_root=eps_root,
moment_requires_grad=moment_requires_grad,
maximize=maximize,
Expand Down
4 changes: 4 additions & 0 deletions torchopt/_src/optimizer/meta/rmsprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
lr: ScalarOrSchedule = 1e-2,
alpha: float = 0.99,
eps: float = 1e-8,
weight_decay: float = 0.0,
momentum: float = 0.0,
centered: bool = False,
*,
Expand All @@ -53,6 +54,8 @@ def __init__(
Smoothing constant, the decay used to track the magnitude of previous gradients.
eps: (float, default: :const:`1e-8`)
A small numerical constant to avoid dividing by zero when rescaling.
weight_decay: (float, default: :const:`0.0`):
Weight decay, add L2 penalty to parameters.
momentum: (float, default: :const:`0.0`)
The decay rate used by the momentum term. The momentum is not used when it is set to
:const:`0.0`.
Expand All @@ -74,6 +77,7 @@ def __init__(
lr=lr,
alpha=alpha,
eps=eps,
weight_decay=weight_decay,
momentum=momentum,
centered=centered,
initial_scale=initial_scale,
Expand Down
4 changes: 4 additions & 0 deletions torchopt/_src/optimizer/meta/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
net: nn.Module,
lr: ScalarOrSchedule,
momentum: float = 0.0,
weight_decay: float = 0.0,
nesterov: bool = False,
moment_requires_grad: bool = True,
maximize: bool = False,
Expand All @@ -48,6 +49,8 @@ def __init__(
momentum: (float, default: :const:`0.0`)
The decay rate used by the momentum term. The momentum is not used when it is set to
:const:`0.0`.
weight_decay: (float, default: :const:`0.0`):
Weight decay, add L2 penalty to parameters.
nesterov: (bool, default: :const:`False`)
Whether the nesterov momentum is used.
moment_requires_grad: (bool, default: :data:`True`)
Expand All @@ -61,6 +64,7 @@ def __init__(
sgd(
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
nesterov=nesterov,
moment_requires_grad=moment_requires_grad,
maximize=maximize,
Expand Down
4 changes: 4 additions & 0 deletions torchopt/_src/optimizer/rmsprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
lr: ScalarOrSchedule = 1e-2,
alpha: float = 0.99,
eps: float = 1e-8,
weight_decay: float = 0.0,
momentum: float = 0.0,
centered: bool = False,
*,
Expand All @@ -55,6 +56,8 @@ def __init__(
Smoothing constant, the decay used to track the magnitude of previous gradients.
eps: (float, default: :const:`1e-8`)
A small numerical constant to avoid dividing by zero when rescaling.
weight_decay: (float, default: :const:`0.0`):
Weight decay, add L2 penalty to parameters.
momentum: (float, default: :const:`0.0`)
The decay rate used by the momentum term. The momentum is not used when it is set to
:const:`0.0`.
Expand All @@ -76,6 +79,7 @@ def __init__(
lr=lr,
alpha=alpha,
eps=eps,
weight_decay=weight_decay,
momentum=momentum,
centered=centered,
initial_scale=initial_scale,
Expand Down
4 changes: 4 additions & 0 deletions torchopt/_src/optimizer/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
params: Iterable[torch.Tensor],
lr: ScalarOrSchedule,
momentum: float = 0.0,
weight_decay: float = 0.0,
nesterov: bool = False,
maximize: bool = False,
):
Expand All @@ -49,6 +50,8 @@ def __init__(
momentum: (float, default: :const:`0.0`)
The decay rate used by the momentum term. The momentum is not used when it is set to
:const:`0.0`.
weight_decay: (float, default: :const:`0.0`):
Weight decay, add L2 penalty to parameters.
nesterov: (bool, default: :data:`False`)
Whether the nesterov momentum is used.
maximize: (bool, default: :data:`False`)
Expand All @@ -59,6 +62,7 @@ def __init__(
sgd(
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
nesterov=nesterov,
moment_requires_grad=False,
maximize=maximize,
Expand Down
3 changes: 1 addition & 2 deletions torchopt/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ def scale(step_size: float) -> base.GradientTransformation:
An ``(init_fn, update_fn)`` tuple.
"""

def init_fn(params):
del params
def init_fn(_):
return ScaleState()

def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument
Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy