diff --git a/CHANGELOG.md b/CHANGELOG.md index a6f2bc8e..b9286125 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Implement AdamW optimizer with masking by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#44](https://github.com/metaopt/torchopt/pull/44). - Add half float support for accelerated OPs by [@XuehaiPan](https://github.com/XuehaiPan) in [#67](https://github.com/metaopt/torchopt/pull/67). - Add MAML example with TorchRL integration by [@vmoens](https://github.com/vmoens) and [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#12](https://github.com/metaopt/TorchOpt/pull/12). - Add optional argument `params` to update function in gradient transformations by [@XuehaiPan](https://github.com/XuehaiPan) in [#65](https://github.com/metaopt/torchopt/pull/65). diff --git a/docs/source/api/api.rst b/docs/source/api/api.rst index 44da5b93..545a8d54 100644 --- a/docs/source/api/api.rst +++ b/docs/source/api/api.rst @@ -32,12 +32,18 @@ Functional Optimizers adam sgd rmsprop + adamw Functional Adam Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: adam +Functional AdamW Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: adamw + Functional SGD Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~ @@ -60,12 +66,18 @@ Classic Optimizers Adam SGD RMSProp + AdamW Classic Adam Optimizer ~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: Adam +Classic AdamW Optimizer +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: AdamW + Classic SGD Optimizer ~~~~~~~~~~~~~~~~~~~~~ @@ -88,12 +100,18 @@ Differentiable Meta-Optimizers MetaAdam MetaSGD MetaRMSProp + MetaAdamW Differentiable Meta-Adam Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: MetaAdam +Differentiable Meta-AdamW Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: MetaAdamW + Differentiable Meta-SGD Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/spelling_wordlist.txt b/docs/source/spelling_wordlist.txt index a48d20d4..ca34dd05 100644 --- a/docs/source/spelling_wordlist.txt +++ b/docs/source/spelling_wordlist.txt @@ -73,3 +73,8 @@ CPython nn Vincent Moens +AdamW +Loshchilov +pytree +booleans +subtrees diff --git a/tests/test_alias.py b/tests/test_alias.py index 1ab20960..6f37e939 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -154,6 +154,64 @@ def test_adam( weight_decay=[0.0, 1e-2], maximize=[False, True], ) +def test_adamw( + dtype: torch.dtype, + lr: float, + betas: Tuple[float, float], + eps: float, + inplace: bool, + weight_decay: float, + maximize: bool, +) -> None: + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + fmodel, params, buffers = functorch.make_functional_with_buffers(model) + optim = torchopt.adamw( + lr, + betas=betas, + eps=eps, + eps_root=0.0, + weight_decay=weight_decay, + maximize=maximize, + ) + optim_state = optim.init(params) + optim_ref = torch.optim.AdamW( + model_ref.parameters(), + lr, + betas=betas, + eps=eps, + amsgrad=False, + weight_decay=weight_decay, + maximize=maximize, + ) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = fmodel(params, buffers, xs) + pred_ref = model_ref(xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + grads = torch.autograd.grad(loss, params) + updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) + params = torchopt.apply_updates(params, updates, inplace=inplace) + + optim_ref.zero_grad() + loss_ref.backward() + optim_ref.step() + + helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) + + +@helpers.parametrize( + dtype=[torch.float64], + lr=[1e-2, 1e-3, 1e-4], + betas=[(0.9, 0.999), (0.95, 0.9995)], + eps=[1e-8], + inplace=[True, False], + weight_decay=[1e-2, 1e-1], + maximize=[False, True], +) def test_adam_accelerated_cpu( dtype: torch.dtype, lr: float, diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index b7e6818d..c0db3e34 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -138,6 +138,61 @@ def test_Adam( helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype) +@helpers.parametrize( + dtype=[torch.float64], + lr=[1e-2, 1e-3, 1e-4], + betas=[(0.9, 0.999), (0.95, 0.9995)], + eps=[1e-8], + weight_decay=[1e-2, 1e-1], + maximize=[False, True], +) +def test_AdamW( + dtype: torch.dtype, + lr: float, + betas: Tuple[float, float], + eps: float, + weight_decay: float, + maximize: bool, +) -> None: + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + optim = torchopt.AdamW( + model.parameters(), + lr, + betas=betas, + eps=eps, + eps_root=0.0, + weight_decay=weight_decay, + maximize=maximize, + ) + optim_ref = torch.optim.AdamW( + model_ref.parameters(), + lr, + betas=betas, + eps=eps, + amsgrad=False, + weight_decay=weight_decay, + maximize=maximize, + ) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = model(xs) + pred_ref = model_ref(xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + optim.zero_grad() + loss.backward() + optim.step() + + optim_ref.zero_grad() + loss_ref.backward() + optim_ref.step() + + helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype) + + @helpers.parametrize( dtype=[torch.float64], lr=[1e-2, 1e-3, 1e-4], diff --git a/torchopt/__init__.py b/torchopt/__init__.py index 12fa9d12..ab7a5a4d 100644 --- a/torchopt/__init__.py +++ b/torchopt/__init__.py @@ -15,11 +15,18 @@ """TorchOpt: a high-performance optimizer library built upon PyTorch.""" from torchopt._src import accelerated_op_available, clip, combine, hook, schedule, visual -from torchopt._src.alias import adam, rmsprop, sgd +from torchopt._src.alias import adam, adamw, rmsprop, sgd from torchopt._src.clip import clip_grad_norm from torchopt._src.combine import chain -from torchopt._src.optimizer import SGD, Adam, Optimizer, RMSProp, RMSprop, meta -from torchopt._src.optimizer.meta import MetaAdam, MetaOptimizer, MetaRMSProp, MetaRMSprop, MetaSGD +from torchopt._src.optimizer import SGD, Adam, AdamW, Optimizer, RMSProp, RMSprop, meta +from torchopt._src.optimizer.meta import ( + MetaAdam, + MetaAdamW, + MetaOptimizer, + MetaRMSProp, + MetaRMSprop, + MetaSGD, +) from torchopt._src.update import apply_updates from torchopt._src.utils import extract_state_dict, recover_state_dict, stop_gradient from torchopt.version import __version__ @@ -33,6 +40,7 @@ 'schedule', 'visual', 'adam', + 'adamw', 'rmsprop', 'sgd', 'clip_grad_norm', @@ -40,11 +48,13 @@ 'Optimizer', 'SGD', 'Adam', + 'AdamW', 'RMSProp', 'RMSprop', 'MetaOptimizer', 'MetaSGD', 'MetaAdam', + 'MetaAdamW', 'MetaRMSProp', 'MetaRMSprop', 'apply_updates', diff --git a/torchopt/_src/alias.py b/torchopt/_src/alias.py index 74677729..40b2e92d 100644 --- a/torchopt/_src/alias.py +++ b/torchopt/_src/alias.py @@ -32,7 +32,7 @@ # pylint: disable=invalid-name -from typing import Tuple +from typing import Any, Callable, Optional, Tuple, Union from torchopt._src import base, combine, transform from torchopt._src.typing import ScalarOrSchedule @@ -45,7 +45,7 @@ def _flip_sign_and_weight_decay(weight_decay: float = 0.0, maximize=False): if not maximize and weight_decay == 0.0: return base.identity() - def init_fn(_): + def init_fn(params): # pylint: disable=unused-argument return base.EmptyState() if not maximize: # gradient descent @@ -166,7 +166,7 @@ def adam( eps: (default: :const:`1e-8`) A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`): + weight_decay: (default: :const:`0.0`) Weight decay, add L2 penalty to parameters. eps_root: (default: :data:`0.0`) A small constant applied to denominator inside the square root (as in RMSProp), to avoid @@ -174,7 +174,7 @@ def adam( (meta-)gradients through Adam. moment_requires_grad: (default: :data:`False`) If :data:`True` the momentums will be created with flag ``requires_grad=True``, this - flag is often used in Meta Learning algorithms. + flag is often used in Meta-Learning algorithms. maximize: (default: :data:`False`) Maximize the params based on the objective, instead of minimizing. use_accelerated_op: (default: :data:`False`) @@ -218,66 +218,98 @@ def adam( ) -def sgd( - lr: ScalarOrSchedule, - momentum: float = 0.0, - dampening: float = 0.0, - weight_decay: float = 0.0, - nesterov: bool = False, +# pylint: disable-next=too-many-arguments +def adamw( + lr: ScalarOrSchedule = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 1e-2, *, + eps_root: float = 0.0, + mask: Optional[Union[Any, Callable[['base.Params'], Any]]] = None, moment_requires_grad: bool = False, maximize: bool = False, + use_accelerated_op: bool = False, ) -> base.GradientTransformation: - """The functional version of the canonical Stochastic Gradient Descent optimizer. + """Adam with weight decay regularization. - This implements stochastic gradient descent. It also includes support for momentum, and nesterov - acceleration, as these are standard practice when using stochastic gradient descent to train - deep neural networks. + AdamW uses weight decay to regularize learning towards small weights, as + this leads to better generalization. In SGD you can also use L2 regularization + to implement this as an additive loss term, however L2 regularization + does not behave as intended for adaptive gradient algorithms such as Adam. References: - - Sutskever et al, 2013: http://proceedings.mlr.press/v28/sutskever13.pdf + - Loshchilov et al, 2019: https://arxiv.org/abs/1711.05101 Args: - lr: This is a fixed global scaling factor. - momentum: (default: :const:`0.0`) - The decay rate used by the momentum term. The momentum is not used when it is set to - :const:`0.0`. - weight_decay: (default: :const:`0.0`): - Weight decay, add L2 penalty to parameters. - dampening: (default: :const:`0.0`) - Dampening for momentum. - nesterov: (default: :data:`False`) - Whether to use Nesterov momentum. + lr: (default: :const:`1e-3`) + This is a fixed global scaling factor. + betas: (default: :const:`(0.9, 0.999)`) + Coefficients used for computing running averages of gradient and its square. + eps: (default: :const:`1e-8`) + A small constant applied to denominator outside of the square root (as in the Adam + paper) to avoid dividing by zero when rescaling. + weight_decay: (default: :const:`1e-2`) + Strength of the weight decay regularization. Note that this weight decay is multiplied + with the learning rate. This is consistent with other frameworks such as PyTorch, but + different from (Loshchilov et al, 2019) where the weight decay is only multiplied with + the "schedule multiplier", but not the base learning rate. + eps_root: (default: :data:`0.0`) + A small constant applied to denominator inside the square root (as in RMSProp), to avoid + dividing by zero when rescaling. This is needed for example when computing + (meta-)gradients through Adam. + mask: (default: :data:`None`) + A tree with same structure as (or a prefix of) the params PyTree, or a Callable that + returns such a pytree given the params/updates. The leaves should be booleans, + :data:`True` for leaves/subtrees you want to apply the weight decay to, and + :data:`False` for those you want to skip. Note that the Adam gradient + transformations are applied to all parameters. moment_requires_grad: (default: :data:`False`) If :data:`True` the momentums will be created with flag ``requires_grad=True``, this flag is often used in Meta-Learning algorithms. maximize: (default: :data:`False`) Maximize the params based on the objective, instead of minimizing. + use_accelerated_op: (default: :data:`False`) + If :data:`True` use our implemented fused operator. Returns: - A :class:`GradientTransformation` instance. + The corresponding :class:`GradientTransformation` instance. """ + b1, b2 = betas # pylint: disable=unneeded-not if not (callable(lr) or 0.0 <= lr): raise ValueError(f'Invalid learning rate: {lr}') - if not 0.0 <= momentum: - raise ValueError(f'Invalid momentum value: {momentum}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= b1 < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {b1}') + if not 0.0 <= b2 < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {b2}') if not 0.0 <= weight_decay: raise ValueError(f'Invalid weight_decay value: {weight_decay}') - if nesterov and (momentum <= 0.0 or dampening != 0.0): - raise ValueError('Nesterov momentum requires a momentum and zero dampening') # pylint: enable=unneeded-not + if use_accelerated_op: + adam_scaler = transform._scale_by_accelerated_adam # pylint: disable=protected-access + else: + adam_scaler = transform._scale_by_adam # pylint: disable=protected-access + return transform.with_flattened_tree( combine.chain( - _flip_sign_and_weight_decay(weight_decay=weight_decay, maximize=maximize), - transform._trace( # pylint: disable=protected-access - momentum=momentum, - dampening=dampening, - nesterov=nesterov, + _flip_sign_and_weight_decay(weight_decay=0.0, maximize=maximize), + adam_scaler( + b1=b1, + b2=b2, + eps=eps, + eps_root=eps_root, moment_requires_grad=moment_requires_grad, already_flattened=True, ), + transform._add_decayed_weights( # pylint: disable=protected-access + weight_decay=weight_decay, + mask=mask, + already_flattened=True, + ), _scale_by_neg_lr(lr), ) ) @@ -314,7 +346,7 @@ def rmsprop( Smoothing constant, the decay used to track the magnitude of previous gradients. eps: (default: :const:`1e-8`) A small numerical constant to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`): + weight_decay: (default: :const:`0.0`) Weight decay, add L2 penalty to parameters. momentum: (default: :const:`0.0`) The decay rate used by the momentum term. The momentum is not used when it is set to @@ -369,3 +401,68 @@ def rmsprop( _scale_by_neg_lr(lr), ) ) + + +def sgd( + lr: ScalarOrSchedule, + momentum: float = 0.0, + dampening: float = 0.0, + weight_decay: float = 0.0, + nesterov: bool = False, + *, + moment_requires_grad: bool = False, + maximize: bool = False, +) -> base.GradientTransformation: + """The functional version of the canonical Stochastic Gradient Descent optimizer. + + This implements stochastic gradient descent. It also includes support for momentum, and nesterov + acceleration, as these are standard practice when using stochastic gradient descent to train + deep neural networks. + + References: + - Sutskever et al, 2013: http://proceedings.mlr.press/v28/sutskever13.pdf + + Args: + lr: This is a fixed global scaling factor. + momentum: (default: :const:`0.0`) + The decay rate used by the momentum term. The momentum is not used when it is set to + :const:`0.0`. + weight_decay: (default: :const:`0.0`) + Weight decay, add L2 penalty to parameters. + dampening: (default: :const:`0.0`) + Dampening for momentum. + nesterov: (default: :data:`False`) + Whether to use Nesterov momentum. + moment_requires_grad: (default: :data:`False`) + If :data:`True` the momentums will be created with flag ``requires_grad=True``, this + flag is often used in Meta-Learning algorithms. + maximize: (default: :data:`False`) + Maximize the params based on the objective, instead of minimizing. + + Returns: + The corresponding :class:`GradientTransformation` instance. + """ + # pylint: disable=unneeded-not + if not (callable(lr) or 0.0 <= lr): + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= momentum: + raise ValueError(f'Invalid momentum value: {momentum}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + if nesterov and (momentum <= 0.0 or dampening != 0.0): + raise ValueError('Nesterov momentum requires a momentum and zero dampening') + # pylint: enable=unneeded-not + + return transform.with_flattened_tree( + combine.chain( + _flip_sign_and_weight_decay(weight_decay=weight_decay, maximize=maximize), + transform._trace( # pylint: disable=protected-access + momentum=momentum, + dampening=dampening, + nesterov=nesterov, + moment_requires_grad=moment_requires_grad, + already_flattened=True, + ), + _scale_by_neg_lr(lr), + ) + ) diff --git a/torchopt/_src/base.py b/torchopt/_src/base.py index ef92acf2..f17bf00f 100644 --- a/torchopt/_src/base.py +++ b/torchopt/_src/base.py @@ -82,7 +82,9 @@ class TransformUpdateFn(Protocol): # pylint: disable=too-few-public-methods The :func:`update` step takes a tree of candidate parameter ``updates`` (e.g. their gradient with respect to some loss), an arbitrary structured ``state``, and the current ``params`` of the model being optimized. The ``params`` argument is optional, it must however be provided when - using transformations that require access to the current values of the parameters. + using transformations that require access to the current values of the parameters. The + ``inplace`` argument is optional, If :data:`True`, modify updates and state using inplace + operations. """ @abstractmethod @@ -99,6 +101,8 @@ def __call__( Args: updates: A tree of candidate updates. state: The state of the gradient transformation. + params: (optional) + The current value of the parameters. inplace: (optional) If :data:`True`, modify updates and state using inplace operations. diff --git a/torchopt/_src/clip.py b/torchopt/_src/clip.py index 2101dd69..31d54797 100644 --- a/torchopt/_src/clip.py +++ b/torchopt/_src/clip.py @@ -38,7 +38,7 @@ def clip_grad_norm( An ``(init_fn, update_fn)`` tuple. """ - def init_fn(_): + def init_fn(params): # pylint: disable=unused-argument return ClipState() def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument diff --git a/torchopt/_src/hook.py b/torchopt/_src/hook.py index 408e17b4..305c34ca 100644 --- a/torchopt/_src/hook.py +++ b/torchopt/_src/hook.py @@ -33,7 +33,7 @@ def register_hook(hook) -> GradientTransformation: An ``(init_fn, update_fn)`` tuple. """ - def init_fn(_): + def init_fn(params): # pylint: disable=unused-argument return EmptyState() def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument diff --git a/torchopt/_src/optimizer/__init__.py b/torchopt/_src/optimizer/__init__.py index 17e8b0a7..8501fb15 100644 --- a/torchopt/_src/optimizer/__init__.py +++ b/torchopt/_src/optimizer/__init__.py @@ -15,6 +15,7 @@ from torchopt._src.optimizer import meta from torchopt._src.optimizer.adam import Adam +from torchopt._src.optimizer.adamw import AdamW from torchopt._src.optimizer.base import Optimizer from torchopt._src.optimizer.rmsprop import RMSProp, RMSprop from torchopt._src.optimizer.sgd import SGD diff --git a/torchopt/_src/optimizer/adam.py b/torchopt/_src/optimizer/adam.py index 2a096102..6776408e 100644 --- a/torchopt/_src/optimizer/adam.py +++ b/torchopt/_src/optimizer/adam.py @@ -55,7 +55,7 @@ def __init__( eps: (default: :const:`1e-8`) A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`): + weight_decay: (default: :const:`0.0`) Weight decay, add L2 penalty to parameters. eps_root: (default: :data:`0.0`) A small constant applied to denominator inside the square root (as in RMSProp), to diff --git a/torchopt/_src/optimizer/adamw.py b/torchopt/_src/optimizer/adamw.py new file mode 100644 index 00000000..886cd77a --- /dev/null +++ b/torchopt/_src/optimizer/adamw.py @@ -0,0 +1,93 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any, Callable, Iterable, Optional, Tuple, Union + +import torch + +from torchopt._src import base # pylint: disable=unused-import +from torchopt._src.alias import adamw +from torchopt._src.optimizer.base import Optimizer +from torchopt._src.typing import ScalarOrSchedule + + +class AdamW(Optimizer): + """The classic AdamW optimizer. + + See Also: + - The functional AdamW optimizer: :func:`torchopt.adamw`. + - The differentiable meta-AdamW optimizer: :class:`torchopt.MetaAdamW`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + params: Iterable[torch.Tensor], + lr: ScalarOrSchedule = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 1e-2, + *, + eps_root: float = 0.0, + mask: Optional[Union[Any, Callable[['base.Params'], Any]]] = None, + maximize: bool = False, + use_accelerated_op: bool = False, + ): + r"""The :meth:`init` function. + + Args: + params: (iterable of torch.Tensor) + An iterable of :class:`torch.Tensor`\s. Specifies what tensors should be optimized. + lr: (default: :const:`1e-3`) + This is a fixed global scaling factor. + betas: (default: :const:`(0.9, 0.999)`) + Coefficients used for computing running averages of gradient and its square. + eps: (default: :const:`1e-8`) + A small constant applied to denominator outside of the square root (as in the Adam + paper) to avoid dividing by zero when rescaling. + weight_decay: (default: :const:`1e-2`) + Strength of the weight decay regularization. Note that this weight decay is + multiplied with the learning rate. This is consistent with other frameworks such as + PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only + multiplied with the "schedule multiplier", but not the base learning rate. + eps_root: (default: :data:`0.0`) + A small constant applied to denominator inside the square root (as in RMSProp), to + avoid dividing by zero when rescaling. This is needed for example when computing + (meta-)gradients through Adam. + mask: (default: :data:`None`) + A tree with same structure as (or a prefix of) the params PyTree, or a Callable that + returns such a pytree given the params/updates. The leaves should be booleans, + :data:`True` for leaves/subtrees you want to apply the weight decay to, and + :data:`False` for those you want to skip. Note that the Adam gradient + transformations are applied to all parameters. + maximize: (default: :data:`False`) + Maximize the params based on the objective, instead of minimizing. + use_accelerated_op: (default: :data:`False`) + If :data:`True` use our implemented fused operator. + """ + super().__init__( + params, + adamw( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + eps_root=eps_root, + mask=mask, + moment_requires_grad=False, + maximize=maximize, + use_accelerated_op=use_accelerated_op, + ), + ) diff --git a/torchopt/_src/optimizer/meta/__init__.py b/torchopt/_src/optimizer/meta/__init__.py index 294d9ebd..ec227474 100644 --- a/torchopt/_src/optimizer/meta/__init__.py +++ b/torchopt/_src/optimizer/meta/__init__.py @@ -14,6 +14,7 @@ # ============================================================================== from torchopt._src.optimizer.meta.adam import MetaAdam +from torchopt._src.optimizer.meta.adamw import MetaAdamW from torchopt._src.optimizer.meta.base import MetaOptimizer from torchopt._src.optimizer.meta.rmsprop import MetaRMSProp, MetaRMSprop from torchopt._src.optimizer.meta.sgd import MetaSGD diff --git a/torchopt/_src/optimizer/meta/adam.py b/torchopt/_src/optimizer/meta/adam.py index 442eca10..6b76f959 100644 --- a/torchopt/_src/optimizer/meta/adam.py +++ b/torchopt/_src/optimizer/meta/adam.py @@ -47,7 +47,8 @@ def __init__( """The :meth:`init` function. Args: - net: A network whose parameters should be optimized. + net: (nn.Module) + A network whose parameters should be optimized. lr: (default: :const:`1e-3`) This is a fixed global scaling factor. betas: (default: :const:`(0.9, 0.999)`) @@ -55,7 +56,7 @@ def __init__( eps: (default: :const:`1e-8`) A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`): + weight_decay: (default: :const:`0.0`) Weight decay, add L2 penalty to parameters. eps_root: (default: :data:`0.0`) A small constant applied to denominator inside the square root (as in RMSProp), to diff --git a/torchopt/_src/optimizer/meta/adamw.py b/torchopt/_src/optimizer/meta/adamw.py new file mode 100644 index 00000000..c38f3c5c --- /dev/null +++ b/torchopt/_src/optimizer/meta/adamw.py @@ -0,0 +1,97 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any, Callable, Optional, Tuple, Union + +import torch.nn as nn + +from torchopt._src import base # pylint: disable=unused-import +from torchopt._src.alias import adamw +from torchopt._src.optimizer.meta.base import MetaOptimizer +from torchopt._src.typing import ScalarOrSchedule + + +class MetaAdamW(MetaOptimizer): + """The differentiable AdamW optimizer. + + See Also: + - The functional AdamW optimizer: :func:`torchopt.adamw`. + - The classic AdamW optimizer: :class:`torchopt.AdamW`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + net: nn.Module, + lr: ScalarOrSchedule = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 1e-2, + *, + eps_root: float = 0.0, + mask: Optional[Union[Any, Callable[['base.Params'], Any]]] = None, + moment_requires_grad: bool = False, + maximize: bool = False, + use_accelerated_op: bool = False, + ): + """The :meth:`init` function. + + Args: + net: (nn.Module) + A network whose parameters should be optimized. + lr: (default: :const:`1e-3`) + This is a fixed global scaling factor. + betas: (default: :const:`(0.9, 0.999)`) + Coefficients used for computing running averages of gradient and its square. + eps: (default: :const:`1e-8`) + A small constant applied to denominator outside of the square root (as in the Adam + paper) to avoid dividing by zero when rescaling. + weight_decay: (default: :const:`1e-2`) + Strength of the weight decay regularization. Note that this weight decay is + multiplied with the learning rate. This is consistent with other frameworks such as + PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only + multiplied with the "schedule multiplier", but not the base learning rate. + eps_root: (default: :data:`0.0`) + A small constant applied to denominator inside the square root (as in RMSProp), to + avoid dividing by zero when rescaling. This is needed for example when computing + (meta-)gradients through Adam. + mask: (default: :data:`None`) + A tree with same structure as (or a prefix of) the params PyTree, or a Callable that + returns such a pytree given the params/updates. The leaves should be booleans, + :data:`True` for leaves/subtrees you want to apply the weight decay to, and + :data:`False` for those you want to skip. Note that the Adam gradient + transformations are applied to all parameters. + moment_requires_grad: (default: :data:`False`) + If :data:`True` the momentums will be created with flag ``requires_grad=True``, this + flag is often used in Meta-Learning algorithms. + maximize: (default: :data:`False`) + Maximize the params based on the objective, instead of minimizing. + use_accelerated_op: (default: :data:`False`) + If :data:`True` use our implemented fused operator. + """ + super().__init__( + net, + adamw( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + eps_root=eps_root, + mask=mask, + moment_requires_grad=moment_requires_grad, + maximize=maximize, + use_accelerated_op=use_accelerated_op, + ), + ) diff --git a/torchopt/_src/optimizer/meta/base.py b/torchopt/_src/optimizer/meta/base.py index 1acbd1b8..eb5a70b1 100644 --- a/torchopt/_src/optimizer/meta/base.py +++ b/torchopt/_src/optimizer/meta/base.py @@ -28,10 +28,11 @@ def __init__(self, net: nn.Module, impl: GradientTransformation): """The :meth:`init` function. Args: - net (torch.nn.Module): A network whose parameters should be optimized. - impl (GradientTransformation): A low level optimizer function, it could be a optimizer - function provided by ``alias.py`` or a customized ``chain`` provided by - ``combine.py``. + net: (nn.Module) + A network whose parameters should be optimized. + impl: (GradientTransformation) + A low level optimizer function, it could be a optimizer function provided by + ``alias.py`` or a customized ``chain`` provided by ``combine.py``. Note that using ``MetaOptimizer(sgd(moment_requires_grad=True))`` or ``MetaOptimizer(chain(sgd(moment_requires_grad=True)))`` is equivalent to :class:`torchopt.MetaSGD`. @@ -50,8 +51,8 @@ def step(self, loss: torch.Tensor): gradients and update the network parameters without modifying tensors in-place. Args: - loss (torch.Tensor): The loss that is used to compute the gradients to the network - parameters. + loss: (torch.Tensor) + The loss that is used to compute the gradients to the network parameters. """ # pylint: disable=line-too-long # Step parameter only for i, (param_container, new_state) in enumerate( diff --git a/torchopt/_src/optimizer/meta/rmsprop.py b/torchopt/_src/optimizer/meta/rmsprop.py index e3526717..20183236 100644 --- a/torchopt/_src/optimizer/meta/rmsprop.py +++ b/torchopt/_src/optimizer/meta/rmsprop.py @@ -46,14 +46,15 @@ def __init__( """The :meth:`init` function. Args: - net: A network whose parameters should be optimized. + net: (nn.Module) + A network whose parameters should be optimized. lr: (default: :const:`1e-2`) This is a fixed global scaling factor. alpha: (default: :const:`0.99`) Smoothing constant, the decay used to track the magnitude of previous gradients. eps: (default: :const:`1e-8`) A small numerical constant to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`): + weight_decay: (default: :const:`0.0`) Weight decay, add L2 penalty to parameters. momentum: (default: :const:`0.0`) The decay rate used by the momentum term. The momentum is not used when it is set to diff --git a/torchopt/_src/optimizer/meta/sgd.py b/torchopt/_src/optimizer/meta/sgd.py index 534f847b..b8ae5d24 100644 --- a/torchopt/_src/optimizer/meta/sgd.py +++ b/torchopt/_src/optimizer/meta/sgd.py @@ -43,12 +43,13 @@ def __init__( """The :meth:`init` function. Args: - net: A network whose parameters should be optimized. + net: (nn.Module) + A network whose parameters should be optimized. lr: This is a fixed global scaling factor. momentum: (default: :const:`0.0`) The decay rate used by the momentum term. The momentum is not used when it is set to :const:`0.0`. - weight_decay: (default: :const:`0.0`): + weight_decay: (default: :const:`0.0`) Weight decay, add L2 penalty to parameters. dampening: (default: :const:`0.0`) Dampening for momentum. diff --git a/torchopt/_src/optimizer/rmsprop.py b/torchopt/_src/optimizer/rmsprop.py index 85a9ac64..3b8634f3 100644 --- a/torchopt/_src/optimizer/rmsprop.py +++ b/torchopt/_src/optimizer/rmsprop.py @@ -56,7 +56,7 @@ def __init__( Smoothing constant, the decay used to track the magnitude of previous gradients. eps: (default: :const:`1e-8`) A small numerical constant to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`): + weight_decay: (default: :const:`0.0`) Weight decay, add L2 penalty to parameters. momentum: (default: :const:`0.0`) The decay rate used by the momentum term. The momentum is not used when it is set to diff --git a/torchopt/_src/optimizer/sgd.py b/torchopt/_src/optimizer/sgd.py index fc2735d6..a7f415f6 100644 --- a/torchopt/_src/optimizer/sgd.py +++ b/torchopt/_src/optimizer/sgd.py @@ -50,7 +50,7 @@ def __init__( momentum: (default: :const:`0.0`) The decay rate used by the momentum term. The momentum is not used when it is set to :const:`0.0`. - weight_decay: (default: :const:`0.0`): + weight_decay: (default: :const:`0.0`) Weight decay, add L2 penalty to parameters. dampening: (default: :const:`0.0`) Dampening for momentum. diff --git a/torchopt/_src/transform.py b/torchopt/_src/transform.py index 7e737671..15bf11ed 100644 --- a/torchopt/_src/transform.py +++ b/torchopt/_src/transform.py @@ -32,7 +32,7 @@ # pylint: disable=invalid-name -from typing import Any, Callable, List, NamedTuple, Sequence +from typing import Any, Callable, List, NamedTuple, Optional, Sequence, Union import torch @@ -111,7 +111,7 @@ def _scale(step_size: float, *, already_flattened: bool = False) -> base.Gradien else: tree_map = pytree.tree_map - def init_fn(_): + def init_fn(params): # pylint: disable=unused-argument return ScaleState() def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument @@ -725,3 +725,173 @@ def f(g, m, n): return updates, ScaleByRStdDevState(mu=mu, nu=nu) return base.GradientTransformation(init_fn, update_fn) + + +class MaskedState(NamedTuple): + """Maintains inner transform state for masked transformations.""" + + inner_state: Any + + +class MaskedNode(NamedTuple): + """A node used to mask out unspecified parts of a tree. + + This node is ignored when mapping functions across the tree e.g. using + :func:`pytree.tree_map` since it is a container without children. It can + therefore be used to mask out parts of a tree. + """ + + +def masked( + inner: base.GradientTransformation, + mask: Union[Any, Callable[[base.Params], Any]], +) -> base.GradientTransformation: + """Mask updates so only some are transformed, the rest are passed through. + + For example, it is common to skip weight decay for BatchNorm scale and all + bias parameters. In many networks, these are the only parameters with only + one dimension. So, you may create a mask function to mask these out as + follows:: + mask_fn = lambda p: pytree.tree_map(lambda x: x.ndim != 1, p) + weight_decay = torchopt.masked(torchopt.add_decayed_weights(0.001), mask_fn) + You may alternatively create the mask pytree upfront:: + mask = pytree.tree_map(lambda x: x.ndim != 1, params) + weight_decay = torchopt.masked(torchopt.add_decayed_weights(0.001), mask) + For the ``inner`` transform, state will only be stored for the parameters that + have a mask value of ``True``. + + Args: + inner: Inner transformation to mask. + mask: a PyTree with same structure as (or a prefix of) the params PyTree, or + a Callable that returns such a pytree given the params/updates. The leaves + should be booleans, ``True`` for leaves/subtrees you want to apply the + transformation to, and ``False`` for those you want to skip. The mask must + be static for the gradient transformation to be jit-compilable. + + Returns: + New GradientTransformation wrapping ``inner``. + """ + return _masked( + inner=inner, + mask=mask, + already_flattened=False, + ) + + +def _masked( + inner: base.GradientTransformation, + mask: Union[Any, Callable[[base.Params], Any]], + *, + already_flattened: bool = False, +) -> base.GradientTransformation: + + if already_flattened: + tree_map = map_flattened + else: + tree_map = pytree.tree_map + + def tree_mask(params, mask_tree): + return tree_map(lambda p, m: p if m else MaskedNode(), params, mask_tree) + + def init_fn(params): + mask_tree = mask(params) if callable(mask) else mask + masked_params = tree_mask(params, mask_tree) + return MaskedState(inner_state=inner.init(masked_params)) + + def update_fn(updates, state, params=None, inplace=True): # pylint: disable=unused-argument + mask_tree = mask(updates) if callable(mask) else mask + masked_updates = tree_mask(updates, mask_tree) + masked_params = None if params is None else tree_mask(params, mask_tree) + + new_masked_updates, new_inner_state = inner.update( + masked_updates, state.inner_state, params=masked_params, inplace=inplace + ) + + new_updates = tree_map( + lambda new_u, old_u, m: new_u if m else old_u, new_masked_updates, updates, mask_tree + ) + return new_updates, MaskedState(inner_state=new_inner_state) + + return base.GradientTransformation(init_fn, update_fn) + + +AddDecayedWeightsState = base.EmptyState + + +# mypy: ignore-errors +def add_decayed_weights( + weight_decay: float = 0.0, + mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, +) -> base.GradientTransformation: + """Add parameter scaled by `weight_decay`. + + Args: + weight_decay: a scalar weight decay rate. + mask: a tree with same structure as (or a prefix of) the params PyTree, + or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the transformation to, and `False` for those you want to skip. + + Returns: + An (init_fn, update_fn) tuple. + """ + return _add_decayed_weights( + weight_decay=weight_decay, + mask=mask, + already_flattened=False, + ) + + +# mypy: ignore-errors +def _add_decayed_weights( + weight_decay: float = 0.0, + mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, + *, + already_flattened: bool = False, +) -> base.GradientTransformation: + if not 0.0 <= weight_decay: # pylint: disable=unneeded-not + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + + if weight_decay == 0.0 and mask is None: + return base.identity() + + if already_flattened: + tree_map = map_flattened + else: + tree_map = pytree.tree_map + + def init_fn(params): # pylint: disable=unused-argument + return AddDecayedWeightsState() + + def update_fn(updates, state, params=None, inplace=True): # pylint: disable=unused-argument + assert params is not None, ( + 'Parameters are required for weight decay. ' + 'Call `update(updates, state, params=params)` instead.' + ) + + if inplace: + + def f(g, p): + if g is not None: + if g.requires_grad: + return g.add_(p, alpha=weight_decay) + return g.add_(p.data, alpha=weight_decay) + return None + + else: + + def f(g, p): + return g.add(p, alpha=weight_decay) if g is not None else None + + updates = tree_map(f, updates, params) + return updates, state + + # If mask is not `None`, apply mask to the gradient transformation. + # E.g. it is common to skip weight decay on bias units and batch stats. + if mask is not None: + return _masked( + inner=base.GradientTransformation(init_fn, update_fn), + mask=mask, + already_flattened=already_flattened, + ) + return base.GradientTransformation(init_fn, update_fn) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy