From a0cfcdfbdcd14b217515b088b8c0d5936cbe2559 Mon Sep 17 00:00:00 2001 From: Bo Liu Date: Sun, 20 Aug 2023 21:56:43 +0800 Subject: [PATCH 1/6] docs: fix typo in explicit differentiation (#188) --- docs/source/explicit_diff/explicit_diff.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/explicit_diff/explicit_diff.rst b/docs/source/explicit_diff/explicit_diff.rst index 9445adb8..28e06f77 100644 --- a/docs/source/explicit_diff/explicit_diff.rst +++ b/docs/source/explicit_diff/explicit_diff.rst @@ -59,7 +59,7 @@ For PyTorch-like API (e.g., ``step()``), we designed a base class :class:`torcho torchopt.MetaAdagrad torchopt.MetaAdam torchopt.MetaAdamW - torchopt.AdaMax + torchopt.MetaAdaMax torchopt.MetaAdamax torchopt.MetaRAdam torchopt.MetaRMSProp From 91b85ae91ead9349b33260ae1f564d3cfca7457e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 5 Sep 2023 16:48:08 +0800 Subject: [PATCH 2/6] chore(pre-commit): [pre-commit.ci] autoupdate (#190) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e5c37d40..4d2ef96d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,7 +30,7 @@ repos: hooks: - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.284 + rev: v0.0.287 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] From 2117c18d4b86dc7dbcb97a23dae6010265698401 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 11 Sep 2023 14:42:52 +0800 Subject: [PATCH 3/6] deps(workflows): bump actions/checkout from 3 to 4 (#192) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/build.yml | 8 ++++---- .github/workflows/lint.yml | 2 +- .github/workflows/tests.yml | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0a6c4d6e..4f15d1a4 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -48,7 +48,7 @@ jobs: timeout-minutes: 60 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 0 @@ -108,7 +108,7 @@ jobs: timeout-minutes: 60 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 0 @@ -158,7 +158,7 @@ jobs: timeout-minutes: 60 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 0 @@ -206,7 +206,7 @@ jobs: timeout-minutes: 15 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 0 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index b338b149..80ed2876 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -24,7 +24,7 @@ jobs: timeout-minutes: 30 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 1 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4f6fad50..6a8f7ec8 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -36,7 +36,7 @@ jobs: timeout-minutes: 60 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 1 @@ -106,7 +106,7 @@ jobs: fail-fast: false steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 1 From 86b167c054b158c8d49c262b30c430838a06b794 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 25 Sep 2023 14:55:52 +0800 Subject: [PATCH 4/6] deps(workflows): bump pypa/cibuildwheel from 2.15 to 2.16 (#193) Bumps [pypa/cibuildwheel](https://github.com/pypa/cibuildwheel) from 2.15 to 2.16. - [Release notes](https://github.com/pypa/cibuildwheel/releases) - [Changelog](https://github.com/pypa/cibuildwheel/blob/main/docs/changelog.md) - [Commits](https://github.com/pypa/cibuildwheel/compare/v2.15...v2.16) --- updated-dependencies: - dependency-name: pypa/cibuildwheel dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/build.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 4f15d1a4..bed635e3 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -132,7 +132,7 @@ jobs: run: python .github/workflows/set_cibw_build.py - name: Build wheels - uses: pypa/cibuildwheel@v2.15 + uses: pypa/cibuildwheel@v2.16 env: CIBW_BUILD: ${{ env.CIBW_BUILD }} with: @@ -182,7 +182,7 @@ jobs: run: python .github/workflows/set_cibw_build.py - name: Build wheels - uses: pypa/cibuildwheel@v2.15 + uses: pypa/cibuildwheel@v2.16 env: CIBW_BUILD: ${{ env.CIBW_BUILD }} with: From 93cc7ec81040dbc0977cecaa9d55f11fd8c32dcf Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 9 Nov 2023 19:25:45 +0800 Subject: [PATCH 5/6] fix: fix `optree` compatibility for multi-tree-map with `None` values (#195) --- .github/workflows/lint.yml | 4 +-- .pre-commit-config.yaml | 12 +++---- CHANGELOG.md | 4 +-- CMakeLists.txt | 4 +-- conda-recipe.yaml | 1 - docs/conda-recipe.yaml | 1 - docs/requirements.txt | 10 +++--- torchopt/alias/sgd.py | 1 + torchopt/alias/utils.py | 26 ++++++++------ torchopt/distributed/api.py | 2 ++ torchopt/linalg/cg.py | 4 ++- torchopt/linalg/ns.py | 2 ++ torchopt/nn/stateless.py | 4 +-- torchopt/transform/add_decayed_weights.py | 12 ++++--- torchopt/transform/scale_by_adadelta.py | 18 +++------- torchopt/transform/scale_by_adam.py | 20 ++++------- torchopt/transform/scale_by_adamax.py | 18 ++++------ torchopt/transform/scale_by_rms.py | 12 +++---- torchopt/transform/scale_by_rss.py | 20 +++++------ torchopt/transform/scale_by_schedule.py | 16 +++++---- torchopt/transform/scale_by_stddev.py | 12 +++---- torchopt/transform/trace.py | 42 ++++++++++++++--------- torchopt/transform/utils.py | 1 + torchopt/utils.py | 4 +-- 24 files changed, 126 insertions(+), 124 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 80ed2876..89c26c3c 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -29,10 +29,10 @@ jobs: submodules: "recursive" fetch-depth: 1 - - name: Set up Python 3.8 + - name: Set up Python 3.9 uses: actions/setup-python@v4 with: - python-version: "3.8" + python-version: "3.9" update-environment: true - name: Setup CUDA Toolkit diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4d2ef96d..975bb7b6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ ci: default_stages: [commit, push, manual] repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: check-symlinks - id: destroyed-symlinks @@ -26,11 +26,11 @@ repos: - id: debug-statements - id: double-quote-string-fixer - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v16.0.6 + rev: v17.0.4 hooks: - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.287 + rev: v0.1.5 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] @@ -39,11 +39,11 @@ repos: hooks: - id: isort - repo: https://github.com/psf/black - rev: 23.7.0 + rev: 23.11.0 hooks: - id: black-jupyter - repo: https://github.com/asottile/pyupgrade - rev: v3.10.1 + rev: v3.15.0 hooks: - id: pyupgrade args: [--py38-plus] # sync with requires-python @@ -68,7 +68,7 @@ repos: ^docs/source/conf.py$ ) - repo: https://github.com/codespell-project/codespell - rev: v2.2.5 + rev: v2.2.6 hooks: - id: codespell additional_dependencies: [".[toml]"] diff --git a/CHANGELOG.md b/CHANGELOG.md index 24e4eea4..797b6ca0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,11 +17,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- +- Set minimal C++ standard to C++17 by [@XuehaiPan](https://github.com/XuehaiPan) in [#195](https://github.com/metaopt/torchopt/pull/195). ### Fixed -- +- Fix `optree` compatibility for multi-tree-map with `None` values by [@XuehaiPan](https://github.com/XuehaiPan) in [#195](https://github.com/metaopt/torchopt/pull/195). ### Removed diff --git a/CMakeLists.txt b/CMakeLists.txt index 3b091a22..eca14815 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,13 +17,13 @@ cmake_minimum_required(VERSION 3.11) # for FetchContent project(torchopt LANGUAGES CXX) include(FetchContent) -set(PYBIND11_VERSION v2.10.3) +set(PYBIND11_VERSION v2.11.1) if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) endif() -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) find_package(Threads REQUIRED) # -pthread diff --git a/conda-recipe.yaml b/conda-recipe.yaml index 997f11c5..12b3a6d0 100644 --- a/conda-recipe.yaml +++ b/conda-recipe.yaml @@ -77,7 +77,6 @@ dependencies: - hunspell-en - myst-nb - ipykernel - - pandoc - docutils # Testing diff --git a/docs/conda-recipe.yaml b/docs/conda-recipe.yaml index 9a14af3f..30ec372e 100644 --- a/docs/conda-recipe.yaml +++ b/docs/conda-recipe.yaml @@ -67,5 +67,4 @@ dependencies: - hunspell-en - myst-nb - ipykernel - - pandoc - docutils diff --git a/docs/requirements.txt b/docs/requirements.txt index 655c64ff..82bf2d91 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,17 +4,17 @@ torch >= 1.13 --requirement ../requirements.txt -sphinx >= 5.2.1 +sphinx >= 5.2.1, < 7.0.0a0 +sphinxcontrib-bibtex >= 2.4 +sphinx-autodoc-typehints >= 1.20 +myst-nb >= 0.15 + sphinx-autoapi sphinx-autobuild sphinx-copybutton sphinx-rtd-theme sphinxcontrib-katex -sphinxcontrib-bibtex -sphinx-autodoc-typehints >= 1.19.2 IPython ipykernel -pandoc -myst-nb docutils matplotlib diff --git a/torchopt/alias/sgd.py b/torchopt/alias/sgd.py index 6fb3c6db..4c5b8317 100644 --- a/torchopt/alias/sgd.py +++ b/torchopt/alias/sgd.py @@ -44,6 +44,7 @@ __all__ = ['sgd'] +# pylint: disable-next=too-many-arguments def sgd( lr: ScalarOrSchedule, momentum: float = 0.0, diff --git a/torchopt/alias/utils.py b/torchopt/alias/utils.py index 5c8dc97a..2c2ec0e4 100644 --- a/torchopt/alias/utils.py +++ b/torchopt/alias/utils.py @@ -108,19 +108,21 @@ def update_fn( if inplace: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if g.requires_grad: return g.add_(p, alpha=weight_decay) return g.add_(p.data, alpha=weight_decay) - updates = tree_map_(f, updates, params) + tree_map_(f, params, updates) else: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: - return g.add(p, alpha=weight_decay) + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.add(p, alpha=weight_decay) if g is not None else g - updates = tree_map(f, updates, params) + updates = tree_map(f, params, updates) return updates, state @@ -139,7 +141,7 @@ def update_fn( def f(g: torch.Tensor) -> torch.Tensor: return g.neg_() - updates = tree_map_(f, updates) + tree_map_(f, updates) else: @@ -166,19 +168,21 @@ def update_fn( if inplace: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if g.requires_grad: return g.neg_().add_(p, alpha=weight_decay) return g.neg_().add_(p.data, alpha=weight_decay) - updates = tree_map_(f, updates, params) + tree_map_(f, params, updates) else: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: - return g.neg().add_(p, alpha=weight_decay) + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.neg().add_(p, alpha=weight_decay) if g is not None else g - updates = tree_map(f, updates, params) + updates = tree_map(f, params, updates) return updates, state diff --git a/torchopt/distributed/api.py b/torchopt/distributed/api.py index 3a6f0526..fb7461e4 100644 --- a/torchopt/distributed/api.py +++ b/torchopt/distributed/api.py @@ -271,6 +271,7 @@ def sum_reducer(results: Iterable[torch.Tensor]) -> torch.Tensor: return torch.sum(torch.stack(tuple(results), dim=0), dim=0) +# pylint: disable-next=too-many-arguments def remote_async_call( func: Callable[..., T], *, @@ -328,6 +329,7 @@ def remote_async_call( return future +# pylint: disable-next=too-many-arguments def remote_sync_call( func: Callable[..., T], *, diff --git a/torchopt/linalg/cg.py b/torchopt/linalg/cg.py index 9cd57cd8..42cb6bea 100644 --- a/torchopt/linalg/cg.py +++ b/torchopt/linalg/cg.py @@ -53,7 +53,7 @@ def _identity(x: TensorTree) -> TensorTree: return x -# pylint: disable-next=too-many-locals +# pylint: disable-next=too-many-arguments,too-many-locals def _cg_solve( A: Callable[[TensorTree], TensorTree], b: TensorTree, @@ -102,6 +102,7 @@ def body_fn( return x_final +# pylint: disable-next=too-many-arguments def _isolve( _isolve_solve: Callable, A: TensorTree | Callable[[TensorTree], TensorTree], @@ -134,6 +135,7 @@ def _isolve( return isolve_solve(A, b) +# pylint: disable-next=too-many-arguments def cg( A: TensorTree | Callable[[TensorTree], TensorTree], b: TensorTree, diff --git a/torchopt/linalg/ns.py b/torchopt/linalg/ns.py index 747ad3cf..ce49fe77 100644 --- a/torchopt/linalg/ns.py +++ b/torchopt/linalg/ns.py @@ -123,12 +123,14 @@ def _ns_inv(A: torch.Tensor, maxiter: int, alpha: float | None = None) -> torch. # A^{-1} = a [I - (I - a A)]^{-1} = a [I + (I - a A) + (I - a A)^2 + (I - a A)^3 + ...] M = I - alpha * A for rank in range(maxiter): + # pylint: disable-next=not-callable inv_A_hat = inv_A_hat + torch.linalg.matrix_power(M, rank) inv_A_hat = alpha * inv_A_hat else: # A^{-1} = [I - (I - A)]^{-1} = I + (I - A) + (I - A)^2 + (I - A)^3 + ... M = I - A for rank in range(maxiter): + # pylint: disable-next=not-callable inv_A_hat = inv_A_hat + torch.linalg.matrix_power(M, rank) return inv_A_hat diff --git a/torchopt/nn/stateless.py b/torchopt/nn/stateless.py index e547b5cb..8268ca3f 100644 --- a/torchopt/nn/stateless.py +++ b/torchopt/nn/stateless.py @@ -66,8 +66,8 @@ def recursive_setattr(path: str, value: torch.Tensor) -> torch.Tensor: mod._parameters[attr] = value # type: ignore[assignment] elif hasattr(mod, '_buffers') and attr in mod._buffers: mod._buffers[attr] = value - elif hasattr(mod, '_meta_parameters') and attr in mod._meta_parameters: # type: ignore[operator] - mod._meta_parameters[attr] = value # type: ignore[operator,index] + elif hasattr(mod, '_meta_parameters') and attr in mod._meta_parameters: + mod._meta_parameters[attr] = value else: setattr(mod, attr, value) # pylint: enable=protected-access diff --git a/torchopt/transform/add_decayed_weights.py b/torchopt/transform/add_decayed_weights.py index 04d564d7..39948694 100644 --- a/torchopt/transform/add_decayed_weights.py +++ b/torchopt/transform/add_decayed_weights.py @@ -226,19 +226,21 @@ def update_fn( if inplace: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if g.requires_grad: return g.add_(p, alpha=weight_decay) return g.add_(p.data, alpha=weight_decay) - updates = tree_map_(f, updates, params) + tree_map_(f, params, updates) else: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: - return g.add(p, alpha=weight_decay) + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.add(p, alpha=weight_decay) if g is not None else g - updates = tree_map(f, updates, params) + updates = tree_map(f, params, updates) return updates, state diff --git a/torchopt/transform/scale_by_adadelta.py b/torchopt/transform/scale_by_adadelta.py index fb5431a3..bbe40080 100644 --- a/torchopt/transform/scale_by_adadelta.py +++ b/torchopt/transform/scale_by_adadelta.py @@ -129,23 +129,15 @@ def update_fn( if inplace: - def f( - g: torch.Tensor, # pylint: disable=unused-argument - m: torch.Tensor, - v: torch.Tensor, - ) -> torch.Tensor: - return g.mul_(v.add(eps).div_(m.add(eps)).sqrt_()) + def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.mul_(v.add(eps).div_(m.add(eps)).sqrt_()) if g is not None else g else: - def f( - g: torch.Tensor, # pylint: disable=unused-argument - m: torch.Tensor, - v: torch.Tensor, - ) -> torch.Tensor: - return g.mul(v.add(eps).div_(m.add(eps)).sqrt_()) + def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.mul(v.add(eps).div_(m.add(eps)).sqrt_()) if g is not None else g - updates = tree_map(f, updates, mu, state.nu) + updates = tree_map(f, mu, state.nu, updates) nu = update_moment.impl( # type: ignore[attr-defined] updates, diff --git a/torchopt/transform/scale_by_adam.py b/torchopt/transform/scale_by_adam.py index c3c6254e..cc0ea3b6 100644 --- a/torchopt/transform/scale_by_adam.py +++ b/torchopt/transform/scale_by_adam.py @@ -132,6 +132,7 @@ def _scale_by_adam_flat( ) +# pylint: disable-next=too-many-arguments def _scale_by_adam( b1: float = 0.9, b2: float = 0.999, @@ -200,23 +201,15 @@ def update_fn( if inplace: - def f( - g: torch.Tensor, # pylint: disable=unused-argument - m: torch.Tensor, - v: torch.Tensor, - ) -> torch.Tensor: - return m.div_(v.add_(eps_root).sqrt_().add(eps)) + def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return m.div_(v.add_(eps_root).sqrt_().add(eps)) if g is not None else g else: - def f( - g: torch.Tensor, # pylint: disable=unused-argument - m: torch.Tensor, - v: torch.Tensor, - ) -> torch.Tensor: - return m.div(v.add(eps_root).sqrt_().add(eps)) + def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return m.div(v.add(eps_root).sqrt_().add(eps)) if g is not None else g - updates = tree_map(f, updates, mu_hat, nu_hat) + updates = tree_map(f, mu_hat, nu_hat, updates) return updates, ScaleByAdamState(mu=mu, nu=nu, count=count_inc) return GradientTransformation(init_fn, update_fn) @@ -283,6 +276,7 @@ def _scale_by_accelerated_adam_flat( ) +# pylint: disable-next=too-many-arguments def _scale_by_accelerated_adam( b1: float = 0.9, b2: float = 0.999, diff --git a/torchopt/transform/scale_by_adamax.py b/torchopt/transform/scale_by_adamax.py index 504e82cd..0a1c3ec9 100644 --- a/torchopt/transform/scale_by_adamax.py +++ b/torchopt/transform/scale_by_adamax.py @@ -137,23 +137,17 @@ def update_fn( already_flattened=already_flattened, ) - def update_nu( - g: torch.Tensor, - n: torch.Tensor, - ) -> torch.Tensor: - return torch.max(n.mul(b2), g.abs().add_(eps)) + def update_nu(n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return torch.max(n.mul(b2), g.abs().add_(eps)) if g is not None else g - nu = tree_map(update_nu, updates, state.nu) + nu = tree_map(update_nu, state.nu, updates) one_minus_b1_pow_t = 1 - b1**state.t - def f( - n: torch.Tensor, - m: torch.Tensor, - ) -> torch.Tensor: - return m.div(n).div_(one_minus_b1_pow_t) + def f(m: torch.Tensor, n: torch.Tensor | None) -> torch.Tensor: + return m.div(n).div_(one_minus_b1_pow_t) if n is not None else m - updates = tree_map(f, nu, mu) + updates = tree_map(f, mu, nu) return updates, ScaleByAdamaxState(mu=mu, nu=nu, t=state.t + 1) diff --git a/torchopt/transform/scale_by_rms.py b/torchopt/transform/scale_by_rms.py index ac2fef16..084be839 100644 --- a/torchopt/transform/scale_by_rms.py +++ b/torchopt/transform/scale_by_rms.py @@ -136,17 +136,17 @@ def update_fn( if inplace: - def f(g: torch.Tensor, n: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name - return g.div_(n.sqrt().add_(eps)) + def f(n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.div_(n.sqrt().add_(eps)) if g is not None else g - updates = tree_map_(f, updates, nu) + tree_map_(f, nu, updates) else: - def f(g: torch.Tensor, n: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name - return g.div(n.sqrt().add(eps)) + def f(n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.div(n.sqrt().add(eps)) if g is not None else g - updates = tree_map(f, updates, nu) + updates = tree_map(f, nu, updates) return updates, ScaleByRmsState(nu=nu) diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py index 68021e5e..b1f3d2a8 100644 --- a/torchopt/transform/scale_by_rss.py +++ b/torchopt/transform/scale_by_rss.py @@ -128,23 +128,21 @@ def update_fn( if inplace: - def f(g: torch.Tensor, sos: torch.Tensor) -> torch.Tensor: - return torch.where( - sos > 0.0, - g.div_(sos.sqrt().add_(eps)), - 0.0, + def f(sos: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return ( + torch.where(sos > 0.0, g.div_(sos.sqrt().add_(eps)), 0.0) + if g is not None + else g ) else: - def f(g: torch.Tensor, sos: torch.Tensor) -> torch.Tensor: - return torch.where( - sos > 0.0, - g.div(sos.sqrt().add(eps)), - 0.0, + def f(sos: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return ( + torch.where(sos > 0.0, g.div(sos.sqrt().add(eps)), 0.0) if g is not None else g ) - updates = tree_map(f, updates, sum_of_squares) + updates = tree_map(f, sum_of_squares, updates) return updates, ScaleByRssState(sum_of_squares=sum_of_squares) return GradientTransformation(init_fn, update_fn) diff --git a/torchopt/transform/scale_by_schedule.py b/torchopt/transform/scale_by_schedule.py index f27fb7e8..749b1853 100644 --- a/torchopt/transform/scale_by_schedule.py +++ b/torchopt/transform/scale_by_schedule.py @@ -96,20 +96,24 @@ def update_fn( inplace: bool = True, ) -> tuple[Updates, OptState]: if inplace: - - def f(g: torch.Tensor, c: Numeric) -> torch.Tensor: # pylint: disable=invalid-name + # pylint: disable-next=invalid-name + def f(c: Numeric, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g step_size = step_size_fn(c) return g.mul_(step_size) - updates = tree_map_(f, updates, state.count) + tree_map_(f, state.count, updates) else: - - def f(g: torch.Tensor, c: Numeric) -> torch.Tensor: # pylint: disable=invalid-name + # pylint: disable-next=invalid-name + def f(c: Numeric, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g step_size = step_size_fn(c) return g.mul(step_size) - updates = tree_map(f, updates, state.count) + updates = tree_map(f, state.count, updates) return ( updates, diff --git a/torchopt/transform/scale_by_stddev.py b/torchopt/transform/scale_by_stddev.py index bbbfb384..d9589c45 100644 --- a/torchopt/transform/scale_by_stddev.py +++ b/torchopt/transform/scale_by_stddev.py @@ -148,17 +148,17 @@ def update_fn( if inplace: - def f(g: torch.Tensor, m: torch.Tensor, n: torch.Tensor) -> torch.Tensor: - return g.div_(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) + def f(m: torch.Tensor, n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.div_(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) if g is not None else g - updates = tree_map_(f, updates, mu, nu) + tree_map_(f, mu, nu, updates) else: - def f(g: torch.Tensor, m: torch.Tensor, n: torch.Tensor) -> torch.Tensor: - return g.div(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) + def f(m: torch.Tensor, n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.div(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) if g is not None else g - updates = tree_map(f, updates, mu, nu) + updates = tree_map(f, mu, nu, updates) return updates, ScaleByRStdDevState(mu=mu, nu=nu) diff --git a/torchopt/transform/trace.py b/torchopt/transform/trace.py index 7a1e1971..d530a676 100644 --- a/torchopt/transform/trace.py +++ b/torchopt/transform/trace.py @@ -148,52 +148,60 @@ def update_fn( if nesterov: if inplace: - def f1(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def f1(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if first_call: return t.add_(g) return t.mul_(momentum).add_(g) - def f2(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: - return g.add_(t, alpha=momentum) + def f2(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.add_(t, alpha=momentum) if g is not None else g - new_trace = tree_map(f1, updates, state.trace) - updates = tree_map_(f2, updates, new_trace) + new_trace = tree_map(f1, state.trace, updates) + tree_map_(f2, new_trace, updates) else: - def f1(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def f1(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if first_call: return t.add(g) return t.mul(momentum).add_(g) - def f2(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: - return g.add(t, alpha=momentum) + def f2(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.add(t, alpha=momentum) if g is not None else g - new_trace = tree_map(f1, updates, state.trace) - updates = tree_map(f2, updates, new_trace) + new_trace = tree_map(f1, state.trace, updates) + updates = tree_map(f2, new_trace, updates) else: if inplace: - def f(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def f(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if first_call: return t.add_(g) return t.mul_(momentum).add_(g, alpha=1.0 - dampening) - def copy_(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: - return g.copy_(t) + def copy_to_(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.copy_(t) if g is not None else g - new_trace = tree_map(f, updates, state.trace) - updates = tree_map_(copy_, updates, new_trace) + new_trace = tree_map(f, state.trace, updates) + tree_map_(copy_to_, new_trace, updates) else: - def f(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def f(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if first_call: return t.add(g) return t.mul(momentum).add_(g, alpha=1.0 - dampening) - new_trace = tree_map(f, updates, state.trace) + new_trace = tree_map(f, state.trace, updates) updates = tree_map(torch.clone, new_trace) first_call = False diff --git a/torchopt/transform/utils.py b/torchopt/transform/utils.py index 8c67fd7e..f1ed39da 100644 --- a/torchopt/transform/utils.py +++ b/torchopt/transform/utils.py @@ -160,6 +160,7 @@ def _update_moment_flat( ) +# pylint: disable-next=too-many-arguments def _update_moment( updates: Updates, moments: TensorTree, diff --git a/torchopt/utils.py b/torchopt/utils.py index 5414db80..ef771966 100644 --- a/torchopt/utils.py +++ b/torchopt/utils.py @@ -91,7 +91,7 @@ def fn_(obj: Any) -> None: @overload -def extract_state_dict( +def extract_state_dict( # pylint: disable=too-many-arguments target: nn.Module, *, by: CopyMode = 'reference', @@ -114,7 +114,7 @@ def extract_state_dict( ... -# pylint: disable-next=too-many-branches,too-many-locals +# pylint: disable-next=too-many-arguments,too-many-branches,too-many-locals def extract_state_dict( target: nn.Module | MetaOptimizer, *, From 26e51d430cddb855126c1518450aa8181c8038e8 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 10 Nov 2023 10:53:40 +0800 Subject: [PATCH 6/6] ver: bump version to 0.7.3 --- CHANGELOG.md | 19 ++++++++++++++++--- CITATION.cff | 4 ++-- torchopt/version.py | 2 +- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 797b6ca0..315d24db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,11 +17,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- Set minimal C++ standard to C++17 by [@XuehaiPan](https://github.com/XuehaiPan) in [#195](https://github.com/metaopt/torchopt/pull/195). +- ### Fixed -- Fix `optree` compatibility for multi-tree-map with `None` values by [@XuehaiPan](https://github.com/XuehaiPan) in [#195](https://github.com/metaopt/torchopt/pull/195). +- ### Removed @@ -29,6 +29,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ------ +## [0.7.3] - 2023-11-10 + +### Changed + +- Set minimal C++ standard to C++17 by [@XuehaiPan](https://github.com/XuehaiPan) in [#195](https://github.com/metaopt/torchopt/pull/195). + +### Fixed + +- Fix `optree` compatibility for multi-tree-map with `None` values by [@XuehaiPan](https://github.com/XuehaiPan) in [#195](https://github.com/metaopt/torchopt/pull/195). + +------ + ## [0.7.2] - 2023-08-18 ### Added @@ -195,7 +207,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ------ -[Unreleased]: https://github.com/metaopt/torchopt/compare/v0.7.2...HEAD +[Unreleased]: https://github.com/metaopt/torchopt/compare/v0.7.3...HEAD +[0.7.3]: https://github.com/metaopt/torchopt/compare/v0.7.2...v0.7.3 [0.7.2]: https://github.com/metaopt/torchopt/compare/v0.7.1...v0.7.2 [0.7.1]: https://github.com/metaopt/torchopt/compare/v0.7.0...v0.7.1 [0.7.0]: https://github.com/metaopt/torchopt/compare/v0.6.0...v0.7.0 diff --git a/CITATION.cff b/CITATION.cff index 965b6a7f..3c6098bf 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -32,7 +32,7 @@ authors: family-names: Yang affiliation: Peking University email: yaodong.yang@pku.edu.cn -version: 0.7.2 -date-released: "2023-08-18" +version: 0.7.3 +date-released: "2023-11-10" license: Apache-2.0 repository-code: "https://github.com/metaopt/torchopt" diff --git a/torchopt/version.py b/torchopt/version.py index 87c4fe49..685735e6 100644 --- a/torchopt/version.py +++ b/torchopt/version.py @@ -14,7 +14,7 @@ # ============================================================================== """TorchOpt: a high-performance optimizer library built upon PyTorch.""" -__version__ = '0.7.2' +__version__ = '0.7.3' __license__ = 'Apache License, Version 2.0' __author__ = 'TorchOpt Contributors' __release__ = False pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy