diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index eb6753cc..0a6c4d6e 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -132,7 +132,7 @@ jobs: run: python .github/workflows/set_cibw_build.py - name: Build wheels - uses: pypa/cibuildwheel@v2.12.3 + uses: pypa/cibuildwheel@v2.15 env: CIBW_BUILD: ${{ env.CIBW_BUILD }} with: @@ -182,7 +182,7 @@ jobs: run: python .github/workflows/set_cibw_build.py - name: Build wheels - uses: pypa/cibuildwheel@v2.12.3 + uses: pypa/cibuildwheel@v2.15 env: CIBW_BUILD: ${{ env.CIBW_BUILD }} with: diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 19c0cf5b..b338b149 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -91,6 +91,11 @@ jobs: run: | make cpplint + - name: clang-tidy + run: | + sudo apt-get update && sudo apt-get install libomp-dev --yes + make clang-tidy + - name: clang-format run: | ( @@ -101,11 +106,6 @@ jobs: sudo apt-get update && sudo apt-get install clang-format --yes make clang-format - - name: clang-tidy - run: | - sudo apt-get update && sudo apt-get install libomp-dev --yes - make clang-tidy - - name: addlicense run: | make addlicense diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a16ff100..e5c37d40 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,6 +5,7 @@ ci: autofix_prs: true autofix_commit_msg: "fix: [pre-commit.ci] auto fixes [...]" autoupdate_commit_msg: "chore(pre-commit): [pre-commit.ci] autoupdate" + autoupdate_schedule: monthly default_stages: [commit, push, manual] repos: - repo: https://github.com/pre-commit/pre-commit-hooks @@ -25,11 +26,11 @@ repos: - id: debug-statements - id: double-quote-string-fixer - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v16.0.3 + rev: v16.0.6 hooks: - id: clang-format - - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.265 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.0.284 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] @@ -38,11 +39,11 @@ repos: hooks: - id: isort - repo: https://github.com/psf/black - rev: 23.3.0 + rev: 23.7.0 hooks: - id: black-jupyter - repo: https://github.com/asottile/pyupgrade - rev: v3.4.0 + rev: v3.10.1 hooks: - id: pyupgrade args: [--py38-plus] # sync with requires-python @@ -51,7 +52,7 @@ repos: ^examples/ ) - repo: https://github.com/pycqa/flake8 - rev: 6.0.0 + rev: 6.1.0 hooks: - id: flake8 additional_dependencies: @@ -67,7 +68,7 @@ repos: ^docs/source/conf.py$ ) - repo: https://github.com/codespell-project/codespell - rev: v2.2.4 + rev: v2.2.5 hooks: - id: codespell additional_dependencies: [".[toml]"] diff --git a/CHANGELOG.md b/CHANGELOG.md index cb158207..24e4eea4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ------ +## [0.7.2] - 2023-08-18 + +### Added + +- Implement `Adadelta`, `RAdam`, `Adamax` optimizer by [@JieRen98](https://github.com/JieRen98) and [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#171](https://github.com/metaopt/torchopt/pull/171). + +------ + ## [0.7.1] - 2023-05-12 ### Added @@ -187,7 +195,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ------ -[Unreleased]: https://github.com/metaopt/torchopt/compare/v0.7.1...HEAD +[Unreleased]: https://github.com/metaopt/torchopt/compare/v0.7.2...HEAD +[0.7.2]: https://github.com/metaopt/torchopt/compare/v0.7.1...v0.7.2 [0.7.1]: https://github.com/metaopt/torchopt/compare/v0.7.0...v0.7.1 [0.7.0]: https://github.com/metaopt/torchopt/compare/v0.6.0...v0.7.0 [0.6.0]: https://github.com/metaopt/torchopt/compare/v0.5.0...v0.6.0 diff --git a/CITATION.cff b/CITATION.cff index e7cf54cb..965b6a7f 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -32,7 +32,7 @@ authors: family-names: Yang affiliation: Peking University email: yaodong.yang@pku.edu.cn -version: 0.7.1 -date-released: "2023-05-12" +version: 0.7.2 +date-released: "2023-08-18" license: Apache-2.0 repository-code: "https://github.com/metaopt/torchopt" diff --git a/Makefile b/Makefile index 906d8a64..0f7dd74e 100644 --- a/Makefile +++ b/Makefile @@ -112,9 +112,9 @@ addlicense-install: go-install # Tests pytest: test-install - cd tests && $(PYTHON) -c 'import $(PROJECT_NAME)' && \ + cd tests && $(PYTHON) -c 'import $(PROJECT_PATH)' && \ $(PYTHON) -m pytest --verbose --color=yes --durations=0 \ - --cov="$(PROJECT_NAME)" --cov-config=.coveragerc --cov-report=xml --cov-report=term-missing \ + --cov="$(PROJECT_PATH)" --cov-config=.coveragerc --cov-report=xml --cov-report=term-missing \ $(PYTESTOPTS) . test: pytest @@ -128,7 +128,7 @@ flake8: flake8-install $(PYTHON) -m flake8 --count --show-source --statistics py-format: py-format-install - $(PYTHON) -m isort --project $(PROJECT_NAME) --check $(PYTHON_FILES) && \ + $(PYTHON) -m isort --project $(PROJECT_PATH) --check $(PYTHON_FILES) && \ $(PYTHON) -m black --check $(PYTHON_FILES) tutorials ruff: ruff-install @@ -189,7 +189,7 @@ clean-docs: lint: ruff flake8 py-format mypy pylint clang-format clang-tidy cpplint addlicense docstyle spelling format: py-format-install ruff-install clang-format-install addlicense-install - $(PYTHON) -m isort --project $(PROJECT_NAME) $(PYTHON_FILES) + $(PYTHON) -m isort --project $(PROJECT_PATH) $(PYTHON_FILES) $(PYTHON) -m black $(PYTHON_FILES) tutorials $(PYTHON) -m ruff check . --fix --exit-zero $(CLANG_FORMAT) -style=file -i $(CXX_FILES) $(CUDA_FILES) diff --git a/README.md b/README.md index 5bc474fc..ee1905ab 100644 --- a/README.md +++ b/README.md @@ -136,7 +136,7 @@ On top of the same optimization function as `torch.optim`, an important benefit This is particularly helpful when the algorithm requires differentiation through optimization updates (such as meta-learning practices). We take as the inputs the gradients and optimizer states, and use non-in-place operators to compute and output the updates. The processes can be automatically implemented, with the only need from users being to pass the argument `inplace=False` to the functions. -Check out the section [Explicit Gradient](#explicit-gradient-eg) (EG)](#explicit-gradient-eg) functional API for example. +Check out the section [Explicit Gradient (EG)](#explicit-gradient-eg) functional API for example. -------------------------------------------------------------------------------- @@ -215,7 +215,7 @@ Refer to the example [iMAML](https://github.com/waterhorse1/torchopt/tree/readme #### Functional API -For the implicit gradient, users need to define the stationary condition and TorchOpt provides the decorator to wrap the solve function for enabling implicit gradient computation. +For the implicit gradient, similar to [JAXopt](https://jaxopt.github.io/stable/implicit_diff.html), users need to define the stationary condition and TorchOpt provides the decorator to wrap the solve function for enabling implicit gradient computation. ```python # The stationary condition for the inner-loop diff --git a/conda-recipe-minimal-cpu.yaml b/conda-recipe-minimal-cpu.yaml new file mode 100644 index 00000000..0404f10c --- /dev/null +++ b/conda-recipe-minimal-cpu.yaml @@ -0,0 +1,49 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Create virtual environment with command: +# +# $ conda env create --file conda-recipe-minimal-cpu.yaml +# + +name: torchopt + +channels: + - pytorch + - defaults + - conda-forge + +dependencies: + - python = 3.10 + - pip + + # Learning + - pytorch::pytorch >= 1.13 # sync with project.dependencies + - pytorch::torchvision + - pytorch::pytorch-mutex = *=*cpu* + - pip: + - torchviz + + # Build toolchain + - cmake >= 3.11 + - make + - cxx-compiler + - pybind11 >= 2.10.1 + + # Misc + - optree >= 0.4.1 + - typing-extensions >= 4.0.0 + - numpy + - python-graphviz diff --git a/docs/source/api/api.rst b/docs/source/api/api.rst index d00e2333..0112e877 100644 --- a/docs/source/api/api.rst +++ b/docs/source/api/api.rst @@ -30,9 +30,12 @@ Functional Optimizers .. autosummary:: FuncOptimizer + adadelta adagrad adam adamw + adamax + radam rmsprop sgd @@ -42,6 +45,11 @@ Wrapper for Function Optimizer .. autoclass:: FuncOptimizer :members: +Functional AdaDelta Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: adadelta + Functional AdaGrad Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -57,6 +65,16 @@ Functional AdamW Optimizer .. autofunction:: adamw +Functional AdaMax Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: adamax + +Functional RAdam Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: radam + Functional RMSProp Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -76,12 +94,23 @@ Classic Optimizers .. autosummary:: + AdaDelta + Adadelta AdaGrad + Adagrad Adam AdamW + AdaMax + Adamax + RAdam RMSProp SGD +Classic AdaDelta Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: AdaDelta + Classic AdaGrad Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -97,6 +126,16 @@ Classic AdamW Optimizer .. autoclass:: AdamW +Classic AdaMax Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: AdaMax + +Classic RAdam Optimizer +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: RAdam + Classic RMSProp Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -116,12 +155,23 @@ Differentiable Meta-Optimizers .. autosummary:: + MetaAdaDelta + MetaAdadelta MetaAdaGrad + MetaAdagrad MetaAdam MetaAdamW + MetaAdaMax + MetaAdamax + MetaRAdam MetaRMSProp MetaSGD +Differentiable Meta-AdaDelta Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: MetaAdaDelta + Differentiable Meta-AdaGrad Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -137,6 +187,16 @@ Differentiable Meta-AdamW Optimizer .. autoclass:: MetaAdamW +Differentiable Meta-AdaMax Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: MetaAdaMax + +Differentiable Meta-RAdam Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: MetaRAdam + Differentiable Meta-RMSProp Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/explicit_diff/explicit_diff.rst b/docs/source/explicit_diff/explicit_diff.rst index f6b82826..9445adb8 100644 --- a/docs/source/explicit_diff/explicit_diff.rst +++ b/docs/source/explicit_diff/explicit_diff.rst @@ -53,9 +53,15 @@ For PyTorch-like API (e.g., ``step()``), we designed a base class :class:`torcho .. autosummary:: torchopt.MetaOptimizer + torchopt.MetaAdaDelta + torchopt.MetaAdadelta torchopt.MetaAdaGrad + torchopt.MetaAdagrad torchopt.MetaAdam torchopt.MetaAdamW + torchopt.AdaMax + torchopt.MetaAdamax + torchopt.MetaRAdam torchopt.MetaRMSProp torchopt.MetaSGD diff --git a/docs/source/optimizer/optim.rst b/docs/source/optimizer/optim.rst index 54c8ef71..4f2e17f8 100644 --- a/docs/source/optimizer/optim.rst +++ b/docs/source/optimizer/optim.rst @@ -18,9 +18,12 @@ Currently, TorchOpt supports 4 functional optimizers: :func:`sgd`, :func:`adam`, .. autosummary:: torchopt.FuncOptimizer + torchopt.adadelta torchopt.adagrad torchopt.adam torchopt.adamw + torchopt.adamax + torchopt.radam torchopt.rmsprop torchopt.sgd @@ -85,9 +88,15 @@ We offer original PyTorch APIs (e.g., ``zero_grad()`` or ``step()``) for traditi .. autosummary:: torchopt.Optimizer + torchopt.AdaDelta + torchopt.Adadelta torchopt.AdaGrad + torchopt.Adagrad torchopt.Adam torchopt.AdamW + torchopt.AdaMax + torchopt.Adamax + torchopt.RAdam torchopt.RMSProp torchopt.SGD diff --git a/docs/source/spelling_wordlist.txt b/docs/source/spelling_wordlist.txt index 49fdbb69..6e0cca78 100644 --- a/docs/source/spelling_wordlist.txt +++ b/docs/source/spelling_wordlist.txt @@ -175,3 +175,10 @@ ctx Duchi invertible AdaGrad +Adadelta +Zeiler +radam +adamax +RAdam +AdaDelta +AdaMax diff --git a/setup.py b/setup.py index cce04c65..dc1103df 100644 --- a/setup.py +++ b/setup.py @@ -107,9 +107,7 @@ def build_extension(self, ext): ], } -TORCHOPT_NO_EXTENSIONS = ( - bool(os.getenv('TORCHOPT_NO_EXTENSIONS', '')) or WINDOWS or (MACOS and CIBUILDWHEEL) -) +TORCHOPT_NO_EXTENSIONS = bool(os.getenv('TORCHOPT_NO_EXTENSIONS', '')) or WINDOWS or MACOS if TORCHOPT_NO_EXTENSIONS: ext_kwargs.clear() diff --git a/tests/helpers.py b/tests/helpers.py index bedf0fb6..50451496 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -142,7 +142,7 @@ def get_model(): @torch.no_grad() def get_models( - device: torch.types.Device = None, + device: torch.types.Device | None = None, dtype: torch.dtype = torch.float32, ) -> tuple[nn.Module, nn.Module, nn.Module, data.DataLoader]: seed_everything(seed=42) @@ -200,7 +200,7 @@ def assert_model_all_close( def assert_all_close( actual: torch.Tensor, expected: torch.Tensor, - base: torch.Tensor = None, + base: torch.Tensor | None = None, rtol: float | None = None, atol: float | None = None, equal_nan: bool = False, diff --git a/tests/test_alias.py b/tests/test_alias.py index a0a78129..aef35b96 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -144,6 +144,63 @@ def test_sgd( _set_use_chain_flat(True) +@helpers.parametrize( + dtype=[torch.float64], + lr=[1e-2, 1e-3, 1e-4], + rho=[0.9, 0.95], + eps=[1e-8], + inplace=[True, False], + weight_decay=[0.0, 1e-2], + use_chain_flat=[True, False], +) +def test_adadelta( + dtype: torch.dtype, + lr: float, + rho: float, + eps: float, + inplace: bool, + weight_decay: float, + use_chain_flat: bool, +) -> None: + _set_use_chain_flat(use_chain_flat) + + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + fmodel, params, buffers = functorch.make_functional_with_buffers(model) + optim = torchopt.adadelta( + lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + ) + optim_state = optim.init(params) + optim_ref = torch.optim.Adadelta( + model_ref.parameters(), + lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + ) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = fmodel(params, buffers, xs) + pred_ref = model_ref(xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + grads = torch.autograd.grad(loss, params, allow_unused=True) + updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) + params = torchopt.apply_updates(params, updates, inplace=inplace) + + optim_ref.zero_grad() + loss_ref.backward() + optim_ref.step() + + helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) + _set_use_chain_flat(True) + + @helpers.parametrize( dtype=[torch.float64], lr=[1e-2, 1e-3, 1e-4], @@ -210,6 +267,120 @@ def test_adam( _set_use_chain_flat(True) +@helpers.parametrize( + dtype=[torch.float64], + lr=[1e-2, 1e-3, 1e-4], + betas=[(0.9, 0.999), (0.95, 0.9995)], + eps=[1e-8], + inplace=[True, False], + weight_decay=[0.0, 1e-2], + use_chain_flat=[True, False], +) +def test_radam( + dtype: torch.dtype, + lr: float, + betas: tuple[float, float], + eps: float, + inplace: bool, + weight_decay: float, + use_chain_flat: bool, +) -> None: + _set_use_chain_flat(use_chain_flat) + + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + fmodel, params, buffers = functorch.make_functional_with_buffers(model) + optim = torchopt.radam( + lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + ) + optim_state = optim.init(params) + optim_ref = torch.optim.RAdam( + model_ref.parameters(), + lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + ) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = fmodel(params, buffers, xs) + pred_ref = model_ref(xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + grads = torch.autograd.grad(loss, params, allow_unused=True) + updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) + params = torchopt.apply_updates(params, updates, inplace=inplace) + + optim_ref.zero_grad() + loss_ref.backward() + optim_ref.step() + + helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) + _set_use_chain_flat(True) + + +@helpers.parametrize( + dtype=[torch.float64], + lr=[1e-2, 1e-3, 1e-4], + betas=[(0.9, 0.999), (0.95, 0.9995)], + eps=[1e-8], + inplace=[True, False], + weight_decay=[0.0, 1e-2], + use_chain_flat=[True, False], +) +def test_adamax( + dtype: torch.dtype, + lr: float, + betas: tuple[float, float], + eps: float, + inplace: bool, + weight_decay: float, + use_chain_flat: bool, +) -> None: + _set_use_chain_flat(use_chain_flat) + + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + fmodel, params, buffers = functorch.make_functional_with_buffers(model) + optim = torchopt.adamax( + lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + ) + optim_state = optim.init(params) + optim_ref = torch.optim.Adamax( + model_ref.parameters(), + lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + ) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = fmodel(params, buffers, xs) + pred_ref = model_ref(xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + grads = torch.autograd.grad(loss, params, allow_unused=True) + updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) + params = torchopt.apply_updates(params, updates, inplace=inplace) + + optim_ref.zero_grad() + loss_ref.backward() + optim_ref.step() + + helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) + _set_use_chain_flat(True) + + @helpers.parametrize( dtype=[torch.float64], outer_lr=[1e-2, 1e-3, 1e-4], diff --git a/tests/test_implicit.py b/tests/test_implicit.py index db19f829..61623a17 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -181,8 +181,7 @@ def imaml_objective_jax(params, meta_params, x, y): regularization_loss = 0 for p1, p2 in zip(params.values(), meta_params.values()): regularization_loss += 0.5 * jnp.sum(jnp.square(p1 - p2)) - loss = loss + regularization_loss - return loss + return loss + regularization_loss @jaxopt.implicit_diff.custom_root( jax.grad(imaml_objective_jax, argnums=0), @@ -310,8 +309,7 @@ def imaml_objective_jax(params, meta_params, x, y): regularization_loss = 0 for p1, p2 in zip(params.values(), meta_params.values()): regularization_loss += 0.5 * jnp.sum(jnp.square(p1 - p2)) - loss = loss + regularization_loss - return loss + return loss + regularization_loss @jaxopt.implicit_diff.custom_root( jax.grad(imaml_objective_jax, argnums=0), @@ -398,8 +396,7 @@ def objective(self, x, y): regularization_loss = 0 for p1, p2 in zip(self.parameters(), self.meta_parameters()): regularization_loss += 0.5 * torch.sum(torch.square(p1 - p2)) - loss = loss + regularization_loss - return loss + return loss + regularization_loss def solve(self, x, y): params = tuple(self.parameters()) @@ -425,8 +422,7 @@ def imaml_objective_jax(params, meta_params, x, y): regularization_loss = 0 for p1, p2 in zip(params.values(), meta_params.values()): regularization_loss += 0.5 * jnp.sum(jnp.square(p1 - p2)) - loss = loss + regularization_loss - return loss + return loss + regularization_loss @jaxopt.implicit_diff.custom_root(jax.grad(imaml_objective_jax, argnums=0), has_aux=True) def inner_solver_jax(params, meta_params, x, y): diff --git a/tests/test_import.py b/tests/test_import.py index 1b6dea38..f7523756 100644 --- a/tests/test_import.py +++ b/tests/test_import.py @@ -25,17 +25,24 @@ def test_accelerated_op_import() -> None: def test_alias_import() -> None: + torchopt.adadelta torchopt.adagrad torchopt.adam torchopt.adamw + torchopt.adamax + torchopt.radam torchopt.rmsprop torchopt.sgd + torchopt.alias.adadelta + torchopt.alias.adagrad torchopt.alias.adam torchopt.alias.adamw + torchopt.alias.adamax + torchopt.alias.radam torchopt.alias.rmsprop torchopt.alias.sgd - from torchopt import adagrad, adam, adamw, rmsprop, sgd - from torchopt.alias import adagrad, adam, adamw, rmsprop, sgd + from torchopt import adadelta, adagrad, adam, adamax, adamw, radam, rmsprop, sgd + from torchopt.alias import adadelta, adagrad, adam, adamax, adamw, radam, rmsprop, sgd def test_diff_import() -> None: @@ -108,25 +115,38 @@ def test_nn_import() -> None: def test_optim_import() -> None: torchopt.FuncOptimizer + torchopt.MetaAdaDelta + torchopt.MetaAdadelta torchopt.MetaAdaGrad torchopt.MetaAdagrad torchopt.MetaAdam torchopt.MetaAdamW + torchopt.MetaAdaMax + torchopt.MetaAdamax + torchopt.MetaRAdam torchopt.MetaRMSProp torchopt.MetaRMSprop torchopt.MetaSGD + torchopt.AdaDelta + torchopt.Adadelta torchopt.AdaGrad torchopt.Adagrad torchopt.Adam torchopt.AdamW + torchopt.AdaMax + torchopt.Adamax torchopt.Optimizer torchopt.RMSProp torchopt.RMSprop torchopt.SGD + torchopt.optim.meta.MetaAdaDelta + torchopt.optim.meta.MetaAdadelta torchopt.optim.meta.MetaAdaGrad torchopt.optim.meta.MetaAdagrad torchopt.optim.meta.MetaAdam torchopt.optim.meta.MetaAdamW + torchopt.optim.meta.MetaAdaMax + torchopt.optim.meta.MetaAdamax torchopt.optim.meta.MetaRMSProp torchopt.optim.meta.MetaRMSprop torchopt.optim.meta.MetaSGD @@ -139,14 +159,22 @@ def test_optim_import() -> None: torchopt.optim.func.FuncOptimizer from torchopt import ( SGD, + AdaDelta, + Adadelta, AdaGrad, Adagrad, Adam, + AdaMax, + Adamax, AdamW, FuncOptimizer, + MetaAdaDelta, + MetaAdadelta, MetaAdaGrad, MetaAdagrad, MetaAdam, + MetaAdaMax, + MetaAdamax, MetaAdamW, MetaOptimizer, MetaRMSprop, @@ -158,11 +186,16 @@ def test_optim_import() -> None: from torchopt.optim import SGD, Adam, AdamW, FuncOptimizer, Optimizer, RMSProp from torchopt.optim.func import FuncOptimizer from torchopt.optim.meta import ( + MetaAdaDelta, + MetaAdadelta, MetaAdaGrad, MetaAdagrad, MetaAdam, + MetaAdaMax, + MetaAdamax, MetaAdamW, MetaOptimizer, + MetaRAdam, MetaRMSProp, MetaRMSprop, MetaSGD, diff --git a/tests/test_optim.py b/tests/test_optim.py index 6ec81918..dc3941d9 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -144,6 +144,153 @@ def test_Adam( helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype) +@helpers.parametrize( + dtype=[torch.float64], + lr=[1e-2, 1e-3, 1e-4], + rho=[0.9, 0.95], + eps=[1e-8], + weight_decay=[0.0, 1e-2], +) +def test_Adadelta( + dtype: torch.dtype, + lr: float, + rho: float, + eps: float, + weight_decay: float, +) -> None: + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + optim = torchopt.Adadelta( + model.parameters(), + lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + ) + optim_ref = torch.optim.Adadelta( + model_ref.parameters(), + lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + ) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = model(xs) + pred_ref = model_ref(xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + optim.zero_grad() + loss.backward() + optim.step() + + optim_ref.zero_grad() + loss_ref.backward() + optim_ref.step() + + helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype) + + +@helpers.parametrize( + dtype=[torch.float64], + lr=[1e-2, 1e-3, 1e-4], + betas=[(0.9, 0.999), (0.95, 0.9995)], + eps=[1e-8], + weight_decay=[0.0, 1e-2], +) +def test_RAdam( + dtype: torch.dtype, + lr: float, + betas: tuple[float, float], + eps: float, + weight_decay: float, +) -> None: + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + optim = torchopt.RAdam( + model.parameters(), + lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + ) + optim_ref = torch.optim.RAdam( + model_ref.parameters(), + lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + ) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = model(xs) + pred_ref = model_ref(xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + optim.zero_grad() + loss.backward() + optim.step() + + optim_ref.zero_grad() + loss_ref.backward() + optim_ref.step() + + helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype) + + +@helpers.parametrize( + dtype=[torch.float64], + lr=[1e-2, 1e-3, 1e-4], + betas=[(0.9, 0.999), (0.95, 0.9995)], + eps=[1e-8], + weight_decay=[0.0, 1e-2], +) +def test_Adamax( + dtype: torch.dtype, + lr: float, + betas: tuple[float, float], + eps: float, + weight_decay: float, +) -> None: + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + optim = torchopt.Adamax( + model.parameters(), + lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + ) + optim_ref = torch.optim.Adamax( + model_ref.parameters(), + lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + ) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = model(xs) + pred_ref = model_ref(xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + optim.zero_grad() + loss.backward() + optim.step() + + optim_ref.zero_grad() + loss_ref.backward() + optim_ref.step() + + helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype) + + @helpers.parametrize( dtype=[torch.float64], lr=[1e-2, 1e-3, 1e-4], diff --git a/torchopt/__init__.py b/torchopt/__init__.py index a8c9fa1d..a089f3dc 100644 --- a/torchopt/__init__.py +++ b/torchopt/__init__.py @@ -33,18 +33,37 @@ visual, ) from torchopt.accelerated_op import is_available as accelerated_op_available -from torchopt.alias import adagrad, adam, adamw, rmsprop, sgd +from torchopt.alias import adadelta, adagrad, adam, adamax, adamw, radam, rmsprop, sgd from torchopt.clip import clip_grad_norm from torchopt.combine import chain from torchopt.hook import register_hook -from torchopt.optim import SGD, AdaGrad, Adagrad, Adam, AdamW, Optimizer, RMSProp, RMSprop +from torchopt.optim import ( + SGD, + AdaDelta, + Adadelta, + AdaGrad, + Adagrad, + Adam, + AdaMax, + Adamax, + AdamW, + Optimizer, + RAdam, + RMSProp, + RMSprop, +) from torchopt.optim.func import FuncOptimizer from torchopt.optim.meta import ( + MetaAdaDelta, + MetaAdadelta, MetaAdaGrad, MetaAdagrad, MetaAdam, + MetaAdaMax, + MetaAdamax, MetaAdamW, MetaOptimizer, + MetaRAdam, MetaRMSProp, MetaRMSprop, MetaSGD, @@ -64,6 +83,9 @@ __all__ = [ 'accelerated_op_available', 'adam', + 'adamax', + 'adadelta', + 'radam', 'adamw', 'adagrad', 'rmsprop', @@ -75,6 +97,11 @@ 'Optimizer', 'SGD', 'Adam', + 'AdaMax', + 'Adamax', + 'AdaDelta', + 'Adadelta', + 'RAdam', 'AdamW', 'AdaGrad', 'Adagrad', @@ -83,6 +110,11 @@ 'MetaOptimizer', 'MetaSGD', 'MetaAdam', + 'MetaAdaMax', + 'MetaAdamax', + 'MetaAdaDelta', + 'MetaAdadelta', + 'MetaRAdam', 'MetaAdamW', 'MetaAdaGrad', 'MetaAdagrad', diff --git a/torchopt/alias/__init__.py b/torchopt/alias/__init__.py index ae7dd2b5..3ea721c4 100644 --- a/torchopt/alias/__init__.py +++ b/torchopt/alias/__init__.py @@ -31,11 +31,14 @@ # ============================================================================== r"""The aliases of preset :class:`GradientTransformation`\s for optimizers.""" +from torchopt.alias.adadelta import adadelta from torchopt.alias.adagrad import adagrad from torchopt.alias.adam import adam +from torchopt.alias.adamax import adamax from torchopt.alias.adamw import adamw +from torchopt.alias.radam import radam from torchopt.alias.rmsprop import rmsprop from torchopt.alias.sgd import sgd -__all__ = ['adagrad', 'adam', 'adamw', 'rmsprop', 'sgd'] +__all__ = ['adagrad', 'radam', 'adam', 'adamax', 'adadelta', 'adamw', 'rmsprop', 'sgd'] diff --git a/torchopt/alias/adadelta.py b/torchopt/alias/adadelta.py new file mode 100644 index 00000000..2e3640f2 --- /dev/null +++ b/torchopt/alias/adadelta.py @@ -0,0 +1,98 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset :class:`GradientTransformation` for the Adadelta optimizer.""" + +from __future__ import annotations + +from torchopt.alias.utils import ( + _get_use_chain_flat, + flip_sign_and_add_weight_decay, + scale_by_neg_lr, +) +from torchopt.combine import chain +from torchopt.transform import scale_by_adadelta +from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +__all__ = ['adadelta'] + + +# pylint: disable-next=too-many-arguments +def adadelta( + lr: ScalarOrSchedule = 1e-3, + rho: float = 0.9, + eps: float = 1e-6, + weight_decay: float = 0.0, + *, + moment_requires_grad: bool = False, +) -> GradientTransformation: + """Create a functional version of the AdaDelta optimizer. + + Adadelta is a per-dimension learning rate method for gradient descent. + + References: + - Zeiler, 2012: https://arxiv.org/abs/1212.5701 + + Args: + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + rho (float, optional): Coefficients used for computing running averages of gradient and its square. + (default: :const:`0.9`) + eps (float, optional): A small constant applied to the square root (as in the Adadelta paper) + to avoid dividing by zero when rescaling. + (default: :const:`1e-6`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created with + flag ``requires_grad=True``, this flag is often used in Meta-Learning algorithms. + (default: :data:`False`) + + Returns: + The corresponding :class:`GradientTransformation` instance. + + See Also: + The functional optimizer wrapper :class:`torchopt.FuncOptimizer`. + """ + # pylint: disable=unneeded-not + if not (callable(lr) or lr >= 0.0): # pragma: no cover + raise ValueError(f'Invalid learning rate: {lr}') + if not 0 <= rho <= 1: # pragma: no cover + raise ValueError(f'Invalid rho value: {rho}') + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') + if not weight_decay >= 0.0: # pragma: no cover + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + # pylint: enable=unneeded-not + + chain_fn = chain + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay + adadelta_scaler_fn = scale_by_adadelta + scale_by_neg_lr_fn = scale_by_neg_lr + + if _get_use_chain_flat(): # default behavior + chain_fn = chain_fn.flat # type: ignore[attr-defined] + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay_fn.flat # type: ignore[attr-defined] + adadelta_scaler_fn = adadelta_scaler_fn.flat # type: ignore[attr-defined] + scale_by_neg_lr_fn = scale_by_neg_lr_fn.flat # type: ignore[attr-defined] + + return chain_fn( + flip_sign_and_add_weight_decay_fn(weight_decay=weight_decay, maximize=False), + adadelta_scaler_fn( + rho=rho, + eps=eps, + moment_requires_grad=moment_requires_grad, + ), + scale_by_neg_lr_fn(lr), + ) diff --git a/torchopt/alias/adagrad.py b/torchopt/alias/adagrad.py index 25910abd..3f983c38 100644 --- a/torchopt/alias/adagrad.py +++ b/torchopt/alias/adagrad.py @@ -91,7 +91,7 @@ def adagrad( *, maximize: bool = False, ) -> GradientTransformation: - """The functional AdaGrad optimizer. + """Create a functional version of the AdaGrad optimizer. AdaGrad is an algorithm for gradient based optimization that anneals the learning rate for each parameter during the course of training. diff --git a/torchopt/alias/adamax.py b/torchopt/alias/adamax.py new file mode 100644 index 00000000..ffa19e37 --- /dev/null +++ b/torchopt/alias/adamax.py @@ -0,0 +1,100 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset :class:`GradientTransformation` for the Adamax optimizer.""" + +from __future__ import annotations + +from torchopt.alias.utils import ( + _get_use_chain_flat, + flip_sign_and_add_weight_decay, + scale_by_neg_lr, +) +from torchopt.combine import chain +from torchopt.transform import scale_by_adamax +from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +__all__ = ['adamax'] + + +# pylint: disable-next=too-many-arguments +def adamax( + lr: ScalarOrSchedule = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-6, + weight_decay: float = 0.0, + *, + moment_requires_grad: bool = False, +) -> GradientTransformation: + """Create a functional version of the AdaMax optimizer. + + References: + - Kingma et al., 2014: https://arxiv.org/abs/1412.6980 + + Args: + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to the square root (as in the RAdam paper) + to avoid dividing by zero when rescaling. + (default: :const:`1e-6`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created with + flag ``requires_grad=True``, this flag is often used in Meta-Learning algorithms. + (default: :data:`False`) + + Returns: + The corresponding :class:`GradientTransformation` instance. + + See Also: + The functional optimizer wrapper :class:`torchopt.FuncOptimizer`. + """ + b1, b2 = betas # pylint: disable=invalid-name + # pylint: disable=unneeded-not + if not (callable(lr) or lr >= 0.0): # pragma: no cover + raise ValueError(f'Invalid learning rate: {lr}') + if not 0 <= b1 <= 1: # pragma: no cover + raise ValueError(f'Invalid rho value: {b1}') + if not 0 <= b2 <= 1: # pragma: no cover + raise ValueError(f'Invalid rho value: {b2}') + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') + if not weight_decay >= 0.0: # pragma: no cover + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + # pylint: enable=unneeded-not + + chain_fn = chain + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay + adamax_scaler_fn = scale_by_adamax + scale_by_neg_lr_fn = scale_by_neg_lr + + if _get_use_chain_flat(): # default behavior + chain_fn = chain_fn.flat # type: ignore[attr-defined] + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay_fn.flat # type: ignore[attr-defined] + adamax_scaler_fn = adamax_scaler_fn.flat # type: ignore[attr-defined] + scale_by_neg_lr_fn = scale_by_neg_lr_fn.flat # type: ignore[attr-defined] + + return chain_fn( + flip_sign_and_add_weight_decay_fn(weight_decay=weight_decay, maximize=False), + adamax_scaler_fn( + b1=b1, + b2=b2, + eps=eps, + moment_requires_grad=moment_requires_grad, + ), + scale_by_neg_lr_fn(lr), + ) diff --git a/torchopt/alias/radam.py b/torchopt/alias/radam.py new file mode 100644 index 00000000..230c1151 --- /dev/null +++ b/torchopt/alias/radam.py @@ -0,0 +1,102 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset :class:`GradientTransformation` for the RAdam optimizer.""" + +from __future__ import annotations + +from torchopt.alias.utils import ( + _get_use_chain_flat, + flip_sign_and_add_weight_decay, + scale_by_neg_lr, +) +from torchopt.combine import chain +from torchopt.transform import scale_by_radam +from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +__all__ = ['radam'] + + +# pylint: disable-next=too-many-arguments +def radam( + lr: ScalarOrSchedule = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-6, + weight_decay: float = 0.0, + *, + moment_requires_grad: bool = False, +) -> GradientTransformation: + """Create a functional version of the RAdam optimizer. + + RAdam is a variance of the adaptive learning rate rectified optimizer. + + References: + - Liu, 2019: https://arxiv.org/abs/1908.03265 + + Args: + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to the square root (as in the RAdam paper) + to avoid dividing by zero when rescaling. + (default: :const:`1e-6`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created with + flag ``requires_grad=True``, this flag is often used in Meta-Learning algorithms. + (default: :data:`False`) + + Returns: + The corresponding :class:`GradientTransformation` instance. + + See Also: + The functional optimizer wrapper :class:`torchopt.FuncOptimizer`. + """ + b1, b2 = betas # pylint: disable=invalid-name + # pylint: disable=unneeded-not + if not (callable(lr) or lr >= 0.0): # pragma: no cover + raise ValueError(f'Invalid learning rate: {lr}') + if not 0 <= b1 <= 1: # pragma: no cover + raise ValueError(f'Invalid rho value: {b1}') + if not 0 <= b2 <= 1: # pragma: no cover + raise ValueError(f'Invalid rho value: {b2}') + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') + if not weight_decay >= 0.0: # pragma: no cover + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + # pylint: enable=unneeded-not + + chain_fn = chain + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay + radam_scaler_fn = scale_by_radam + scale_by_neg_lr_fn = scale_by_neg_lr + + if _get_use_chain_flat(): # default behavior + chain_fn = chain_fn.flat # type: ignore[attr-defined] + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay_fn.flat # type: ignore[attr-defined] + radam_scaler_fn = radam_scaler_fn.flat # type: ignore[attr-defined] + scale_by_neg_lr_fn = scale_by_neg_lr_fn.flat # type: ignore[attr-defined] + + return chain_fn( + flip_sign_and_add_weight_decay_fn(weight_decay=weight_decay, maximize=False), + radam_scaler_fn( + b1=b1, + b2=b2, + eps=eps, + moment_requires_grad=moment_requires_grad, + ), + scale_by_neg_lr_fn(lr), + ) diff --git a/torchopt/alias/utils.py b/torchopt/alias/utils.py index 1e626810..5c8dc97a 100644 --- a/torchopt/alias/utils.py +++ b/torchopt/alias/utils.py @@ -209,7 +209,7 @@ def _scale_by_neg_lr( if callable(lr): def schedule_wrapper(count: Numeric) -> Numeric: - return -lr(count) # type: ignore[operator] + return -lr(count) return scale_by_schedule.impl( # type: ignore[attr-defined] schedule_wrapper, diff --git a/torchopt/base.py b/torchopt/base.py index b0a40afa..cab2b49f 100644 --- a/torchopt/base.py +++ b/torchopt/base.py @@ -36,6 +36,7 @@ import itertools from abc import abstractmethod from typing import TYPE_CHECKING, Callable, NamedTuple, Protocol +from typing_extensions import Self # Python 3.11+ if TYPE_CHECKING: @@ -159,7 +160,7 @@ class ChainedGradientTransformation(GradientTransformation): transformations: tuple[GradientTransformation, ...] - def __new__(cls, *transformations: GradientTransformation) -> ChainedGradientTransformation: + def __new__(cls, *transformations: GradientTransformation) -> Self: """Create a new chained gradient transformation.""" transformations = tuple( itertools.chain.from_iterable( @@ -235,7 +236,7 @@ def __reduce__(self) -> tuple[Callable, tuple[tuple[GradientTransformation, ...] class IdentityGradientTransformation(GradientTransformation): """A gradient transformation that does nothing.""" - def __new__(cls) -> IdentityGradientTransformation: + def __new__(cls) -> Self: """Create a new gradient transformation that does nothing.""" return super().__new__(cls, init=cls.init_fn, update=cls.update_fn) diff --git a/torchopt/clip.py b/torchopt/clip.py index 69da9afd..eda4bef3 100644 --- a/torchopt/clip.py +++ b/torchopt/clip.py @@ -33,16 +33,16 @@ def clip_grad_norm( - max_norm: float | int, - norm_type: float | int = 2.0, + max_norm: float, + norm_type: float = 2.0, error_if_nonfinite: bool = False, ) -> GradientTransformation: """Clip gradient norm of an iterable of parameters. Args: - max_norm (float or int): The maximum absolute value for each element in the update. - norm_type (float or int, optional): Type of the used p-norm. Can be ``'inf'`` for infinity - norm. (default: :const:`2.0`) + max_norm (float): The maximum absolute value for each element in the update. + norm_type (float, optional): Type of the used p-norm. Can be ``'inf'`` for infinity norm. + (default: :const:`2.0`) error_if_nonfinite (bool, optional): If :data:`True`, an error is thrown if the total norm of the gradients from ``updates`` is ``nan``, ``inf``, or ``-inf``. (default: :data:`False`) diff --git a/torchopt/diff/implicit/decorator.py b/torchopt/diff/implicit/decorator.py index 031aa11f..03720a49 100644 --- a/torchopt/diff/implicit/decorator.py +++ b/torchopt/diff/implicit/decorator.py @@ -12,6 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +# This file is modified from: +# https://github.com/google/jaxopt/blob/main/jaxopt/_src/implicit_diff.py +# ============================================================================== +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== """Implicit Meta-Gradient.""" # pylint: disable=invalid-name @@ -257,7 +274,7 @@ def _custom_root( def make_custom_vjp_solver_fn( solver_fn: Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]], kwarg_keys: Sequence[str], - args_signs: tuple[tuple[int, int, type[tuple] | type[list] | None], ...], + args_signs: tuple[tuple[int, int, type[tuple | list] | None], ...], ) -> type[Function]: # pylint: disable-next=missing-class-docstring,abstract-method class ImplicitMetaGradient(Function): @@ -379,7 +396,7 @@ def wrapped_solver_fn( args, kwargs = _signature_bind(solver_fn_signature, *args, **kwargs) keys, vals = list(kwargs.keys()), list(kwargs.values()) - args_signs: list[tuple[int, int, type[tuple] | type[list] | None]] = [] + args_signs: list[tuple[int, int, type[tuple | list] | None]] = [] flat_args: list[Any] = [] args_offset = 0 for idx, arg in enumerate(args): diff --git a/torchopt/distributed/autograd.py b/torchopt/distributed/autograd.py index c2a4b3e2..4e10d24e 100644 --- a/torchopt/distributed/autograd.py +++ b/torchopt/distributed/autograd.py @@ -38,7 +38,7 @@ def is_available() -> bool: if is_available(): # pylint: disable-next=unused-import,ungrouped-imports - from torch.distributed.autograd import DistAutogradContext, get_gradients # noqa: F401 + from torch.distributed.autograd import DistAutogradContext, get_gradients def backward( autograd_ctx_id: int, @@ -131,4 +131,4 @@ def grad( return tuple(grads) - __all__.extend(['DistAutogradContext', 'get_gradients', 'backward', 'grad']) + __all__ += ['DistAutogradContext', 'get_gradients', 'backward', 'grad'] diff --git a/torchopt/nn/module.py b/torchopt/nn/module.py index 09ab359e..64623146 100644 --- a/torchopt/nn/module.py +++ b/torchopt/nn/module.py @@ -18,6 +18,7 @@ from collections import OrderedDict from typing import Any, Iterator, NamedTuple +from typing_extensions import Self # Python 3.11+ import torch import torch.nn as nn @@ -40,7 +41,7 @@ class MetaGradientModule(nn.Module): # pylint: disable=abstract-method _meta_parameters: TensorContainer _meta_modules: dict[str, nn.Module | None] - def __new__(cls, *args: Any, **kwargs: Any) -> MetaGradientModule: + def __new__(cls, *args: Any, **kwargs: Any) -> Self: """Create a new module instance.""" instance = super().__new__(cls) flat_args: list[Any] diff --git a/torchopt/optim/__init__.py b/torchopt/optim/__init__.py index 8e390a5c..20da5fca 100644 --- a/torchopt/optim/__init__.py +++ b/torchopt/optim/__init__.py @@ -15,10 +15,13 @@ """object oriented optimizer implementations.""" from torchopt.optim import meta +from torchopt.optim.adadelta import AdaDelta, Adadelta from torchopt.optim.adagrad import AdaGrad, Adagrad from torchopt.optim.adam import Adam +from torchopt.optim.adamax import AdaMax, Adamax from torchopt.optim.adamw import AdamW from torchopt.optim.base import Optimizer from torchopt.optim.func import FuncOptimizer +from torchopt.optim.radam import RAdam from torchopt.optim.rmsprop import RMSProp, RMSprop from torchopt.optim.sgd import SGD diff --git a/torchopt/optim/adadelta.py b/torchopt/optim/adadelta.py new file mode 100644 index 00000000..7c73cb58 --- /dev/null +++ b/torchopt/optim/adadelta.py @@ -0,0 +1,75 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Adadelta optimizer.""" + +from __future__ import annotations + +from typing import Iterable + +import torch + +from torchopt import alias +from torchopt.optim.base import Optimizer +from torchopt.typing import ScalarOrSchedule + + +__all__ = ['AdaDelta', 'Adadelta'] + + +class AdaDelta(Optimizer): + """The classic AdaDelta optimizer. + + See Also: + - The functional AdaDelta optimizer: :func:`torchopt.adadelta`. + - The differentiable meta-AdaDelta optimizer: :class:`torchopt.MetaAdaDetla`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + params: Iterable[torch.Tensor], + lr: ScalarOrSchedule = 1.0, + rho: float = 0.9, + eps: float = 1e-6, + weight_decay: float = 0.0, + ) -> None: + r"""Initialize the AdaDelta optimizer. + + Args: + params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what + tensors should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + rho (float, optional): Coefficients used for computing running averages of gradient and its square. + (default: :const:`0.9`) + eps (float, optional): A small constant applied to the square root (as in the AdaDelta paper) + to avoid dividing by zero when rescaling. + (default: :const:`1e-6`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + """ + super().__init__( + params, + alias.adadelta( + lr=lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + moment_requires_grad=False, + ), + ) + + +Adadelta = AdaDelta # alias for PyTorch compatibility diff --git a/torchopt/optim/adagrad.py b/torchopt/optim/adagrad.py index 055e0ad5..a7e8c72b 100644 --- a/torchopt/optim/adagrad.py +++ b/torchopt/optim/adagrad.py @@ -33,7 +33,7 @@ class AdaGrad(Optimizer): See Also: - The functional AdaGrad optimizer: :func:`torchopt.adagrad`. - - The differentiable meta AdaGrad optimizer: :class:`torchopt.MetaAdaGrad`. + - The differentiable meta-AdaGrad optimizer: :class:`torchopt.MetaAdaGrad`. """ # pylint: disable-next=too-many-arguments diff --git a/torchopt/optim/adamax.py b/torchopt/optim/adamax.py new file mode 100644 index 00000000..904c05a0 --- /dev/null +++ b/torchopt/optim/adamax.py @@ -0,0 +1,75 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Adamax optimizer.""" + +from __future__ import annotations + +from typing import Iterable + +import torch + +from torchopt import alias +from torchopt.optim.base import Optimizer +from torchopt.typing import ScalarOrSchedule + + +__all__ = ['AdaMax', 'Adamax'] + + +class AdaMax(Optimizer): + """The classic AdaMax optimizer. + + See Also: + - The functional AdaMax optimizer: :func:`torchopt.adamax`. + - The differentiable meta-AdaMax optimizer: :class:`torchopt.MetaAdaMax`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + params: Iterable[torch.Tensor], + lr: ScalarOrSchedule = 2e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + ) -> None: + r"""Initialize the AdaMax optimizer. + + Args: + params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what + tensors should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to the square root (as in the AdaMax paper) + to avoid dividing by zero when rescaling. + (default: :const:`1e-6`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + """ + super().__init__( + params, + alias.adamax( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + moment_requires_grad=False, + ), + ) + + +Adamax = AdaMax # alias for PyTorch compatibility diff --git a/torchopt/optim/func/base.py b/torchopt/optim/func/base.py index 94038464..7a7839a3 100644 --- a/torchopt/optim/func/base.py +++ b/torchopt/optim/func/base.py @@ -34,9 +34,12 @@ class FuncOptimizer: # pylint: disable=too-few-public-methods and update the parameters. See Also: + - The functional AdaDelta optimizer: :func:`torchopt.adadelta`. - The functional AdaGrad optimizer: :func:`torchopt.adagrad`. - The functional Adam optimizer: :func:`torchopt.adam`. - The functional AdamW optimizer: :func:`torchopt.adamw`. + - The functional AdaMax optimizer: :func:`torchopt.adamax`. + - The functional RAdam optimizer: :func:`torchopt.radam`. - The functional RMSprop optimizer: :func:`torchopt.rmsprop`. - The functional SGD optimizer: :func:`torchopt.sgd`. """ diff --git a/torchopt/optim/meta/__init__.py b/torchopt/optim/meta/__init__.py index 28f374cc..516f2b5f 100644 --- a/torchopt/optim/meta/__init__.py +++ b/torchopt/optim/meta/__init__.py @@ -14,9 +14,12 @@ # ============================================================================== """Differentiable Meta-Optimizers.""" +from torchopt.optim.meta.adadelta import MetaAdaDelta, MetaAdadelta from torchopt.optim.meta.adagrad import MetaAdaGrad, MetaAdagrad from torchopt.optim.meta.adam import MetaAdam +from torchopt.optim.meta.adamax import MetaAdaMax, MetaAdamax from torchopt.optim.meta.adamw import MetaAdamW from torchopt.optim.meta.base import MetaOptimizer +from torchopt.optim.meta.radam import MetaRAdam from torchopt.optim.meta.rmsprop import MetaRMSProp, MetaRMSprop from torchopt.optim.meta.sgd import MetaSGD diff --git a/torchopt/optim/meta/adadelta.py b/torchopt/optim/meta/adadelta.py new file mode 100644 index 00000000..36d8d9ad --- /dev/null +++ b/torchopt/optim/meta/adadelta.py @@ -0,0 +1,77 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Differentiable Adadelta optimizer.""" + +from __future__ import annotations + +import torch.nn as nn + +from torchopt import alias +from torchopt.optim.meta.base import MetaOptimizer +from torchopt.typing import ScalarOrSchedule + + +__all__ = ['MetaAdaDelta', 'MetaAdadelta'] + + +class MetaAdaDelta(MetaOptimizer): + """The differentiable AdaDelta optimizer. + + See Also: + - The functional AdaDelta optimizer: :func:`torchopt.adadetla`. + - The classic AdaDelta optimizer: :class:`torchopt.Adadelta`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + module: nn.Module, + lr: ScalarOrSchedule = 1.0, + rho: float = 0.9, + eps: float = 1e-6, + weight_decay: float = 0.0, + *, + moment_requires_grad: bool = True, + ) -> None: + """Initialize the meta AdaDelta optimizer. + + Args: + module (nn.Module): A network whose parameters should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + rho (float, optional): Coefficients used for computing running averages of gradient and its square. + (default: :const:`0.9`) + eps (float, optional): A small constant applied to the square root (as in the AdaDelta paper) + to avoid dividing by zero when rescaling. + (default: :const:`1e-6`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + """ + super().__init__( + module, + alias.adadelta( + lr=lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + moment_requires_grad=moment_requires_grad, + ), + ) + + +MetaAdadelta = MetaAdaDelta # alias for PyTorch compatibility diff --git a/torchopt/optim/meta/adagrad.py b/torchopt/optim/meta/adagrad.py index 079d76db..4e8ef0eb 100644 --- a/torchopt/optim/meta/adagrad.py +++ b/torchopt/optim/meta/adagrad.py @@ -31,7 +31,7 @@ class MetaAdaGrad(MetaOptimizer): See Also: - The functional AdaGrad optimizer: :func:`torchopt.adagrad`. - - The classic AdaGrad optimizer: :class:`torchopt.AdaGrad`. + - The classic AdaGrad optimizer: :class:`torchopt.Adagrad`. """ # pylint: disable-next=too-many-arguments diff --git a/torchopt/optim/meta/adamax.py b/torchopt/optim/meta/adamax.py new file mode 100644 index 00000000..01082af2 --- /dev/null +++ b/torchopt/optim/meta/adamax.py @@ -0,0 +1,77 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Differentiable Adamax optimizer.""" + +from __future__ import annotations + +import torch.nn as nn + +from torchopt import alias +from torchopt.optim.meta.base import MetaOptimizer +from torchopt.typing import ScalarOrSchedule + + +__all__ = ['MetaAdaMax', 'MetaAdamax'] + + +class MetaAdaMax(MetaOptimizer): + """The differentiable AdaMax optimizer. + + See Also: + - The functional AdaMax optimizer: :func:`torchopt.adamax`. + - The classic AdaMax optimizer: :class:`torchopt.Adamax`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + module: nn.Module, + lr: ScalarOrSchedule = 2e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + *, + moment_requires_grad: bool = True, + ) -> None: + """Initialize the meta AdaMax optimizer. + + Args: + module (nn.Module): A network whose parameters should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to the square root (as in the AdaMax paper) + to avoid dividing by zero when rescaling. + (default: :const:`1e-6`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + """ + super().__init__( + module, + alias.adamax( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + moment_requires_grad=moment_requires_grad, + ), + ) + + +MetaAdamax = MetaAdaMax # alias for PyTorch compatibility diff --git a/torchopt/optim/meta/radam.py b/torchopt/optim/meta/radam.py new file mode 100644 index 00000000..baf4cdd2 --- /dev/null +++ b/torchopt/optim/meta/radam.py @@ -0,0 +1,74 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Differentiable RAdam optimizer.""" + +from __future__ import annotations + +import torch.nn as nn + +from torchopt import alias +from torchopt.optim.meta.base import MetaOptimizer +from torchopt.typing import ScalarOrSchedule + + +__all__ = ['MetaRAdam'] + + +class MetaRAdam(MetaOptimizer): + """The differentiable RAdam optimizer. + + See Also: + - The functional RAdam optimizer: :func:`torchopt.radan`. + - The classic RAdam optimizer: :class:`torchopt.RAdam`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + module: nn.Module, + lr: ScalarOrSchedule = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + *, + moment_requires_grad: bool = True, + ) -> None: + """Initialize the meta-RAdam optimizer. + + Args: + module (nn.Module): A network whose parameters should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to the square root (as in the RAdam paper) + to avoid dividing by zero when rescaling. + (default: :const:`1e-6`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + """ + super().__init__( + module, + alias.radam( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + moment_requires_grad=moment_requires_grad, + ), + ) diff --git a/torchopt/optim/radam.py b/torchopt/optim/radam.py new file mode 100644 index 00000000..c2f6a211 --- /dev/null +++ b/torchopt/optim/radam.py @@ -0,0 +1,72 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RAdam optimizer.""" + +from __future__ import annotations + +from typing import Iterable + +import torch + +from torchopt import alias +from torchopt.optim.base import Optimizer +from torchopt.typing import ScalarOrSchedule + + +__all__ = ['RAdam'] + + +class RAdam(Optimizer): + """The classic RAdam optimizer. + + See Also: + - The functional Adam optimizer: :func:`torchopt.radam`. + - The differentiable meta-RAdam optimizer: :class:`torchopt.MetaRAdam`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + params: Iterable[torch.Tensor], + lr: ScalarOrSchedule = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + ) -> None: + r"""Initialize the RAdam optimizer. + + Args: + params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what + tensors should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to the square root (as in the RAdam paper) + to avoid dividing by zero when rescaling. + (default: :const:`1e-6`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + """ + super().__init__( + params, + alias.radam( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + moment_requires_grad=False, + ), + ) diff --git a/torchopt/pytree.py b/torchopt/pytree.py index 253cb154..6d41d0fa 100644 --- a/torchopt/pytree.py +++ b/torchopt/pytree.py @@ -194,7 +194,7 @@ def tree_local_value(rref_tree: PyTree[RRef[T]]) -> PyTree[T]: r"""Return the local value of a tree of :class:`RRef`\s.""" return tree_map(lambda x: x.local_value(), rref_tree) - __all__.extend(['tree_as_rref', 'tree_to_here']) + __all__ += ['tree_as_rref', 'tree_to_here'] -del Callable, optree, rpc, Scalar, T, RRef +del optree, rpc diff --git a/torchopt/transform/__init__.py b/torchopt/transform/__init__.py index 47c49ea1..c75fcb5d 100644 --- a/torchopt/transform/__init__.py +++ b/torchopt/transform/__init__.py @@ -34,7 +34,10 @@ from torchopt.transform.add_decayed_weights import add_decayed_weights, masked from torchopt.transform.nan_to_num import nan_to_num from torchopt.transform.scale import scale +from torchopt.transform.scale_by_adadelta import scale_by_adadelta from torchopt.transform.scale_by_adam import scale_by_accelerated_adam, scale_by_adam +from torchopt.transform.scale_by_adamax import scale_by_adamax +from torchopt.transform.scale_by_radam import scale_by_radam from torchopt.transform.scale_by_rms import scale_by_rms from torchopt.transform.scale_by_rss import scale_by_rss from torchopt.transform.scale_by_schedule import scale_by_schedule @@ -49,6 +52,9 @@ 'add_decayed_weights', 'masked', 'scale_by_adam', + 'scale_by_adamax', + 'scale_by_adadelta', + 'scale_by_radam', 'scale_by_accelerated_adam', 'scale_by_rss', 'scale_by_rms', diff --git a/torchopt/transform/scale_by_adadelta.py b/torchopt/transform/scale_by_adadelta.py new file mode 100644 index 00000000..fb5431a3 --- /dev/null +++ b/torchopt/transform/scale_by_adadelta.py @@ -0,0 +1,165 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Preset transformations for scaling updates by Adam.""" + +# pylint: disable=invalid-name + +from __future__ import annotations + +from typing import NamedTuple + +import torch + +from torchopt import pytree +from torchopt.base import GradientTransformation +from torchopt.transform.utils import tree_map_flat, update_moment +from torchopt.typing import OptState, Params, Updates + + +__all__ = ['scale_by_adadelta'] + + +class ScaleByAdadeltaState(NamedTuple): + """State for the Adadelta algorithm.""" + + mu: Updates + nu: Updates + + +def scale_by_adadelta( + rho: float = 0.9, + eps: float = 1e-6, + moment_requires_grad: bool = False, +) -> GradientTransformation: + """Rescale updates according to the Adadelta algorithm. + + References: + - Zeiler, 2012: https://arxiv.org/abs/1212.5701 + + Args: + rho (float, optional): Decay rate for the squared grads. + (default: :const:`0.9`) + eps (float, optional): Term added to the denominator to improve numerical stability. + (default: :const:`1e-6`) + moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag + ``requires_grad = True``. (default: :data:`False`) + + Returns: + An (init_fn, update_fn) tuple. + """ + return _scale_by_adadelta( + rho=rho, + eps=eps, + moment_requires_grad=moment_requires_grad, + already_flattened=False, + ) + + +def _scale_by_adadelta_flat( + rho: float = 0.9, + eps: float = 1e-6, + moment_requires_grad: bool = False, +) -> GradientTransformation: + return _scale_by_adadelta( + rho=rho, + eps=eps, + moment_requires_grad=moment_requires_grad, + already_flattened=True, + ) + + +def _scale_by_adadelta( + rho: float = 0.9, + eps: float = 1e-6, + moment_requires_grad: bool = False, + *, + already_flattened: bool = False, +) -> GradientTransformation: + # pylint: disable=unneeded-not + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= rho < 1.0: # pragma: no cover + raise ValueError(f'Invalid rho parameter at index 0: {rho}') + # pylint: enable=unneeded-not + + if already_flattened: # noqa: SIM108 + tree_map = tree_map_flat + else: + tree_map = pytree.tree_map # type: ignore[assignment] + + def init_fn(params: Params) -> OptState: + mu = tree_map( # first moment + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), + params, + ) + nu = tree_map( # second moment + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), + params, + ) + return ScaleByAdadeltaState(mu=mu, nu=nu) + + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: + mu = update_moment.impl( # type: ignore[attr-defined] + updates, + state.mu, + rho, + order=2, + inplace=inplace, + already_flattened=already_flattened, + ) + + if inplace: + + def f( + g: torch.Tensor, # pylint: disable=unused-argument + m: torch.Tensor, + v: torch.Tensor, + ) -> torch.Tensor: + return g.mul_(v.add(eps).div_(m.add(eps)).sqrt_()) + + else: + + def f( + g: torch.Tensor, # pylint: disable=unused-argument + m: torch.Tensor, + v: torch.Tensor, + ) -> torch.Tensor: + return g.mul(v.add(eps).div_(m.add(eps)).sqrt_()) + + updates = tree_map(f, updates, mu, state.nu) + + nu = update_moment.impl( # type: ignore[attr-defined] + updates, + state.nu, + rho, + order=2, + inplace=inplace, + already_flattened=already_flattened, + ) + + return updates, ScaleByAdadeltaState(mu=mu, nu=nu) + + return GradientTransformation(init_fn, update_fn) + + +scale_by_adadelta.flat = _scale_by_adadelta_flat # type: ignore[attr-defined] +scale_by_adadelta.impl = _scale_by_adadelta # type: ignore[attr-defined] diff --git a/torchopt/transform/scale_by_adamax.py b/torchopt/transform/scale_by_adamax.py new file mode 100644 index 00000000..504e82cd --- /dev/null +++ b/torchopt/transform/scale_by_adamax.py @@ -0,0 +1,164 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Preset transformations for scaling updates by Adamax.""" + +# pylint: disable=invalid-name + +from __future__ import annotations + +from typing import NamedTuple + +import torch + +from torchopt import pytree +from torchopt.base import GradientTransformation +from torchopt.transform.utils import tree_map_flat, update_moment +from torchopt.typing import OptState, Params, Updates + + +__all__ = ['scale_by_adamax'] + + +class ScaleByAdamaxState(NamedTuple): + """State for the Adamax algorithm.""" + + mu: Updates + nu: Updates + t: int + + +def scale_by_adamax( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-6, + moment_requires_grad: bool = False, +) -> GradientTransformation: + """A Adam algorithm variation. + + References: + - Kingma et al., 2014: https://arxiv.org/abs/1412.6980 + + Args: + b1 (float, optional): Decay rate for the exponentially weighted average of grads. + (default: :const:`0.9`) + b2 (float, optional): Decay rate for the exponentially weighted average of squared grads. + (default: :const:`0.999`) + eps (float, optional): Term added to the denominator to improve numerical stability. + (default: :const:`1e-6`) + moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag + ``requires_grad = True``. (default: :data:`False`) + + Returns: + An (init_fn, update_fn) tuple. + """ + return _scale_by_adamax( + b1=b1, + b2=b2, + eps=eps, + moment_requires_grad=moment_requires_grad, + already_flattened=False, + ) + + +def _scale_by_adamax_flat( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-6, + moment_requires_grad: bool = False, +) -> GradientTransformation: + return _scale_by_adamax( + b1=b1, + b2=b2, + eps=eps, + moment_requires_grad=moment_requires_grad, + already_flattened=True, + ) + + +def _scale_by_adamax( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-6, + moment_requires_grad: bool = False, + *, + already_flattened: bool = False, +) -> GradientTransformation: + # pylint: disable=unneeded-not + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= b1 < 1.0: # pragma: no cover + raise ValueError(f'Invalid b1 parameter at index 0: {b1}') + if not 0.0 <= b2 < 1.0: # pragma: no cover + raise ValueError(f'Invalid b1 parameter at index 0: {b2}') + # pylint: enable=unneeded-not + + if already_flattened: # noqa: SIM108 + tree_map = tree_map_flat + else: + tree_map = pytree.tree_map # type: ignore[assignment] + + def init_fn(params: Params) -> OptState: + mu = tree_map( # first moment + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), + params, + ) + nu = tree_map( # second moment + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), + params, + ) + return ScaleByAdamaxState(mu=mu, nu=nu, t=1) + + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: + mu = update_moment.impl( # type: ignore[attr-defined] + updates, + state.mu, + b1, + order=1, + inplace=inplace, + already_flattened=already_flattened, + ) + + def update_nu( + g: torch.Tensor, + n: torch.Tensor, + ) -> torch.Tensor: + return torch.max(n.mul(b2), g.abs().add_(eps)) + + nu = tree_map(update_nu, updates, state.nu) + + one_minus_b1_pow_t = 1 - b1**state.t + + def f( + n: torch.Tensor, + m: torch.Tensor, + ) -> torch.Tensor: + return m.div(n).div_(one_minus_b1_pow_t) + + updates = tree_map(f, nu, mu) + + return updates, ScaleByAdamaxState(mu=mu, nu=nu, t=state.t + 1) + + return GradientTransformation(init_fn, update_fn) + + +scale_by_adamax.flat = _scale_by_adamax_flat # type: ignore[attr-defined] +scale_by_adamax.impl = _scale_by_adamax # type: ignore[attr-defined] diff --git a/torchopt/transform/scale_by_radam.py b/torchopt/transform/scale_by_radam.py new file mode 100644 index 00000000..acb85a82 --- /dev/null +++ b/torchopt/transform/scale_by_radam.py @@ -0,0 +1,204 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Preset transformations for scaling updates by RAdam.""" + +# pylint: disable=invalid-name + +from __future__ import annotations + +import math +from typing import NamedTuple + +import torch + +from torchopt import pytree +from torchopt.base import GradientTransformation +from torchopt.transform.utils import tree_map_flat, update_moment +from torchopt.typing import OptState, Params, Updates + + +__all__ = ['scale_by_radam'] + + +class ScaleByRAdamState(NamedTuple): + """State for the RAdam algorithm.""" + + mu: Updates + nu: Updates + t: int + + +def scale_by_radam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-6, + moment_requires_grad: bool = False, +) -> GradientTransformation: + """Rescale updates according to the RAdam algorithm. + + References: + - Liu, 2019: https://arxiv.org/abs/1908.03265 + + Args: + b1 (float, optional): Decay rate for the exponentially weighted average of grads. + (default: :const:`0.9`) + b2 (float, optional): Decay rate for the exponentially weighted average of squared grads. + (default: :const:`0.999`) + eps (float, optional): Term added to the denominator to improve numerical stability. + (default: :const:`1e-6`) + moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag + ``requires_grad = True``. (default: :data:`False`) + + Returns: + An (init_fn, update_fn) tuple. + """ + return _scale_by_radam( + b1=b1, + b2=b2, + eps=eps, + moment_requires_grad=moment_requires_grad, + already_flattened=False, + ) + + +def _scale_by_radam_flat( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-6, + moment_requires_grad: bool = False, +) -> GradientTransformation: + return _scale_by_radam( + b1=b1, + b2=b2, + eps=eps, + moment_requires_grad=moment_requires_grad, + already_flattened=True, + ) + + +def _scale_by_radam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-6, + moment_requires_grad: bool = False, + *, + already_flattened: bool = False, +) -> GradientTransformation: + # pylint: disable=unneeded-not + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= b1 < 1.0: # pragma: no cover + raise ValueError(f'Invalid b1 parameter at index 0: {b1}') + if not 0.0 <= b2 < 1.0: # pragma: no cover + raise ValueError(f'Invalid b1 parameter at index 0: {b2}') + # pylint: enable=unneeded-not + + if already_flattened: # noqa: SIM108 + tree_map = tree_map_flat + else: + tree_map = pytree.tree_map # type: ignore[assignment] + + def init_fn(params: Params) -> OptState: + mu = tree_map( # first moment + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), + params, + ) + nu = tree_map( # second moment + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), + params, + ) + return ScaleByRAdamState(mu=mu, nu=nu, t=1) + + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: + mu = update_moment.impl( # type: ignore[attr-defined] + updates, + state.mu, + b1, + order=1, + inplace=inplace, + already_flattened=already_flattened, + ) + + nu = update_moment.impl( # type: ignore[attr-defined] + updates, + state.nu, + b2, + order=2, + inplace=inplace, + already_flattened=already_flattened, + ) + + rho_inf = 2 / (1 - b2) - 1 + one_minus_b1_pow_t = 1 - b1**state.t + one_minus_b2_pow_t = 1 - b2**state.t + rho = rho_inf - 2 * state.t * b2**state.t / one_minus_b2_pow_t + + if rho > 5: + numerator = math.sqrt( + one_minus_b2_pow_t + * (rho - 4) + * (rho - 2) + * rho_inf + / ((rho_inf - 4) * (rho_inf - 2) * rho), + ) + if inplace: + + def f( + m: torch.Tensor, + v: torch.Tensor, + ) -> torch.Tensor: + return m.mul(numerator / one_minus_b1_pow_t).div_(v.sqrt().add_(eps)) + + else: + + def f( + m: torch.Tensor, + v: torch.Tensor, + ) -> torch.Tensor: + return m.mul(numerator / one_minus_b1_pow_t).div(v.sqrt().add(eps)) + + else: + if inplace: + + def f( + m: torch.Tensor, + v: torch.Tensor, # pylint: disable=unused-argument + ) -> torch.Tensor: + return m.div(one_minus_b1_pow_t) + + else: + + def f( + m: torch.Tensor, + v: torch.Tensor, # pylint: disable=unused-argument + ) -> torch.Tensor: + return m.div(one_minus_b1_pow_t) + + updates = tree_map(f, mu, nu) + + return updates, ScaleByRAdamState(mu=mu, nu=nu, t=state.t + 1) + + return GradientTransformation(init_fn, update_fn) + + +scale_by_radam.flat = _scale_by_radam_flat # type: ignore[attr-defined] +scale_by_radam.impl = _scale_by_radam # type: ignore[attr-defined] diff --git a/torchopt/typing.py b/torchopt/typing.py index 510cb693..c5c76984 100644 --- a/torchopt/typing.py +++ b/torchopt/typing.py @@ -115,7 +115,7 @@ if rpc.is_available(): # pragma: no cover from torch.distributed.rpc import RRef # pylint: disable=ungrouped-imports,unused-import - __all__.extend(['RRef']) + __all__ += ['RRef'] else: # pragma: no cover # pylint: disable-next=invalid-name RRef = None # type: ignore[misc,assignment] diff --git a/torchopt/utils.py b/torchopt/utils.py index 69bda9ac..5414db80 100644 --- a/torchopt/utils.py +++ b/torchopt/utils.py @@ -270,8 +270,7 @@ def get_variable(t: torch.Tensor | None) -> torch.Tensor | None: return replicate(t) return t - state = pytree.tree_map(get_variable, state) # type: ignore[arg-type,assignment] - return state + return pytree.tree_map(get_variable, state) # type: ignore[arg-type,return-value] raise RuntimeError(f'Unexpected class of {target}') diff --git a/torchopt/version.py b/torchopt/version.py index 4b091d8a..87c4fe49 100644 --- a/torchopt/version.py +++ b/torchopt/version.py @@ -14,7 +14,7 @@ # ============================================================================== """TorchOpt: a high-performance optimizer library built upon PyTorch.""" -__version__ = '0.7.1' +__version__ = '0.7.2' __license__ = 'Apache License, Version 2.0' __author__ = 'TorchOpt Contributors' __release__ = False diff --git a/torchopt/visual.py b/torchopt/visual.py index 493ffbab..47a7f5d5 100644 --- a/torchopt/visual.py +++ b/torchopt/visual.py @@ -19,7 +19,6 @@ from __future__ import annotations -from collections import namedtuple from typing import Any, Generator, Iterable, Mapping, cast import torch @@ -33,8 +32,6 @@ __all__ = ['make_dot', 'resize_graph'] -Node = namedtuple('Node', ('name', 'inputs', 'attr', 'op')) - # Saved attrs for grad_fn (incl. saved variables) begin with `._saved_*` SAVED_PREFIX = '_saved_' diff --git a/tutorials/1_Functional_Optimizer.ipynb b/tutorials/1_Functional_Optimizer.ipynb index 07a8aeb8..afc55f38 100644 --- a/tutorials/1_Functional_Optimizer.ipynb +++ b/tutorials/1_Functional_Optimizer.ipynb @@ -1,5 +1,5 @@ { - "cells": [ + "cells": [ { "cell_type": "markdown", "metadata": {}, diff --git a/tutorials/2_Visualization.ipynb b/tutorials/2_Visualization.ipynb index 11c68bec..dd58c48d 100644 --- a/tutorials/2_Visualization.ipynb +++ b/tutorials/2_Visualization.ipynb @@ -1,5 +1,5 @@ { - "cells": [ + "cells": [ { "cell_type": "markdown", "metadata": {}, diff --git a/tutorials/3_Meta_Optimizer.ipynb b/tutorials/3_Meta_Optimizer.ipynb index 4a09836c..69be77ed 100644 --- a/tutorials/3_Meta_Optimizer.ipynb +++ b/tutorials/3_Meta_Optimizer.ipynb @@ -1,5 +1,5 @@ { - "cells": [ + "cells": [ { "cell_type": "markdown", "metadata": {}, diff --git a/tutorials/4_Stop_Gradient.ipynb b/tutorials/4_Stop_Gradient.ipynb index 06e6b3c3..d8c24bc6 100644 --- a/tutorials/4_Stop_Gradient.ipynb +++ b/tutorials/4_Stop_Gradient.ipynb @@ -1,5 +1,5 @@ { - "cells": [ + "cells": [ { "cell_type": "markdown", "metadata": {}, diff --git a/tutorials/5_Implicit_Differentiation.ipynb b/tutorials/5_Implicit_Differentiation.ipynb index f8258fcc..23407801 100644 --- a/tutorials/5_Implicit_Differentiation.ipynb +++ b/tutorials/5_Implicit_Differentiation.ipynb @@ -1,5 +1,5 @@ { - "cells": [ + "cells": [ { "cell_type": "markdown", "id": "8850c832-3b54-4971-8ee0-2cd64b585ea8", diff --git a/tutorials/6_Zero_Order_Differentiation.ipynb b/tutorials/6_Zero_Order_Differentiation.ipynb index 968f6b6c..d6cb028c 100644 --- a/tutorials/6_Zero_Order_Differentiation.ipynb +++ b/tutorials/6_Zero_Order_Differentiation.ipynb @@ -1,5 +1,5 @@ { - "cells": [ + "cells": [ { "cell_type": "markdown", "id": "8850c832-3b54-4971-8ee0-2cd64b585ea8",
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: