diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0a6c4d6e..bed635e3 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -48,7 +48,7 @@ jobs: timeout-minutes: 60 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 0 @@ -108,7 +108,7 @@ jobs: timeout-minutes: 60 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 0 @@ -132,7 +132,7 @@ jobs: run: python .github/workflows/set_cibw_build.py - name: Build wheels - uses: pypa/cibuildwheel@v2.15 + uses: pypa/cibuildwheel@v2.16 env: CIBW_BUILD: ${{ env.CIBW_BUILD }} with: @@ -158,7 +158,7 @@ jobs: timeout-minutes: 60 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 0 @@ -182,7 +182,7 @@ jobs: run: python .github/workflows/set_cibw_build.py - name: Build wheels - uses: pypa/cibuildwheel@v2.15 + uses: pypa/cibuildwheel@v2.16 env: CIBW_BUILD: ${{ env.CIBW_BUILD }} with: @@ -206,7 +206,7 @@ jobs: timeout-minutes: 15 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 0 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index b338b149..89c26c3c 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -24,15 +24,15 @@ jobs: timeout-minutes: 30 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 1 - - name: Set up Python 3.8 + - name: Set up Python 3.9 uses: actions/setup-python@v4 with: - python-version: "3.8" + python-version: "3.9" update-environment: true - name: Setup CUDA Toolkit diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4f6fad50..6a8f7ec8 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -36,7 +36,7 @@ jobs: timeout-minutes: 60 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 1 @@ -106,7 +106,7 @@ jobs: fail-fast: false steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 1 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e5c37d40..975bb7b6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ ci: default_stages: [commit, push, manual] repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: check-symlinks - id: destroyed-symlinks @@ -26,11 +26,11 @@ repos: - id: debug-statements - id: double-quote-string-fixer - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v16.0.6 + rev: v17.0.4 hooks: - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.284 + rev: v0.1.5 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] @@ -39,11 +39,11 @@ repos: hooks: - id: isort - repo: https://github.com/psf/black - rev: 23.7.0 + rev: 23.11.0 hooks: - id: black-jupyter - repo: https://github.com/asottile/pyupgrade - rev: v3.10.1 + rev: v3.15.0 hooks: - id: pyupgrade args: [--py38-plus] # sync with requires-python @@ -68,7 +68,7 @@ repos: ^docs/source/conf.py$ ) - repo: https://github.com/codespell-project/codespell - rev: v2.2.5 + rev: v2.2.6 hooks: - id: codespell additional_dependencies: [".[toml]"] diff --git a/CHANGELOG.md b/CHANGELOG.md index 24e4eea4..315d24db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ------ +## [0.7.3] - 2023-11-10 + +### Changed + +- Set minimal C++ standard to C++17 by [@XuehaiPan](https://github.com/XuehaiPan) in [#195](https://github.com/metaopt/torchopt/pull/195). + +### Fixed + +- Fix `optree` compatibility for multi-tree-map with `None` values by [@XuehaiPan](https://github.com/XuehaiPan) in [#195](https://github.com/metaopt/torchopt/pull/195). + +------ + ## [0.7.2] - 2023-08-18 ### Added @@ -195,7 +207,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ------ -[Unreleased]: https://github.com/metaopt/torchopt/compare/v0.7.2...HEAD +[Unreleased]: https://github.com/metaopt/torchopt/compare/v0.7.3...HEAD +[0.7.3]: https://github.com/metaopt/torchopt/compare/v0.7.2...v0.7.3 [0.7.2]: https://github.com/metaopt/torchopt/compare/v0.7.1...v0.7.2 [0.7.1]: https://github.com/metaopt/torchopt/compare/v0.7.0...v0.7.1 [0.7.0]: https://github.com/metaopt/torchopt/compare/v0.6.0...v0.7.0 diff --git a/CITATION.cff b/CITATION.cff index 965b6a7f..3c6098bf 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -32,7 +32,7 @@ authors: family-names: Yang affiliation: Peking University email: yaodong.yang@pku.edu.cn -version: 0.7.2 -date-released: "2023-08-18" +version: 0.7.3 +date-released: "2023-11-10" license: Apache-2.0 repository-code: "https://github.com/metaopt/torchopt" diff --git a/CMakeLists.txt b/CMakeLists.txt index 3b091a22..eca14815 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,13 +17,13 @@ cmake_minimum_required(VERSION 3.11) # for FetchContent project(torchopt LANGUAGES CXX) include(FetchContent) -set(PYBIND11_VERSION v2.10.3) +set(PYBIND11_VERSION v2.11.1) if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) endif() -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) find_package(Threads REQUIRED) # -pthread diff --git a/conda-recipe.yaml b/conda-recipe.yaml index 997f11c5..12b3a6d0 100644 --- a/conda-recipe.yaml +++ b/conda-recipe.yaml @@ -77,7 +77,6 @@ dependencies: - hunspell-en - myst-nb - ipykernel - - pandoc - docutils # Testing diff --git a/docs/conda-recipe.yaml b/docs/conda-recipe.yaml index 9a14af3f..30ec372e 100644 --- a/docs/conda-recipe.yaml +++ b/docs/conda-recipe.yaml @@ -67,5 +67,4 @@ dependencies: - hunspell-en - myst-nb - ipykernel - - pandoc - docutils diff --git a/docs/requirements.txt b/docs/requirements.txt index 655c64ff..82bf2d91 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,17 +4,17 @@ torch >= 1.13 --requirement ../requirements.txt -sphinx >= 5.2.1 +sphinx >= 5.2.1, < 7.0.0a0 +sphinxcontrib-bibtex >= 2.4 +sphinx-autodoc-typehints >= 1.20 +myst-nb >= 0.15 + sphinx-autoapi sphinx-autobuild sphinx-copybutton sphinx-rtd-theme sphinxcontrib-katex -sphinxcontrib-bibtex -sphinx-autodoc-typehints >= 1.19.2 IPython ipykernel -pandoc -myst-nb docutils matplotlib diff --git a/docs/source/explicit_diff/explicit_diff.rst b/docs/source/explicit_diff/explicit_diff.rst index 9445adb8..28e06f77 100644 --- a/docs/source/explicit_diff/explicit_diff.rst +++ b/docs/source/explicit_diff/explicit_diff.rst @@ -59,7 +59,7 @@ For PyTorch-like API (e.g., ``step()``), we designed a base class :class:`torcho torchopt.MetaAdagrad torchopt.MetaAdam torchopt.MetaAdamW - torchopt.AdaMax + torchopt.MetaAdaMax torchopt.MetaAdamax torchopt.MetaRAdam torchopt.MetaRMSProp diff --git a/torchopt/alias/sgd.py b/torchopt/alias/sgd.py index 6fb3c6db..4c5b8317 100644 --- a/torchopt/alias/sgd.py +++ b/torchopt/alias/sgd.py @@ -44,6 +44,7 @@ __all__ = ['sgd'] +# pylint: disable-next=too-many-arguments def sgd( lr: ScalarOrSchedule, momentum: float = 0.0, diff --git a/torchopt/alias/utils.py b/torchopt/alias/utils.py index 5c8dc97a..2c2ec0e4 100644 --- a/torchopt/alias/utils.py +++ b/torchopt/alias/utils.py @@ -108,19 +108,21 @@ def update_fn( if inplace: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if g.requires_grad: return g.add_(p, alpha=weight_decay) return g.add_(p.data, alpha=weight_decay) - updates = tree_map_(f, updates, params) + tree_map_(f, params, updates) else: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: - return g.add(p, alpha=weight_decay) + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.add(p, alpha=weight_decay) if g is not None else g - updates = tree_map(f, updates, params) + updates = tree_map(f, params, updates) return updates, state @@ -139,7 +141,7 @@ def update_fn( def f(g: torch.Tensor) -> torch.Tensor: return g.neg_() - updates = tree_map_(f, updates) + tree_map_(f, updates) else: @@ -166,19 +168,21 @@ def update_fn( if inplace: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if g.requires_grad: return g.neg_().add_(p, alpha=weight_decay) return g.neg_().add_(p.data, alpha=weight_decay) - updates = tree_map_(f, updates, params) + tree_map_(f, params, updates) else: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: - return g.neg().add_(p, alpha=weight_decay) + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.neg().add_(p, alpha=weight_decay) if g is not None else g - updates = tree_map(f, updates, params) + updates = tree_map(f, params, updates) return updates, state diff --git a/torchopt/distributed/api.py b/torchopt/distributed/api.py index 3a6f0526..fb7461e4 100644 --- a/torchopt/distributed/api.py +++ b/torchopt/distributed/api.py @@ -271,6 +271,7 @@ def sum_reducer(results: Iterable[torch.Tensor]) -> torch.Tensor: return torch.sum(torch.stack(tuple(results), dim=0), dim=0) +# pylint: disable-next=too-many-arguments def remote_async_call( func: Callable[..., T], *, @@ -328,6 +329,7 @@ def remote_async_call( return future +# pylint: disable-next=too-many-arguments def remote_sync_call( func: Callable[..., T], *, diff --git a/torchopt/linalg/cg.py b/torchopt/linalg/cg.py index 9cd57cd8..42cb6bea 100644 --- a/torchopt/linalg/cg.py +++ b/torchopt/linalg/cg.py @@ -53,7 +53,7 @@ def _identity(x: TensorTree) -> TensorTree: return x -# pylint: disable-next=too-many-locals +# pylint: disable-next=too-many-arguments,too-many-locals def _cg_solve( A: Callable[[TensorTree], TensorTree], b: TensorTree, @@ -102,6 +102,7 @@ def body_fn( return x_final +# pylint: disable-next=too-many-arguments def _isolve( _isolve_solve: Callable, A: TensorTree | Callable[[TensorTree], TensorTree], @@ -134,6 +135,7 @@ def _isolve( return isolve_solve(A, b) +# pylint: disable-next=too-many-arguments def cg( A: TensorTree | Callable[[TensorTree], TensorTree], b: TensorTree, diff --git a/torchopt/linalg/ns.py b/torchopt/linalg/ns.py index 747ad3cf..ce49fe77 100644 --- a/torchopt/linalg/ns.py +++ b/torchopt/linalg/ns.py @@ -123,12 +123,14 @@ def _ns_inv(A: torch.Tensor, maxiter: int, alpha: float | None = None) -> torch. # A^{-1} = a [I - (I - a A)]^{-1} = a [I + (I - a A) + (I - a A)^2 + (I - a A)^3 + ...] M = I - alpha * A for rank in range(maxiter): + # pylint: disable-next=not-callable inv_A_hat = inv_A_hat + torch.linalg.matrix_power(M, rank) inv_A_hat = alpha * inv_A_hat else: # A^{-1} = [I - (I - A)]^{-1} = I + (I - A) + (I - A)^2 + (I - A)^3 + ... M = I - A for rank in range(maxiter): + # pylint: disable-next=not-callable inv_A_hat = inv_A_hat + torch.linalg.matrix_power(M, rank) return inv_A_hat diff --git a/torchopt/nn/stateless.py b/torchopt/nn/stateless.py index e547b5cb..8268ca3f 100644 --- a/torchopt/nn/stateless.py +++ b/torchopt/nn/stateless.py @@ -66,8 +66,8 @@ def recursive_setattr(path: str, value: torch.Tensor) -> torch.Tensor: mod._parameters[attr] = value # type: ignore[assignment] elif hasattr(mod, '_buffers') and attr in mod._buffers: mod._buffers[attr] = value - elif hasattr(mod, '_meta_parameters') and attr in mod._meta_parameters: # type: ignore[operator] - mod._meta_parameters[attr] = value # type: ignore[operator,index] + elif hasattr(mod, '_meta_parameters') and attr in mod._meta_parameters: + mod._meta_parameters[attr] = value else: setattr(mod, attr, value) # pylint: enable=protected-access diff --git a/torchopt/transform/add_decayed_weights.py b/torchopt/transform/add_decayed_weights.py index 04d564d7..39948694 100644 --- a/torchopt/transform/add_decayed_weights.py +++ b/torchopt/transform/add_decayed_weights.py @@ -226,19 +226,21 @@ def update_fn( if inplace: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if g.requires_grad: return g.add_(p, alpha=weight_decay) return g.add_(p.data, alpha=weight_decay) - updates = tree_map_(f, updates, params) + tree_map_(f, params, updates) else: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: - return g.add(p, alpha=weight_decay) + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.add(p, alpha=weight_decay) if g is not None else g - updates = tree_map(f, updates, params) + updates = tree_map(f, params, updates) return updates, state diff --git a/torchopt/transform/scale_by_adadelta.py b/torchopt/transform/scale_by_adadelta.py index fb5431a3..bbe40080 100644 --- a/torchopt/transform/scale_by_adadelta.py +++ b/torchopt/transform/scale_by_adadelta.py @@ -129,23 +129,15 @@ def update_fn( if inplace: - def f( - g: torch.Tensor, # pylint: disable=unused-argument - m: torch.Tensor, - v: torch.Tensor, - ) -> torch.Tensor: - return g.mul_(v.add(eps).div_(m.add(eps)).sqrt_()) + def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.mul_(v.add(eps).div_(m.add(eps)).sqrt_()) if g is not None else g else: - def f( - g: torch.Tensor, # pylint: disable=unused-argument - m: torch.Tensor, - v: torch.Tensor, - ) -> torch.Tensor: - return g.mul(v.add(eps).div_(m.add(eps)).sqrt_()) + def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.mul(v.add(eps).div_(m.add(eps)).sqrt_()) if g is not None else g - updates = tree_map(f, updates, mu, state.nu) + updates = tree_map(f, mu, state.nu, updates) nu = update_moment.impl( # type: ignore[attr-defined] updates, diff --git a/torchopt/transform/scale_by_adam.py b/torchopt/transform/scale_by_adam.py index c3c6254e..cc0ea3b6 100644 --- a/torchopt/transform/scale_by_adam.py +++ b/torchopt/transform/scale_by_adam.py @@ -132,6 +132,7 @@ def _scale_by_adam_flat( ) +# pylint: disable-next=too-many-arguments def _scale_by_adam( b1: float = 0.9, b2: float = 0.999, @@ -200,23 +201,15 @@ def update_fn( if inplace: - def f( - g: torch.Tensor, # pylint: disable=unused-argument - m: torch.Tensor, - v: torch.Tensor, - ) -> torch.Tensor: - return m.div_(v.add_(eps_root).sqrt_().add(eps)) + def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return m.div_(v.add_(eps_root).sqrt_().add(eps)) if g is not None else g else: - def f( - g: torch.Tensor, # pylint: disable=unused-argument - m: torch.Tensor, - v: torch.Tensor, - ) -> torch.Tensor: - return m.div(v.add(eps_root).sqrt_().add(eps)) + def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return m.div(v.add(eps_root).sqrt_().add(eps)) if g is not None else g - updates = tree_map(f, updates, mu_hat, nu_hat) + updates = tree_map(f, mu_hat, nu_hat, updates) return updates, ScaleByAdamState(mu=mu, nu=nu, count=count_inc) return GradientTransformation(init_fn, update_fn) @@ -283,6 +276,7 @@ def _scale_by_accelerated_adam_flat( ) +# pylint: disable-next=too-many-arguments def _scale_by_accelerated_adam( b1: float = 0.9, b2: float = 0.999, diff --git a/torchopt/transform/scale_by_adamax.py b/torchopt/transform/scale_by_adamax.py index 504e82cd..0a1c3ec9 100644 --- a/torchopt/transform/scale_by_adamax.py +++ b/torchopt/transform/scale_by_adamax.py @@ -137,23 +137,17 @@ def update_fn( already_flattened=already_flattened, ) - def update_nu( - g: torch.Tensor, - n: torch.Tensor, - ) -> torch.Tensor: - return torch.max(n.mul(b2), g.abs().add_(eps)) + def update_nu(n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return torch.max(n.mul(b2), g.abs().add_(eps)) if g is not None else g - nu = tree_map(update_nu, updates, state.nu) + nu = tree_map(update_nu, state.nu, updates) one_minus_b1_pow_t = 1 - b1**state.t - def f( - n: torch.Tensor, - m: torch.Tensor, - ) -> torch.Tensor: - return m.div(n).div_(one_minus_b1_pow_t) + def f(m: torch.Tensor, n: torch.Tensor | None) -> torch.Tensor: + return m.div(n).div_(one_minus_b1_pow_t) if n is not None else m - updates = tree_map(f, nu, mu) + updates = tree_map(f, mu, nu) return updates, ScaleByAdamaxState(mu=mu, nu=nu, t=state.t + 1) diff --git a/torchopt/transform/scale_by_rms.py b/torchopt/transform/scale_by_rms.py index ac2fef16..084be839 100644 --- a/torchopt/transform/scale_by_rms.py +++ b/torchopt/transform/scale_by_rms.py @@ -136,17 +136,17 @@ def update_fn( if inplace: - def f(g: torch.Tensor, n: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name - return g.div_(n.sqrt().add_(eps)) + def f(n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.div_(n.sqrt().add_(eps)) if g is not None else g - updates = tree_map_(f, updates, nu) + tree_map_(f, nu, updates) else: - def f(g: torch.Tensor, n: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name - return g.div(n.sqrt().add(eps)) + def f(n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.div(n.sqrt().add(eps)) if g is not None else g - updates = tree_map(f, updates, nu) + updates = tree_map(f, nu, updates) return updates, ScaleByRmsState(nu=nu) diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py index 68021e5e..b1f3d2a8 100644 --- a/torchopt/transform/scale_by_rss.py +++ b/torchopt/transform/scale_by_rss.py @@ -128,23 +128,21 @@ def update_fn( if inplace: - def f(g: torch.Tensor, sos: torch.Tensor) -> torch.Tensor: - return torch.where( - sos > 0.0, - g.div_(sos.sqrt().add_(eps)), - 0.0, + def f(sos: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return ( + torch.where(sos > 0.0, g.div_(sos.sqrt().add_(eps)), 0.0) + if g is not None + else g ) else: - def f(g: torch.Tensor, sos: torch.Tensor) -> torch.Tensor: - return torch.where( - sos > 0.0, - g.div(sos.sqrt().add(eps)), - 0.0, + def f(sos: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return ( + torch.where(sos > 0.0, g.div(sos.sqrt().add(eps)), 0.0) if g is not None else g ) - updates = tree_map(f, updates, sum_of_squares) + updates = tree_map(f, sum_of_squares, updates) return updates, ScaleByRssState(sum_of_squares=sum_of_squares) return GradientTransformation(init_fn, update_fn) diff --git a/torchopt/transform/scale_by_schedule.py b/torchopt/transform/scale_by_schedule.py index f27fb7e8..749b1853 100644 --- a/torchopt/transform/scale_by_schedule.py +++ b/torchopt/transform/scale_by_schedule.py @@ -96,20 +96,24 @@ def update_fn( inplace: bool = True, ) -> tuple[Updates, OptState]: if inplace: - - def f(g: torch.Tensor, c: Numeric) -> torch.Tensor: # pylint: disable=invalid-name + # pylint: disable-next=invalid-name + def f(c: Numeric, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g step_size = step_size_fn(c) return g.mul_(step_size) - updates = tree_map_(f, updates, state.count) + tree_map_(f, state.count, updates) else: - - def f(g: torch.Tensor, c: Numeric) -> torch.Tensor: # pylint: disable=invalid-name + # pylint: disable-next=invalid-name + def f(c: Numeric, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g step_size = step_size_fn(c) return g.mul(step_size) - updates = tree_map(f, updates, state.count) + updates = tree_map(f, state.count, updates) return ( updates, diff --git a/torchopt/transform/scale_by_stddev.py b/torchopt/transform/scale_by_stddev.py index bbbfb384..d9589c45 100644 --- a/torchopt/transform/scale_by_stddev.py +++ b/torchopt/transform/scale_by_stddev.py @@ -148,17 +148,17 @@ def update_fn( if inplace: - def f(g: torch.Tensor, m: torch.Tensor, n: torch.Tensor) -> torch.Tensor: - return g.div_(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) + def f(m: torch.Tensor, n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.div_(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) if g is not None else g - updates = tree_map_(f, updates, mu, nu) + tree_map_(f, mu, nu, updates) else: - def f(g: torch.Tensor, m: torch.Tensor, n: torch.Tensor) -> torch.Tensor: - return g.div(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) + def f(m: torch.Tensor, n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.div(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) if g is not None else g - updates = tree_map(f, updates, mu, nu) + updates = tree_map(f, mu, nu, updates) return updates, ScaleByRStdDevState(mu=mu, nu=nu) diff --git a/torchopt/transform/trace.py b/torchopt/transform/trace.py index 7a1e1971..d530a676 100644 --- a/torchopt/transform/trace.py +++ b/torchopt/transform/trace.py @@ -148,52 +148,60 @@ def update_fn( if nesterov: if inplace: - def f1(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def f1(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if first_call: return t.add_(g) return t.mul_(momentum).add_(g) - def f2(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: - return g.add_(t, alpha=momentum) + def f2(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.add_(t, alpha=momentum) if g is not None else g - new_trace = tree_map(f1, updates, state.trace) - updates = tree_map_(f2, updates, new_trace) + new_trace = tree_map(f1, state.trace, updates) + tree_map_(f2, new_trace, updates) else: - def f1(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def f1(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if first_call: return t.add(g) return t.mul(momentum).add_(g) - def f2(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: - return g.add(t, alpha=momentum) + def f2(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.add(t, alpha=momentum) if g is not None else g - new_trace = tree_map(f1, updates, state.trace) - updates = tree_map(f2, updates, new_trace) + new_trace = tree_map(f1, state.trace, updates) + updates = tree_map(f2, new_trace, updates) else: if inplace: - def f(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def f(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if first_call: return t.add_(g) return t.mul_(momentum).add_(g, alpha=1.0 - dampening) - def copy_(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: - return g.copy_(t) + def copy_to_(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.copy_(t) if g is not None else g - new_trace = tree_map(f, updates, state.trace) - updates = tree_map_(copy_, updates, new_trace) + new_trace = tree_map(f, state.trace, updates) + tree_map_(copy_to_, new_trace, updates) else: - def f(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def f(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if first_call: return t.add(g) return t.mul(momentum).add_(g, alpha=1.0 - dampening) - new_trace = tree_map(f, updates, state.trace) + new_trace = tree_map(f, state.trace, updates) updates = tree_map(torch.clone, new_trace) first_call = False diff --git a/torchopt/transform/utils.py b/torchopt/transform/utils.py index 8c67fd7e..f1ed39da 100644 --- a/torchopt/transform/utils.py +++ b/torchopt/transform/utils.py @@ -160,6 +160,7 @@ def _update_moment_flat( ) +# pylint: disable-next=too-many-arguments def _update_moment( updates: Updates, moments: TensorTree, diff --git a/torchopt/utils.py b/torchopt/utils.py index 5414db80..ef771966 100644 --- a/torchopt/utils.py +++ b/torchopt/utils.py @@ -91,7 +91,7 @@ def fn_(obj: Any) -> None: @overload -def extract_state_dict( +def extract_state_dict( # pylint: disable=too-many-arguments target: nn.Module, *, by: CopyMode = 'reference', @@ -114,7 +114,7 @@ def extract_state_dict( ... -# pylint: disable-next=too-many-branches,too-many-locals +# pylint: disable-next=too-many-arguments,too-many-branches,too-many-locals def extract_state_dict( target: nn.Module | MetaOptimizer, *, diff --git a/torchopt/version.py b/torchopt/version.py index 87c4fe49..685735e6 100644 --- a/torchopt/version.py +++ b/torchopt/version.py @@ -14,7 +14,7 @@ # ============================================================================== """TorchOpt: a high-performance optimizer library built upon PyTorch.""" -__version__ = '0.7.2' +__version__ = '0.7.3' __license__ = 'Apache License, Version 2.0' __author__ = 'TorchOpt Contributors' __release__ = False
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: