From a0cfcdfbdcd14b217515b088b8c0d5936cbe2559 Mon Sep 17 00:00:00 2001 From: Bo Liu Date: Sun, 20 Aug 2023 21:56:43 +0800 Subject: [PATCH 01/26] docs: fix typo in explicit differentiation (#188) --- docs/source/explicit_diff/explicit_diff.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/explicit_diff/explicit_diff.rst b/docs/source/explicit_diff/explicit_diff.rst index 9445adb8..28e06f77 100644 --- a/docs/source/explicit_diff/explicit_diff.rst +++ b/docs/source/explicit_diff/explicit_diff.rst @@ -59,7 +59,7 @@ For PyTorch-like API (e.g., ``step()``), we designed a base class :class:`torcho torchopt.MetaAdagrad torchopt.MetaAdam torchopt.MetaAdamW - torchopt.AdaMax + torchopt.MetaAdaMax torchopt.MetaAdamax torchopt.MetaRAdam torchopt.MetaRMSProp From 91b85ae91ead9349b33260ae1f564d3cfca7457e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 5 Sep 2023 16:48:08 +0800 Subject: [PATCH 02/26] chore(pre-commit): [pre-commit.ci] autoupdate (#190) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e5c37d40..4d2ef96d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,7 +30,7 @@ repos: hooks: - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.284 + rev: v0.0.287 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] From 2117c18d4b86dc7dbcb97a23dae6010265698401 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 11 Sep 2023 14:42:52 +0800 Subject: [PATCH 03/26] deps(workflows): bump actions/checkout from 3 to 4 (#192) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/build.yml | 8 ++++---- .github/workflows/lint.yml | 2 +- .github/workflows/tests.yml | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0a6c4d6e..4f15d1a4 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -48,7 +48,7 @@ jobs: timeout-minutes: 60 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 0 @@ -108,7 +108,7 @@ jobs: timeout-minutes: 60 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 0 @@ -158,7 +158,7 @@ jobs: timeout-minutes: 60 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 0 @@ -206,7 +206,7 @@ jobs: timeout-minutes: 15 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 0 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index b338b149..80ed2876 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -24,7 +24,7 @@ jobs: timeout-minutes: 30 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 1 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4f6fad50..6a8f7ec8 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -36,7 +36,7 @@ jobs: timeout-minutes: 60 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 1 @@ -106,7 +106,7 @@ jobs: fail-fast: false steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 1 From 86b167c054b158c8d49c262b30c430838a06b794 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 25 Sep 2023 14:55:52 +0800 Subject: [PATCH 04/26] deps(workflows): bump pypa/cibuildwheel from 2.15 to 2.16 (#193) Bumps [pypa/cibuildwheel](https://github.com/pypa/cibuildwheel) from 2.15 to 2.16. - [Release notes](https://github.com/pypa/cibuildwheel/releases) - [Changelog](https://github.com/pypa/cibuildwheel/blob/main/docs/changelog.md) - [Commits](https://github.com/pypa/cibuildwheel/compare/v2.15...v2.16) --- updated-dependencies: - dependency-name: pypa/cibuildwheel dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/build.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 4f15d1a4..bed635e3 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -132,7 +132,7 @@ jobs: run: python .github/workflows/set_cibw_build.py - name: Build wheels - uses: pypa/cibuildwheel@v2.15 + uses: pypa/cibuildwheel@v2.16 env: CIBW_BUILD: ${{ env.CIBW_BUILD }} with: @@ -182,7 +182,7 @@ jobs: run: python .github/workflows/set_cibw_build.py - name: Build wheels - uses: pypa/cibuildwheel@v2.15 + uses: pypa/cibuildwheel@v2.16 env: CIBW_BUILD: ${{ env.CIBW_BUILD }} with: From 93cc7ec81040dbc0977cecaa9d55f11fd8c32dcf Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 9 Nov 2023 19:25:45 +0800 Subject: [PATCH 05/26] fix: fix `optree` compatibility for multi-tree-map with `None` values (#195) --- .github/workflows/lint.yml | 4 +-- .pre-commit-config.yaml | 12 +++---- CHANGELOG.md | 4 +-- CMakeLists.txt | 4 +-- conda-recipe.yaml | 1 - docs/conda-recipe.yaml | 1 - docs/requirements.txt | 10 +++--- torchopt/alias/sgd.py | 1 + torchopt/alias/utils.py | 26 ++++++++------ torchopt/distributed/api.py | 2 ++ torchopt/linalg/cg.py | 4 ++- torchopt/linalg/ns.py | 2 ++ torchopt/nn/stateless.py | 4 +-- torchopt/transform/add_decayed_weights.py | 12 ++++--- torchopt/transform/scale_by_adadelta.py | 18 +++------- torchopt/transform/scale_by_adam.py | 20 ++++------- torchopt/transform/scale_by_adamax.py | 18 ++++------ torchopt/transform/scale_by_rms.py | 12 +++---- torchopt/transform/scale_by_rss.py | 20 +++++------ torchopt/transform/scale_by_schedule.py | 16 +++++---- torchopt/transform/scale_by_stddev.py | 12 +++---- torchopt/transform/trace.py | 42 ++++++++++++++--------- torchopt/transform/utils.py | 1 + torchopt/utils.py | 4 +-- 24 files changed, 126 insertions(+), 124 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 80ed2876..89c26c3c 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -29,10 +29,10 @@ jobs: submodules: "recursive" fetch-depth: 1 - - name: Set up Python 3.8 + - name: Set up Python 3.9 uses: actions/setup-python@v4 with: - python-version: "3.8" + python-version: "3.9" update-environment: true - name: Setup CUDA Toolkit diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4d2ef96d..975bb7b6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ ci: default_stages: [commit, push, manual] repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: check-symlinks - id: destroyed-symlinks @@ -26,11 +26,11 @@ repos: - id: debug-statements - id: double-quote-string-fixer - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v16.0.6 + rev: v17.0.4 hooks: - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.287 + rev: v0.1.5 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] @@ -39,11 +39,11 @@ repos: hooks: - id: isort - repo: https://github.com/psf/black - rev: 23.7.0 + rev: 23.11.0 hooks: - id: black-jupyter - repo: https://github.com/asottile/pyupgrade - rev: v3.10.1 + rev: v3.15.0 hooks: - id: pyupgrade args: [--py38-plus] # sync with requires-python @@ -68,7 +68,7 @@ repos: ^docs/source/conf.py$ ) - repo: https://github.com/codespell-project/codespell - rev: v2.2.5 + rev: v2.2.6 hooks: - id: codespell additional_dependencies: [".[toml]"] diff --git a/CHANGELOG.md b/CHANGELOG.md index 24e4eea4..797b6ca0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,11 +17,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- +- Set minimal C++ standard to C++17 by [@XuehaiPan](https://github.com/XuehaiPan) in [#195](https://github.com/metaopt/torchopt/pull/195). ### Fixed -- +- Fix `optree` compatibility for multi-tree-map with `None` values by [@XuehaiPan](https://github.com/XuehaiPan) in [#195](https://github.com/metaopt/torchopt/pull/195). ### Removed diff --git a/CMakeLists.txt b/CMakeLists.txt index 3b091a22..eca14815 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,13 +17,13 @@ cmake_minimum_required(VERSION 3.11) # for FetchContent project(torchopt LANGUAGES CXX) include(FetchContent) -set(PYBIND11_VERSION v2.10.3) +set(PYBIND11_VERSION v2.11.1) if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) endif() -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) find_package(Threads REQUIRED) # -pthread diff --git a/conda-recipe.yaml b/conda-recipe.yaml index 997f11c5..12b3a6d0 100644 --- a/conda-recipe.yaml +++ b/conda-recipe.yaml @@ -77,7 +77,6 @@ dependencies: - hunspell-en - myst-nb - ipykernel - - pandoc - docutils # Testing diff --git a/docs/conda-recipe.yaml b/docs/conda-recipe.yaml index 9a14af3f..30ec372e 100644 --- a/docs/conda-recipe.yaml +++ b/docs/conda-recipe.yaml @@ -67,5 +67,4 @@ dependencies: - hunspell-en - myst-nb - ipykernel - - pandoc - docutils diff --git a/docs/requirements.txt b/docs/requirements.txt index 655c64ff..82bf2d91 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,17 +4,17 @@ torch >= 1.13 --requirement ../requirements.txt -sphinx >= 5.2.1 +sphinx >= 5.2.1, < 7.0.0a0 +sphinxcontrib-bibtex >= 2.4 +sphinx-autodoc-typehints >= 1.20 +myst-nb >= 0.15 + sphinx-autoapi sphinx-autobuild sphinx-copybutton sphinx-rtd-theme sphinxcontrib-katex -sphinxcontrib-bibtex -sphinx-autodoc-typehints >= 1.19.2 IPython ipykernel -pandoc -myst-nb docutils matplotlib diff --git a/torchopt/alias/sgd.py b/torchopt/alias/sgd.py index 6fb3c6db..4c5b8317 100644 --- a/torchopt/alias/sgd.py +++ b/torchopt/alias/sgd.py @@ -44,6 +44,7 @@ __all__ = ['sgd'] +# pylint: disable-next=too-many-arguments def sgd( lr: ScalarOrSchedule, momentum: float = 0.0, diff --git a/torchopt/alias/utils.py b/torchopt/alias/utils.py index 5c8dc97a..2c2ec0e4 100644 --- a/torchopt/alias/utils.py +++ b/torchopt/alias/utils.py @@ -108,19 +108,21 @@ def update_fn( if inplace: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if g.requires_grad: return g.add_(p, alpha=weight_decay) return g.add_(p.data, alpha=weight_decay) - updates = tree_map_(f, updates, params) + tree_map_(f, params, updates) else: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: - return g.add(p, alpha=weight_decay) + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.add(p, alpha=weight_decay) if g is not None else g - updates = tree_map(f, updates, params) + updates = tree_map(f, params, updates) return updates, state @@ -139,7 +141,7 @@ def update_fn( def f(g: torch.Tensor) -> torch.Tensor: return g.neg_() - updates = tree_map_(f, updates) + tree_map_(f, updates) else: @@ -166,19 +168,21 @@ def update_fn( if inplace: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if g.requires_grad: return g.neg_().add_(p, alpha=weight_decay) return g.neg_().add_(p.data, alpha=weight_decay) - updates = tree_map_(f, updates, params) + tree_map_(f, params, updates) else: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: - return g.neg().add_(p, alpha=weight_decay) + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.neg().add_(p, alpha=weight_decay) if g is not None else g - updates = tree_map(f, updates, params) + updates = tree_map(f, params, updates) return updates, state diff --git a/torchopt/distributed/api.py b/torchopt/distributed/api.py index 3a6f0526..fb7461e4 100644 --- a/torchopt/distributed/api.py +++ b/torchopt/distributed/api.py @@ -271,6 +271,7 @@ def sum_reducer(results: Iterable[torch.Tensor]) -> torch.Tensor: return torch.sum(torch.stack(tuple(results), dim=0), dim=0) +# pylint: disable-next=too-many-arguments def remote_async_call( func: Callable[..., T], *, @@ -328,6 +329,7 @@ def remote_async_call( return future +# pylint: disable-next=too-many-arguments def remote_sync_call( func: Callable[..., T], *, diff --git a/torchopt/linalg/cg.py b/torchopt/linalg/cg.py index 9cd57cd8..42cb6bea 100644 --- a/torchopt/linalg/cg.py +++ b/torchopt/linalg/cg.py @@ -53,7 +53,7 @@ def _identity(x: TensorTree) -> TensorTree: return x -# pylint: disable-next=too-many-locals +# pylint: disable-next=too-many-arguments,too-many-locals def _cg_solve( A: Callable[[TensorTree], TensorTree], b: TensorTree, @@ -102,6 +102,7 @@ def body_fn( return x_final +# pylint: disable-next=too-many-arguments def _isolve( _isolve_solve: Callable, A: TensorTree | Callable[[TensorTree], TensorTree], @@ -134,6 +135,7 @@ def _isolve( return isolve_solve(A, b) +# pylint: disable-next=too-many-arguments def cg( A: TensorTree | Callable[[TensorTree], TensorTree], b: TensorTree, diff --git a/torchopt/linalg/ns.py b/torchopt/linalg/ns.py index 747ad3cf..ce49fe77 100644 --- a/torchopt/linalg/ns.py +++ b/torchopt/linalg/ns.py @@ -123,12 +123,14 @@ def _ns_inv(A: torch.Tensor, maxiter: int, alpha: float | None = None) -> torch. # A^{-1} = a [I - (I - a A)]^{-1} = a [I + (I - a A) + (I - a A)^2 + (I - a A)^3 + ...] M = I - alpha * A for rank in range(maxiter): + # pylint: disable-next=not-callable inv_A_hat = inv_A_hat + torch.linalg.matrix_power(M, rank) inv_A_hat = alpha * inv_A_hat else: # A^{-1} = [I - (I - A)]^{-1} = I + (I - A) + (I - A)^2 + (I - A)^3 + ... M = I - A for rank in range(maxiter): + # pylint: disable-next=not-callable inv_A_hat = inv_A_hat + torch.linalg.matrix_power(M, rank) return inv_A_hat diff --git a/torchopt/nn/stateless.py b/torchopt/nn/stateless.py index e547b5cb..8268ca3f 100644 --- a/torchopt/nn/stateless.py +++ b/torchopt/nn/stateless.py @@ -66,8 +66,8 @@ def recursive_setattr(path: str, value: torch.Tensor) -> torch.Tensor: mod._parameters[attr] = value # type: ignore[assignment] elif hasattr(mod, '_buffers') and attr in mod._buffers: mod._buffers[attr] = value - elif hasattr(mod, '_meta_parameters') and attr in mod._meta_parameters: # type: ignore[operator] - mod._meta_parameters[attr] = value # type: ignore[operator,index] + elif hasattr(mod, '_meta_parameters') and attr in mod._meta_parameters: + mod._meta_parameters[attr] = value else: setattr(mod, attr, value) # pylint: enable=protected-access diff --git a/torchopt/transform/add_decayed_weights.py b/torchopt/transform/add_decayed_weights.py index 04d564d7..39948694 100644 --- a/torchopt/transform/add_decayed_weights.py +++ b/torchopt/transform/add_decayed_weights.py @@ -226,19 +226,21 @@ def update_fn( if inplace: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if g.requires_grad: return g.add_(p, alpha=weight_decay) return g.add_(p.data, alpha=weight_decay) - updates = tree_map_(f, updates, params) + tree_map_(f, params, updates) else: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: - return g.add(p, alpha=weight_decay) + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.add(p, alpha=weight_decay) if g is not None else g - updates = tree_map(f, updates, params) + updates = tree_map(f, params, updates) return updates, state diff --git a/torchopt/transform/scale_by_adadelta.py b/torchopt/transform/scale_by_adadelta.py index fb5431a3..bbe40080 100644 --- a/torchopt/transform/scale_by_adadelta.py +++ b/torchopt/transform/scale_by_adadelta.py @@ -129,23 +129,15 @@ def update_fn( if inplace: - def f( - g: torch.Tensor, # pylint: disable=unused-argument - m: torch.Tensor, - v: torch.Tensor, - ) -> torch.Tensor: - return g.mul_(v.add(eps).div_(m.add(eps)).sqrt_()) + def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.mul_(v.add(eps).div_(m.add(eps)).sqrt_()) if g is not None else g else: - def f( - g: torch.Tensor, # pylint: disable=unused-argument - m: torch.Tensor, - v: torch.Tensor, - ) -> torch.Tensor: - return g.mul(v.add(eps).div_(m.add(eps)).sqrt_()) + def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.mul(v.add(eps).div_(m.add(eps)).sqrt_()) if g is not None else g - updates = tree_map(f, updates, mu, state.nu) + updates = tree_map(f, mu, state.nu, updates) nu = update_moment.impl( # type: ignore[attr-defined] updates, diff --git a/torchopt/transform/scale_by_adam.py b/torchopt/transform/scale_by_adam.py index c3c6254e..cc0ea3b6 100644 --- a/torchopt/transform/scale_by_adam.py +++ b/torchopt/transform/scale_by_adam.py @@ -132,6 +132,7 @@ def _scale_by_adam_flat( ) +# pylint: disable-next=too-many-arguments def _scale_by_adam( b1: float = 0.9, b2: float = 0.999, @@ -200,23 +201,15 @@ def update_fn( if inplace: - def f( - g: torch.Tensor, # pylint: disable=unused-argument - m: torch.Tensor, - v: torch.Tensor, - ) -> torch.Tensor: - return m.div_(v.add_(eps_root).sqrt_().add(eps)) + def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return m.div_(v.add_(eps_root).sqrt_().add(eps)) if g is not None else g else: - def f( - g: torch.Tensor, # pylint: disable=unused-argument - m: torch.Tensor, - v: torch.Tensor, - ) -> torch.Tensor: - return m.div(v.add(eps_root).sqrt_().add(eps)) + def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return m.div(v.add(eps_root).sqrt_().add(eps)) if g is not None else g - updates = tree_map(f, updates, mu_hat, nu_hat) + updates = tree_map(f, mu_hat, nu_hat, updates) return updates, ScaleByAdamState(mu=mu, nu=nu, count=count_inc) return GradientTransformation(init_fn, update_fn) @@ -283,6 +276,7 @@ def _scale_by_accelerated_adam_flat( ) +# pylint: disable-next=too-many-arguments def _scale_by_accelerated_adam( b1: float = 0.9, b2: float = 0.999, diff --git a/torchopt/transform/scale_by_adamax.py b/torchopt/transform/scale_by_adamax.py index 504e82cd..0a1c3ec9 100644 --- a/torchopt/transform/scale_by_adamax.py +++ b/torchopt/transform/scale_by_adamax.py @@ -137,23 +137,17 @@ def update_fn( already_flattened=already_flattened, ) - def update_nu( - g: torch.Tensor, - n: torch.Tensor, - ) -> torch.Tensor: - return torch.max(n.mul(b2), g.abs().add_(eps)) + def update_nu(n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return torch.max(n.mul(b2), g.abs().add_(eps)) if g is not None else g - nu = tree_map(update_nu, updates, state.nu) + nu = tree_map(update_nu, state.nu, updates) one_minus_b1_pow_t = 1 - b1**state.t - def f( - n: torch.Tensor, - m: torch.Tensor, - ) -> torch.Tensor: - return m.div(n).div_(one_minus_b1_pow_t) + def f(m: torch.Tensor, n: torch.Tensor | None) -> torch.Tensor: + return m.div(n).div_(one_minus_b1_pow_t) if n is not None else m - updates = tree_map(f, nu, mu) + updates = tree_map(f, mu, nu) return updates, ScaleByAdamaxState(mu=mu, nu=nu, t=state.t + 1) diff --git a/torchopt/transform/scale_by_rms.py b/torchopt/transform/scale_by_rms.py index ac2fef16..084be839 100644 --- a/torchopt/transform/scale_by_rms.py +++ b/torchopt/transform/scale_by_rms.py @@ -136,17 +136,17 @@ def update_fn( if inplace: - def f(g: torch.Tensor, n: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name - return g.div_(n.sqrt().add_(eps)) + def f(n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.div_(n.sqrt().add_(eps)) if g is not None else g - updates = tree_map_(f, updates, nu) + tree_map_(f, nu, updates) else: - def f(g: torch.Tensor, n: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name - return g.div(n.sqrt().add(eps)) + def f(n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.div(n.sqrt().add(eps)) if g is not None else g - updates = tree_map(f, updates, nu) + updates = tree_map(f, nu, updates) return updates, ScaleByRmsState(nu=nu) diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py index 68021e5e..b1f3d2a8 100644 --- a/torchopt/transform/scale_by_rss.py +++ b/torchopt/transform/scale_by_rss.py @@ -128,23 +128,21 @@ def update_fn( if inplace: - def f(g: torch.Tensor, sos: torch.Tensor) -> torch.Tensor: - return torch.where( - sos > 0.0, - g.div_(sos.sqrt().add_(eps)), - 0.0, + def f(sos: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return ( + torch.where(sos > 0.0, g.div_(sos.sqrt().add_(eps)), 0.0) + if g is not None + else g ) else: - def f(g: torch.Tensor, sos: torch.Tensor) -> torch.Tensor: - return torch.where( - sos > 0.0, - g.div(sos.sqrt().add(eps)), - 0.0, + def f(sos: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return ( + torch.where(sos > 0.0, g.div(sos.sqrt().add(eps)), 0.0) if g is not None else g ) - updates = tree_map(f, updates, sum_of_squares) + updates = tree_map(f, sum_of_squares, updates) return updates, ScaleByRssState(sum_of_squares=sum_of_squares) return GradientTransformation(init_fn, update_fn) diff --git a/torchopt/transform/scale_by_schedule.py b/torchopt/transform/scale_by_schedule.py index f27fb7e8..749b1853 100644 --- a/torchopt/transform/scale_by_schedule.py +++ b/torchopt/transform/scale_by_schedule.py @@ -96,20 +96,24 @@ def update_fn( inplace: bool = True, ) -> tuple[Updates, OptState]: if inplace: - - def f(g: torch.Tensor, c: Numeric) -> torch.Tensor: # pylint: disable=invalid-name + # pylint: disable-next=invalid-name + def f(c: Numeric, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g step_size = step_size_fn(c) return g.mul_(step_size) - updates = tree_map_(f, updates, state.count) + tree_map_(f, state.count, updates) else: - - def f(g: torch.Tensor, c: Numeric) -> torch.Tensor: # pylint: disable=invalid-name + # pylint: disable-next=invalid-name + def f(c: Numeric, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g step_size = step_size_fn(c) return g.mul(step_size) - updates = tree_map(f, updates, state.count) + updates = tree_map(f, state.count, updates) return ( updates, diff --git a/torchopt/transform/scale_by_stddev.py b/torchopt/transform/scale_by_stddev.py index bbbfb384..d9589c45 100644 --- a/torchopt/transform/scale_by_stddev.py +++ b/torchopt/transform/scale_by_stddev.py @@ -148,17 +148,17 @@ def update_fn( if inplace: - def f(g: torch.Tensor, m: torch.Tensor, n: torch.Tensor) -> torch.Tensor: - return g.div_(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) + def f(m: torch.Tensor, n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.div_(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) if g is not None else g - updates = tree_map_(f, updates, mu, nu) + tree_map_(f, mu, nu, updates) else: - def f(g: torch.Tensor, m: torch.Tensor, n: torch.Tensor) -> torch.Tensor: - return g.div(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) + def f(m: torch.Tensor, n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.div(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) if g is not None else g - updates = tree_map(f, updates, mu, nu) + updates = tree_map(f, mu, nu, updates) return updates, ScaleByRStdDevState(mu=mu, nu=nu) diff --git a/torchopt/transform/trace.py b/torchopt/transform/trace.py index 7a1e1971..d530a676 100644 --- a/torchopt/transform/trace.py +++ b/torchopt/transform/trace.py @@ -148,52 +148,60 @@ def update_fn( if nesterov: if inplace: - def f1(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def f1(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if first_call: return t.add_(g) return t.mul_(momentum).add_(g) - def f2(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: - return g.add_(t, alpha=momentum) + def f2(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.add_(t, alpha=momentum) if g is not None else g - new_trace = tree_map(f1, updates, state.trace) - updates = tree_map_(f2, updates, new_trace) + new_trace = tree_map(f1, state.trace, updates) + tree_map_(f2, new_trace, updates) else: - def f1(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def f1(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if first_call: return t.add(g) return t.mul(momentum).add_(g) - def f2(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: - return g.add(t, alpha=momentum) + def f2(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.add(t, alpha=momentum) if g is not None else g - new_trace = tree_map(f1, updates, state.trace) - updates = tree_map(f2, updates, new_trace) + new_trace = tree_map(f1, state.trace, updates) + updates = tree_map(f2, new_trace, updates) else: if inplace: - def f(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def f(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if first_call: return t.add_(g) return t.mul_(momentum).add_(g, alpha=1.0 - dampening) - def copy_(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: - return g.copy_(t) + def copy_to_(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.copy_(t) if g is not None else g - new_trace = tree_map(f, updates, state.trace) - updates = tree_map_(copy_, updates, new_trace) + new_trace = tree_map(f, state.trace, updates) + tree_map_(copy_to_, new_trace, updates) else: - def f(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def f(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if first_call: return t.add(g) return t.mul(momentum).add_(g, alpha=1.0 - dampening) - new_trace = tree_map(f, updates, state.trace) + new_trace = tree_map(f, state.trace, updates) updates = tree_map(torch.clone, new_trace) first_call = False diff --git a/torchopt/transform/utils.py b/torchopt/transform/utils.py index 8c67fd7e..f1ed39da 100644 --- a/torchopt/transform/utils.py +++ b/torchopt/transform/utils.py @@ -160,6 +160,7 @@ def _update_moment_flat( ) +# pylint: disable-next=too-many-arguments def _update_moment( updates: Updates, moments: TensorTree, diff --git a/torchopt/utils.py b/torchopt/utils.py index 5414db80..ef771966 100644 --- a/torchopt/utils.py +++ b/torchopt/utils.py @@ -91,7 +91,7 @@ def fn_(obj: Any) -> None: @overload -def extract_state_dict( +def extract_state_dict( # pylint: disable=too-many-arguments target: nn.Module, *, by: CopyMode = 'reference', @@ -114,7 +114,7 @@ def extract_state_dict( ... -# pylint: disable-next=too-many-branches,too-many-locals +# pylint: disable-next=too-many-arguments,too-many-branches,too-many-locals def extract_state_dict( target: nn.Module | MetaOptimizer, *, From 26e51d430cddb855126c1518450aa8181c8038e8 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 10 Nov 2023 10:53:40 +0800 Subject: [PATCH 06/26] ver: bump version to 0.7.3 --- CHANGELOG.md | 19 ++++++++++++++++--- CITATION.cff | 4 ++-- torchopt/version.py | 2 +- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 797b6ca0..315d24db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,11 +17,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- Set minimal C++ standard to C++17 by [@XuehaiPan](https://github.com/XuehaiPan) in [#195](https://github.com/metaopt/torchopt/pull/195). +- ### Fixed -- Fix `optree` compatibility for multi-tree-map with `None` values by [@XuehaiPan](https://github.com/XuehaiPan) in [#195](https://github.com/metaopt/torchopt/pull/195). +- ### Removed @@ -29,6 +29,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ------ +## [0.7.3] - 2023-11-10 + +### Changed + +- Set minimal C++ standard to C++17 by [@XuehaiPan](https://github.com/XuehaiPan) in [#195](https://github.com/metaopt/torchopt/pull/195). + +### Fixed + +- Fix `optree` compatibility for multi-tree-map with `None` values by [@XuehaiPan](https://github.com/XuehaiPan) in [#195](https://github.com/metaopt/torchopt/pull/195). + +------ + ## [0.7.2] - 2023-08-18 ### Added @@ -195,7 +207,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ------ -[Unreleased]: https://github.com/metaopt/torchopt/compare/v0.7.2...HEAD +[Unreleased]: https://github.com/metaopt/torchopt/compare/v0.7.3...HEAD +[0.7.3]: https://github.com/metaopt/torchopt/compare/v0.7.2...v0.7.3 [0.7.2]: https://github.com/metaopt/torchopt/compare/v0.7.1...v0.7.2 [0.7.1]: https://github.com/metaopt/torchopt/compare/v0.7.0...v0.7.1 [0.7.0]: https://github.com/metaopt/torchopt/compare/v0.6.0...v0.7.0 diff --git a/CITATION.cff b/CITATION.cff index 965b6a7f..3c6098bf 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -32,7 +32,7 @@ authors: family-names: Yang affiliation: Peking University email: yaodong.yang@pku.edu.cn -version: 0.7.2 -date-released: "2023-08-18" +version: 0.7.3 +date-released: "2023-11-10" license: Apache-2.0 repository-code: "https://github.com/metaopt/torchopt" diff --git a/torchopt/version.py b/torchopt/version.py index 87c4fe49..685735e6 100644 --- a/torchopt/version.py +++ b/torchopt/version.py @@ -14,7 +14,7 @@ # ============================================================================== """TorchOpt: a high-performance optimizer library built upon PyTorch.""" -__version__ = '0.7.2' +__version__ = '0.7.3' __license__ = 'Apache License, Version 2.0' __author__ = 'TorchOpt Contributors' __release__ = False From 4fd9d290fbe0142efeea84cc4bd808072a4f0466 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 10 Dec 2023 11:45:35 +0800 Subject: [PATCH 07/26] chore(pre-commit): [pre-commit.ci] autoupdate (#196) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xuehai Pan --- .pre-commit-config.yaml | 4 ++-- pyproject.toml | 1 + torchopt/diff/implicit/decorator.py | 2 +- torchopt/distributed/api.py | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 975bb7b6..3ca22436 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,11 +26,11 @@ repos: - id: debug-statements - id: double-quote-string-fixer - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v17.0.4 + rev: v17.0.6 hooks: - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.5 + rev: v0.1.6 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/pyproject.toml b/pyproject.toml index 47424855..31d20f88 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -314,4 +314,5 @@ filterwarnings = [ 'ignore:Explicitly requested dtype float64 requested in .* is not available, and will be truncated to dtype float32\.:UserWarning', 'ignore:jax\.numpy\.DeviceArray is deprecated\. Use jax\.Array\.:DeprecationWarning', 'ignore:.*functorch.*deprecate.*:UserWarning', + 'ignore:.*Apple Paravirtual device.*:UserWarning', ] diff --git a/torchopt/diff/implicit/decorator.py b/torchopt/diff/implicit/decorator.py index 03720a49..1fc9bb4a 100644 --- a/torchopt/diff/implicit/decorator.py +++ b/torchopt/diff/implicit/decorator.py @@ -279,7 +279,7 @@ def make_custom_vjp_solver_fn( # pylint: disable-next=missing-class-docstring,abstract-method class ImplicitMetaGradient(Function): @staticmethod - def forward( # type: ignore[override] # pylint: disable=arguments-differ + def forward( # pylint: disable=arguments-differ ctx: Any, *flat_args: Any, ) -> tuple[Any, ...]: diff --git a/torchopt/distributed/api.py b/torchopt/distributed/api.py index fb7461e4..86c2cfe8 100644 --- a/torchopt/distributed/api.py +++ b/torchopt/distributed/api.py @@ -324,7 +324,7 @@ def remote_async_call( if reducer is not None: return cast( Future[U], - future.then(lambda fut: cast(Callable[[Iterable[T]], U], reducer)(fut.wait())), + future.then(lambda fut: reducer(fut.wait())), ) return future From 2769849072685cceadaea6d9545bdd9d6d2eca5a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 11 Dec 2023 15:50:21 +0800 Subject: [PATCH 08/26] deps(workflows): bump actions/setup-python from 4 to 5 (#198) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Xuehai Pan --- .github/workflows/build.yml | 8 ++++---- .github/workflows/lint.yml | 2 +- .github/workflows/tests.yml | 4 ++-- .pre-commit-config.yaml | 4 ++-- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index bed635e3..56230154 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -54,7 +54,7 @@ jobs: fetch-depth: 0 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.8 - 3.11" # sync with requires-python in pyproject.toml update-environment: true @@ -114,7 +114,7 @@ jobs: fetch-depth: 0 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} update-environment: true @@ -164,7 +164,7 @@ jobs: fetch-depth: 0 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} update-environment: true @@ -212,7 +212,7 @@ jobs: fetch-depth: 0 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 if: startsWith(github.ref, 'refs/tags/') with: python-version: "3.8 - 3.11" # sync with requires-python in pyproject.toml diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 89c26c3c..62f96340 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -30,7 +30,7 @@ jobs: fetch-depth: 1 - name: Set up Python 3.9 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.9" update-environment: true diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6a8f7ec8..2f9f03d9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -42,7 +42,7 @@ jobs: fetch-depth: 1 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.8" # the lowest version we support (sync with requires-python in pyproject.toml) update-environment: true @@ -112,7 +112,7 @@ jobs: fetch-depth: 1 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.8" # the lowest version we support (sync with requires-python in pyproject.toml) update-environment: true diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3ca22436..3a6e633e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,12 +30,12 @@ repos: hooks: - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.6 + rev: v0.1.7 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/PyCQA/isort - rev: 5.12.0 + rev: 5.13.0 hooks: - id: isort - repo: https://github.com/psf/black From 6d1191a3ba7c70f2df23241ad82a67ff322e04b3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 19 Dec 2023 15:00:31 +0800 Subject: [PATCH 09/26] deps(workflows): bump actions/{upload,download}-artifact from 3 to 4 (#200) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Xuehai Pan --- .github/workflows/build.yml | 86 ++++++++++++++++++++++--------------- .pre-commit-config.yaml | 6 +-- 2 files changed, 55 insertions(+), 37 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 56230154..921cf4af 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -78,7 +78,7 @@ jobs: TORCHOPT_NO_EXTENSIONS: "true" - name: Upload artifact - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: build path: dist/* @@ -97,12 +97,13 @@ jobs: make pytest build-wheels-py38: - name: Build wheels for Python ${{ matrix.python-version }} on ubuntu-latest - runs-on: ubuntu-latest + name: Build wheels for Python ${{ matrix.python-version }} on ${{ matrix.os }} + runs-on: ${{ matrix.os }} needs: [build] if: github.repository == 'metaopt/torchopt' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/')) strategy: matrix: + os: [ubuntu-latest] python-version: ["3.8"] # sync with requires-python in pyproject.toml fail-fast: false timeout-minutes: 60 @@ -140,19 +141,20 @@ jobs: output-dir: wheelhouse config-file: "{package}/pyproject.toml" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: - name: wheels-py38 + name: wheels-${{ matrix.python-version }}-${{ matrix.os }} path: wheelhouse/*.whl if-no-files-found: error build-wheels: - name: Build wheels for Python ${{ matrix.python-version }} on ubuntu-latest - runs-on: ubuntu-latest + name: Build wheels for Python ${{ matrix.python-version }} on ${{ matrix.os }} + runs-on: ${{ matrix.os }} needs: [build, build-wheels-py38] if: github.repository == 'metaopt/torchopt' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/')) strategy: matrix: + os: [ubuntu-latest] python-version: ["3.9", "3.10", "3.11"] # sync with requires-python in pyproject.toml fail-fast: false timeout-minutes: 60 @@ -190,15 +192,47 @@ jobs: output-dir: wheelhouse config-file: "{package}/pyproject.toml" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: - name: wheels + name: wheels-${{ matrix.python-version }}-${{ matrix.os }} path: wheelhouse/*.whl if-no-files-found: error - publish: + list-artifacts: + name: List artifacts runs-on: ubuntu-latest needs: [build, build-wheels-py38, build-wheels] + if: github.repository == 'metaopt/torchopt' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/')) + timeout-minutes: 15 + steps: + - name: Download built sdist + uses: actions/download-artifact@v4 + with: + # unpacks default artifact into dist/ + # if `name: artifact` is omitted, the action will create extra parent dir + name: build + path: dist + + - name: Download built wheels + uses: actions/download-artifact@v4 + with: + pattern: wheels-* + path: dist + merge-multiple: true + + - name: List distributions + run: ls -lh dist/* + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + name: artifacts + path: dist/* + if-no-files-found: error + + publish: + runs-on: ubuntu-latest + needs: [list-artifacts] if: | github.repository == 'metaopt/torchopt' && github.event_name != 'pull_request' && (github.event_name != 'workflow_dispatch' || github.event.inputs.task == 'build-and-publish') && @@ -236,28 +270,12 @@ jobs: exit 1 fi - - name: Download built sdist - uses: actions/download-artifact@v3 - with: - # unpacks default artifact into dist/ - # if `name: artifact` is omitted, the action will create extra parent dir - name: build - path: dist - - - name: Download built wheels - uses: actions/download-artifact@v3 - with: - # unpacks default artifact into dist/ - # if `name: artifact` is omitted, the action will create extra parent dir - name: wheels-py38 - path: dist - - - name: Download built wheels - uses: actions/download-artifact@v3 + - name: Download built artifacts + uses: actions/download-artifact@v4 with: # unpacks default artifact into dist/ # if `name: artifact` is omitted, the action will create extra parent dir - name: wheels + name: artifacts path: dist - name: List distributions @@ -269,10 +287,10 @@ jobs: with: user: __token__ password: ${{ secrets.TESTPYPI_UPLOAD_TOKEN }} - repository_url: https://test.pypi.org/legacy/ + repository-url: https://test.pypi.org/legacy/ verbose: true - print_hash: true - skip_existing: true + print-hash: true + skip-existing: true - name: Publish to PyPI if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' @@ -281,5 +299,5 @@ jobs: user: __token__ password: ${{ secrets.PYPI_UPLOAD_TOKEN }} verbose: true - print_hash: true - skip_existing: true + print-hash: true + skip-existing: true diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3a6e633e..15af5635 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,16 +30,16 @@ repos: hooks: - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.7 + rev: v0.1.8 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/PyCQA/isort - rev: 5.13.0 + rev: 5.13.2 hooks: - id: isort - repo: https://github.com/psf/black - rev: 23.11.0 + rev: 23.12.0 hooks: - id: black-jupyter - repo: https://github.com/asottile/pyupgrade From 3b1efe938e740a3d91f1e7b1564ff2935c0cf024 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Wed, 20 Dec 2023 19:34:01 +0800 Subject: [PATCH 10/26] docs: update citation references (#201) --- README.md | 12 ++++++++---- docs/source/index.rst | 14 +++++++++----- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index ee1905ab..2300f619 100644 --- a/README.md +++ b/README.md @@ -469,11 +469,15 @@ See [CHANGELOG.md](CHANGELOG.md). If you find TorchOpt useful, please cite it in your publications. ```bibtex -@article{torchopt, +@article{JMLR:TorchOpt, + author = {Jie Ren* and Xidong Feng* and Bo Liu* and Xuehai Pan* and Yao Fu and Luo Mai and Yaodong Yang}, title = {TorchOpt: An Efficient Library for Differentiable Optimization}, - author = {Ren, Jie and Feng, Xidong and Liu, Bo and Pan, Xuehai and Fu, Yao and Mai, Luo and Yang, Yaodong}, - journal = {arXiv preprint arXiv:2211.06934}, - year = {2022} + journal = {Journal of Machine Learning Research}, + year = {2023}, + volume = {24}, + number = {367}, + pages = {1--14}, + url = {http://jmlr.org/papers/v24/23-0191.html} } ``` diff --git a/docs/source/index.rst b/docs/source/index.rst index 02fab843..ec9a3749 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -118,11 +118,15 @@ If you find TorchOpt useful, please cite it in your publications. .. code-block:: bibtex - @article{torchopt, - title = {TorchOpt: An Efficient Library for Differentiable Optimization}, - author = {Ren, Jie and Feng, Xidong and Liu, Bo and Pan, Xuehai and Fu, Yao and Mai, Luo and Yang, Yaodong}, - journal = {arXiv preprint arXiv:2211.06934}, - year = {2022} + @article{JMLR:TorchOpt, + author = {Jie Ren* and Xidong Feng* and Bo Liu* and Xuehai Pan* and Yao Fu and Luo Mai and Yaodong Yang}, + title = {TorchOpt: An Efficient Library for Differentiable Optimization}, + journal = {Journal of Machine Learning Research}, + year = {2023}, + volume = {24}, + number = {367}, + pages = {1--14}, + url = {http://jmlr.org/papers/v24/23-0191.html} } From aef565c281c8516c27bbe96954cf3b714177afef Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Jan 2024 04:30:09 +0800 Subject: [PATCH 11/26] chore(pre-commit): [pre-commit.ci] autoupdate (#202) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.1.8 → v0.1.9](https://github.com/astral-sh/ruff-pre-commit/compare/v0.1.8...v0.1.9) - [github.com/psf/black: 23.12.0 → 23.12.1](https://github.com/psf/black/compare/23.12.0...23.12.1) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 15af5635..99e5782d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,7 +30,7 @@ repos: hooks: - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.8 + rev: v0.1.9 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] @@ -39,7 +39,7 @@ repos: hooks: - id: isort - repo: https://github.com/psf/black - rev: 23.12.0 + rev: 23.12.1 hooks: - id: black-jupyter - repo: https://github.com/asottile/pyupgrade From d439ea31eb4d14d9b3604a169d31dd60b26353ce Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Mar 2024 17:16:37 +0800 Subject: [PATCH 12/26] chore(pre-commit): [pre-commit.ci] autoupdate (#206) Co-authored-by: Xuehai Pan --- .pre-commit-config.yaml | 10 ++++----- .readthedocs.yaml | 4 ++-- CMakeLists.txt | 2 +- LICENSE | 2 +- Makefile | 19 ++++++++++++++--- conda-recipe-minimal-cpu.yaml | 2 +- conda-recipe-minimal.yaml | 2 +- conda-recipe.yaml | 2 +- docs/conda-recipe.yaml | 2 +- docs/source/conf.py | 4 ++-- examples/FuncTorch/maml_omniglot_vmap.py | 2 +- examples/L2R/helpers/utils.py | 2 +- examples/LOLA/helpers/agent.py | 2 +- examples/MAML-RL/func_maml.py | 2 +- .../distributed/few-shot/maml_omniglot.py | 2 +- .../few-shot/maml_omniglot_local_loader.py | 2 +- examples/few-shot/maml_omniglot.py | 2 +- examples/iMAML/imaml_omniglot.py | 2 +- examples/iMAML/imaml_omniglot_functional.py | 2 +- include/adam_op/adam_op.h | 2 +- include/adam_op/adam_op_impl_cpu.h | 2 +- include/adam_op/adam_op_impl_cuda.cuh | 2 +- include/common.h | 2 +- include/utils.h | 2 +- pyproject.toml | 21 ++++++++++--------- src/adam_op/adam_op.cpp | 2 +- src/adam_op/adam_op_impl_cpu.cpp | 2 +- src/adam_op/adam_op_impl_cuda.cu | 2 +- tests/conftest.py | 2 +- tests/helpers.py | 2 +- tests/requirements.txt | 3 ++- tests/test_accelerated_op.py | 2 +- tests/test_alias.py | 2 +- tests/test_clip.py | 2 +- tests/test_combine.py | 2 +- tests/test_hook.py | 2 +- tests/test_implicit.py | 2 +- tests/test_import.py | 2 +- tests/test_linalg.py | 2 +- tests/test_meta_optim.py | 2 +- tests/test_nn.py | 2 +- tests/test_optim.py | 2 +- tests/test_pytree.py | 2 +- tests/test_schedule.py | 2 +- tests/test_transform.py | 2 +- tests/test_utils.py | 2 +- tests/test_zero_order.py | 2 +- torchopt/_C/adam_op.pyi | 2 +- torchopt/__init__.py | 2 +- torchopt/accelerated_op/__init__.py | 2 +- torchopt/accelerated_op/_src/adam_op.py | 2 +- torchopt/accelerated_op/adam_op.py | 2 +- torchopt/alias/__init__.py | 2 +- torchopt/alias/adadelta.py | 2 +- torchopt/alias/adagrad.py | 2 +- torchopt/alias/adam.py | 2 +- torchopt/alias/adamax.py | 2 +- torchopt/alias/adamw.py | 2 +- torchopt/alias/radam.py | 2 +- torchopt/alias/rmsprop.py | 2 +- torchopt/alias/sgd.py | 2 +- torchopt/alias/utils.py | 2 +- torchopt/base.py | 10 +++++---- torchopt/clip.py | 2 +- torchopt/combine.py | 2 +- torchopt/diff/__init__.py | 2 +- torchopt/diff/implicit/decorator.py | 2 +- torchopt/diff/implicit/nn/__init__.py | 2 +- torchopt/diff/implicit/nn/module.py | 2 +- torchopt/diff/zero_order/__init__.py | 2 +- torchopt/diff/zero_order/decorator.py | 2 +- torchopt/diff/zero_order/nn/__init__.py | 2 +- torchopt/diff/zero_order/nn/module.py | 2 +- torchopt/distributed/__init__.py | 2 +- torchopt/distributed/api.py | 2 +- torchopt/hook.py | 2 +- torchopt/nn/__init__.py | 2 +- torchopt/nn/module.py | 2 +- torchopt/nn/stateless.py | 2 +- torchopt/optim/__init__.py | 2 +- torchopt/optim/adadelta.py | 2 +- torchopt/optim/adagrad.py | 2 +- torchopt/optim/adam.py | 2 +- torchopt/optim/adamax.py | 2 +- torchopt/optim/adamw.py | 2 +- torchopt/optim/func/base.py | 2 +- torchopt/optim/meta/__init__.py | 2 +- torchopt/optim/meta/adagrad.py | 2 +- torchopt/optim/meta/adamw.py | 2 +- torchopt/optim/meta/base.py | 2 +- torchopt/optim/radam.py | 2 +- torchopt/pytree.py | 2 +- torchopt/schedule/__init__.py | 2 +- torchopt/schedule/exponential_decay.py | 2 +- torchopt/schedule/polynomial.py | 2 +- torchopt/transform/__init__.py | 2 +- torchopt/transform/add_decayed_weights.py | 2 +- torchopt/transform/nan_to_num.py | 2 +- torchopt/transform/scale.py | 2 +- torchopt/transform/scale_by_adadelta.py | 2 +- torchopt/transform/scale_by_adam.py | 2 +- torchopt/transform/scale_by_adamax.py | 2 +- torchopt/transform/scale_by_radam.py | 2 +- torchopt/transform/scale_by_rms.py | 6 +++--- torchopt/transform/scale_by_rss.py | 2 +- torchopt/transform/scale_by_schedule.py | 2 +- torchopt/transform/scale_by_stddev.py | 2 +- torchopt/transform/trace.py | 2 +- torchopt/transform/utils.py | 2 +- torchopt/typing.py | 2 +- torchopt/utils.py | 2 +- torchopt/visual.py | 4 ++-- 112 files changed, 152 insertions(+), 135 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 99e5782d..d36f0472 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,11 +26,11 @@ repos: - id: debug-statements - id: double-quote-string-fixer - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v17.0.6 + rev: v18.1.1 hooks: - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.9 + rev: v0.3.3 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] @@ -39,11 +39,11 @@ repos: hooks: - id: isort - repo: https://github.com/psf/black - rev: 23.12.1 + rev: 24.3.0 hooks: - id: black-jupyter - repo: https://github.com/asottile/pyupgrade - rev: v3.15.0 + rev: v3.15.1 hooks: - id: pyupgrade args: [--py38-plus] # sync with requires-python @@ -52,7 +52,7 @@ repos: ^examples/ ) - repo: https://github.com/pycqa/flake8 - rev: 6.1.0 + rev: 7.0.0 hooks: - id: flake8 additional_dependencies: diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 6a9c387e..c014fa97 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -7,9 +7,9 @@ version: 2 # Set the version of Python and other tools you might need build: - os: ubuntu-20.04 + os: ubuntu-lts-latest tools: - python: mambaforge-4.10 + python: mambaforge-latest jobs: post_install: - python -m pip install --upgrade pip setuptools diff --git a/CMakeLists.txt b/CMakeLists.txt index eca14815..bf17f38c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/LICENSE b/LICENSE index 8d26c203..185e0144 100644 --- a/LICENSE +++ b/LICENSE @@ -187,7 +187,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright [2022-2023] [MetaOPT Team. All Rights Reserved.] + Copyright [2022-2024] [MetaOPT Team. All Rights Reserved.] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/Makefile b/Makefile index 0f7dd74e..1c42fcd7 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -print-% : ; @echo $* = $($*) +print-%: ; @echo $* = $($*) PROJECT_NAME = torchopt COPYRIGHT = "MetaOPT Team. All Rights Reserved." PROJECT_PATH = $(PROJECT_NAME) @@ -22,7 +22,7 @@ install: install-editable: $(PYTHON) -m pip install --upgrade pip $(PYTHON) -m pip install --upgrade setuptools wheel - $(PYTHON) -m pip install torch numpy pybind11 + $(PYTHON) -m pip install torch numpy pybind11 cmake USE_FP16=ON TORCH_CUDA_ARCH_LIST=Auto $(PYTHON) -m pip install -vvv --no-build-isolation --editable . install-e: install-editable # alias @@ -112,6 +112,7 @@ addlicense-install: go-install # Tests pytest: test-install + $(PYTHON) -m pytest --version cd tests && $(PYTHON) -c 'import $(PROJECT_PATH)' && \ $(PYTHON) -m pytest --verbose --color=yes --durations=0 \ --cov="$(PROJECT_PATH)" --cov-config=.coveragerc --cov-report=xml --cov-report=term-missing \ @@ -122,30 +123,39 @@ test: pytest # Python linters pylint: pylint-install + $(PYTHON) -m pylint --version $(PYTHON) -m pylint $(PROJECT_PATH) flake8: flake8-install + $(PYTHON) -m flake8 --version $(PYTHON) -m flake8 --count --show-source --statistics py-format: py-format-install + $(PYTHON) -m isort --version + $(PYTHON) -m black --version $(PYTHON) -m isort --project $(PROJECT_PATH) --check $(PYTHON_FILES) && \ $(PYTHON) -m black --check $(PYTHON_FILES) tutorials ruff: ruff-install + $(PYTHON) -m ruff --version $(PYTHON) -m ruff check . ruff-fix: ruff-install + $(PYTHON) -m ruff --version $(PYTHON) -m ruff check . --fix --exit-non-zero-on-fix mypy: mypy-install + $(PYTHON) -m mypy --version $(PYTHON) -m mypy $(PROJECT_PATH) --install-types --non-interactive pre-commit: pre-commit-install + $(PYTHON) -m pre_commit --version $(PYTHON) -m pre_commit run --all-files # C++ linters cmake-configure: cmake-install + cmake --version cmake -S . -B cmake-build-debug \ -DCMAKE_BUILD_TYPE=Debug \ -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ @@ -157,13 +167,16 @@ cmake-build: cmake-configure cmake: cmake-build cpplint: cpplint-install + $(PYTHON) -m cpplint --version $(PYTHON) -m cpplint $(CXX_FILES) $(CUDA_FILES) clang-format: clang-format-install + $(CLANG_FORMAT) --version $(CLANG_FORMAT) --style=file -i $(CXX_FILES) $(CUDA_FILES) -n --Werror clang-tidy: clang-tidy-install cmake-configure - clang-tidy -p=cmake-build-debug $(CXX_FILES) + clang-tidy --version + clang-tidy --extra-arg="-v" -p=cmake-build-debug $(CXX_FILES) # Documentation diff --git a/conda-recipe-minimal-cpu.yaml b/conda-recipe-minimal-cpu.yaml index 0404f10c..c872924d 100644 --- a/conda-recipe-minimal-cpu.yaml +++ b/conda-recipe-minimal-cpu.yaml @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/conda-recipe-minimal.yaml b/conda-recipe-minimal.yaml index c3d155b8..39ff72cc 100644 --- a/conda-recipe-minimal.yaml +++ b/conda-recipe-minimal.yaml @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/conda-recipe.yaml b/conda-recipe.yaml index 12b3a6d0..c82eb3c6 100644 --- a/conda-recipe.yaml +++ b/conda-recipe.yaml @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/docs/conda-recipe.yaml b/docs/conda-recipe.yaml index 30ec372e..b6f1b580 100644 --- a/docs/conda-recipe.yaml +++ b/docs/conda-recipe.yaml @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/docs/source/conf.py b/docs/source/conf.py index f5d206c7..a4f23533 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -66,7 +66,7 @@ def filter(self, record: logging.LogRecord) -> bool: # -- Project information ------------------------------------------------------- project = 'TorchOpt' -copyright = '2022-2023 MetaOPT Team' +copyright = '2022-2024 MetaOPT Team' author = 'TorchOpt Contributors' # The full version, including alpha/beta/rc tags diff --git a/examples/FuncTorch/maml_omniglot_vmap.py b/examples/FuncTorch/maml_omniglot_vmap.py index e1cfe95e..2f42e050 100644 --- a/examples/FuncTorch/maml_omniglot_vmap.py +++ b/examples/FuncTorch/maml_omniglot_vmap.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/L2R/helpers/utils.py b/examples/L2R/helpers/utils.py index 7e95ca6f..ade64236 100644 --- a/examples/L2R/helpers/utils.py +++ b/examples/L2R/helpers/utils.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/LOLA/helpers/agent.py b/examples/LOLA/helpers/agent.py index a8f8ee31..78946ee7 100644 --- a/examples/LOLA/helpers/agent.py +++ b/examples/LOLA/helpers/agent.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/MAML-RL/func_maml.py b/examples/MAML-RL/func_maml.py index f3a00642..475c1b12 100644 --- a/examples/MAML-RL/func_maml.py +++ b/examples/MAML-RL/func_maml.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/distributed/few-shot/maml_omniglot.py b/examples/distributed/few-shot/maml_omniglot.py index 24601dfa..f840e65e 100644 --- a/examples/distributed/few-shot/maml_omniglot.py +++ b/examples/distributed/few-shot/maml_omniglot.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/distributed/few-shot/maml_omniglot_local_loader.py b/examples/distributed/few-shot/maml_omniglot_local_loader.py index d7413770..fb737d4f 100644 --- a/examples/distributed/few-shot/maml_omniglot_local_loader.py +++ b/examples/distributed/few-shot/maml_omniglot_local_loader.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/few-shot/maml_omniglot.py b/examples/few-shot/maml_omniglot.py index d798aa1d..7f7f67fe 100644 --- a/examples/few-shot/maml_omniglot.py +++ b/examples/few-shot/maml_omniglot.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/iMAML/imaml_omniglot.py b/examples/iMAML/imaml_omniglot.py index 8a6960ba..1db08427 100644 --- a/examples/iMAML/imaml_omniglot.py +++ b/examples/iMAML/imaml_omniglot.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/iMAML/imaml_omniglot_functional.py b/examples/iMAML/imaml_omniglot_functional.py index 60fd4108..7bc1e9da 100644 --- a/examples/iMAML/imaml_omniglot_functional.py +++ b/examples/iMAML/imaml_omniglot_functional.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/include/adam_op/adam_op.h b/include/adam_op/adam_op.h index a49b0a06..2d0abcd3 100644 --- a/include/adam_op/adam_op.h +++ b/include/adam_op/adam_op.h @@ -1,4 +1,4 @@ -// Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2024 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/include/adam_op/adam_op_impl_cpu.h b/include/adam_op/adam_op_impl_cpu.h index 37aba528..4d54377e 100644 --- a/include/adam_op/adam_op_impl_cpu.h +++ b/include/adam_op/adam_op_impl_cpu.h @@ -1,4 +1,4 @@ -// Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2024 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/include/adam_op/adam_op_impl_cuda.cuh b/include/adam_op/adam_op_impl_cuda.cuh index 6e661564..17002b36 100644 --- a/include/adam_op/adam_op_impl_cuda.cuh +++ b/include/adam_op/adam_op_impl_cuda.cuh @@ -1,4 +1,4 @@ -// Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2024 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/include/common.h b/include/common.h index 65f9ef33..256b0ca1 100644 --- a/include/common.h +++ b/include/common.h @@ -1,4 +1,4 @@ -// Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2024 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/include/utils.h b/include/utils.h index 0ef98539..cefabfac 100644 --- a/include/utils.h +++ b/include/utils.h @@ -1,4 +1,4 @@ -// Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2024 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pyproject.toml b/pyproject.toml index 31d20f88..f3a5645b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -176,7 +176,6 @@ test-command = """ # Linter tools ################################################################# [tool.black] -safe = true line-length = 100 skip-string-normalization = true # Sync with requires-python @@ -194,15 +193,15 @@ multi_line_output = 3 [tool.mypy] # Sync with requires-python -python_version = 3.8 +python_version = "3.8" pretty = true show_error_codes = true show_error_context = true show_traceback = true allow_redefinition = true check_untyped_defs = true -disallow_incomplete_defs = false -disallow_untyped_defs = false +disallow_incomplete_defs = true +disallow_untyped_defs = true ignore_missing_imports = true no_implicit_optional = true strict_equality = true @@ -226,9 +225,11 @@ ignore-words = "docs/source/spelling_wordlist.txt" # Sync with requires-python target-version = "py38" line-length = 100 -show-source = true +output-format = "full" src = ["torchopt", "tests"] extend-exclude = ["examples"] + +[tool.ruff.lint] select = [ "E", "W", # pycodestyle "F", # pyflakes @@ -271,7 +272,7 @@ ignore = [ ] typing-modules = ["torchopt.typing"] -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "__init__.py" = [ "F401", # unused-import ] @@ -294,18 +295,18 @@ typing-modules = ["torchopt.typing"] "F811", # redefined-while-unused ] -[tool.ruff.flake8-annotations] +[tool.ruff.lint.flake8-annotations] allow-star-arg-any = true -[tool.ruff.flake8-quotes] +[tool.ruff.lint.flake8-quotes] docstring-quotes = "double" multiline-quotes = "double" inline-quotes = "single" -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "all" -[tool.ruff.pylint] +[tool.ruff.lint.pylint] allow-magic-value-types = ["int", "str", "float"] [tool.pytest.ini_options] diff --git a/src/adam_op/adam_op.cpp b/src/adam_op/adam_op.cpp index 08c9fb74..47f5d7f1 100644 --- a/src/adam_op/adam_op.cpp +++ b/src/adam_op/adam_op.cpp @@ -1,4 +1,4 @@ -// Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2024 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/src/adam_op/adam_op_impl_cpu.cpp b/src/adam_op/adam_op_impl_cpu.cpp index 1135206d..9c460685 100644 --- a/src/adam_op/adam_op_impl_cpu.cpp +++ b/src/adam_op/adam_op_impl_cpu.cpp @@ -1,4 +1,4 @@ -// Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2024 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/src/adam_op/adam_op_impl_cuda.cu b/src/adam_op/adam_op_impl_cuda.cu index ea1526a6..a12eca4f 100644 --- a/src/adam_op/adam_op_impl_cuda.cu +++ b/src/adam_op/adam_op_impl_cuda.cu @@ -1,4 +1,4 @@ -// Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2024 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/tests/conftest.py b/tests/conftest.py index eaa734b2..bb2b1cf2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/helpers.py b/tests/helpers.py index 50451496..1b624ce9 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/requirements.txt b/tests/requirements.txt index 87c994e1..397ac350 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -5,7 +5,8 @@ torch >= 1.13 jax[cpu] >= 0.3; platform_system != 'Windows' jaxopt; platform_system != 'Windows' -optax; platform_system != 'Windows' +optax < 0.1.8a0; platform_system != 'Windows' and python_version < '3.9' +optax >= 0.1.8; platform_system != 'Windows' and python_version >= '3.9' pytest pytest-cov diff --git a/tests/test_accelerated_op.py b/tests/test_accelerated_op.py index 6cb45ca0..668c9b9a 100644 --- a/tests/test_accelerated_op.py +++ b/tests/test_accelerated_op.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_alias.py b/tests/test_alias.py index aef35b96..58b5a328 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_clip.py b/tests/test_clip.py index 0b191cfe..2614781e 100644 --- a/tests/test_clip.py +++ b/tests/test_clip.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_combine.py b/tests/test_combine.py index 39b3e37f..1a026b9e 100644 --- a/tests/test_combine.py +++ b/tests/test_combine.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_hook.py b/tests/test_hook.py index 1f3024c7..e89bb178 100644 --- a/tests/test_hook.py +++ b/tests/test_hook.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_implicit.py b/tests/test_implicit.py index 61623a17..ff0ba15c 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_import.py b/tests/test_import.py index f7523756..04d0ebbb 100644 --- a/tests/test_import.py +++ b/tests/test_import.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 7758b7db..c5b07618 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_meta_optim.py b/tests/test_meta_optim.py index 61f8a7ad..55712bdf 100644 --- a/tests/test_meta_optim.py +++ b/tests/test_meta_optim.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_nn.py b/tests/test_nn.py index 8e89bdb5..f77c20ec 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_optim.py b/tests/test_optim.py index dc3941d9..1257054f 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_pytree.py b/tests/test_pytree.py index d82d81f2..6ee2939b 100644 --- a/tests/test_pytree.py +++ b/tests/test_pytree.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_schedule.py b/tests/test_schedule.py index 1fdc4669..e4c0ac0a 100644 --- a/tests/test_schedule.py +++ b/tests/test_schedule.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_transform.py b/tests/test_transform.py index 9598386d..0a7bd498 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_utils.py b/tests/test_utils.py index d1be7c6f..5215e7b3 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_zero_order.py b/tests/test_zero_order.py index 61f75f9a..65642559 100644 --- a/tests/test_zero_order.py +++ b/tests/test_zero_order.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/_C/adam_op.pyi b/torchopt/_C/adam_op.pyi index 04f141fd..5ef572aa 100644 --- a/torchopt/_C/adam_op.pyi +++ b/torchopt/_C/adam_op.pyi @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/__init__.py b/torchopt/__init__.py index a089f3dc..5e568526 100644 --- a/torchopt/__init__.py +++ b/torchopt/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/accelerated_op/__init__.py b/torchopt/accelerated_op/__init__.py index 3ac943e3..103b6fc0 100644 --- a/torchopt/accelerated_op/__init__.py +++ b/torchopt/accelerated_op/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/accelerated_op/_src/adam_op.py b/torchopt/accelerated_op/_src/adam_op.py index c8fc8898..bc999766 100644 --- a/torchopt/accelerated_op/_src/adam_op.py +++ b/torchopt/accelerated_op/_src/adam_op.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/accelerated_op/adam_op.py b/torchopt/accelerated_op/adam_op.py index d6f9e9f9..43ac26cd 100644 --- a/torchopt/accelerated_op/adam_op.py +++ b/torchopt/accelerated_op/adam_op.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/alias/__init__.py b/torchopt/alias/__init__.py index 3ea721c4..3cfb5b8b 100644 --- a/torchopt/alias/__init__.py +++ b/torchopt/alias/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/alias/adadelta.py b/torchopt/alias/adadelta.py index 2e3640f2..fb0b551a 100644 --- a/torchopt/alias/adadelta.py +++ b/torchopt/alias/adadelta.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/alias/adagrad.py b/torchopt/alias/adagrad.py index 3f983c38..6fdb4aa3 100644 --- a/torchopt/alias/adagrad.py +++ b/torchopt/alias/adagrad.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/alias/adam.py b/torchopt/alias/adam.py index dc889285..9419e908 100644 --- a/torchopt/alias/adam.py +++ b/torchopt/alias/adam.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/alias/adamax.py b/torchopt/alias/adamax.py index ffa19e37..f80c0c2f 100644 --- a/torchopt/alias/adamax.py +++ b/torchopt/alias/adamax.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/alias/adamw.py b/torchopt/alias/adamw.py index e8bed2ab..38d4d5ac 100644 --- a/torchopt/alias/adamw.py +++ b/torchopt/alias/adamw.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/alias/radam.py b/torchopt/alias/radam.py index 230c1151..56d3d3d5 100644 --- a/torchopt/alias/radam.py +++ b/torchopt/alias/radam.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/alias/rmsprop.py b/torchopt/alias/rmsprop.py index 96092548..612e4f45 100644 --- a/torchopt/alias/rmsprop.py +++ b/torchopt/alias/rmsprop.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/alias/sgd.py b/torchopt/alias/sgd.py index 4c5b8317..6d5935bc 100644 --- a/torchopt/alias/sgd.py +++ b/torchopt/alias/sgd.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/alias/utils.py b/torchopt/alias/utils.py index 2c2ec0e4..49f8784d 100644 --- a/torchopt/alias/utils.py +++ b/torchopt/alias/utils.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/base.py b/torchopt/base.py index cab2b49f..572708e2 100644 --- a/torchopt/base.py +++ b/torchopt/base.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -164,9 +164,11 @@ def __new__(cls, *transformations: GradientTransformation) -> Self: """Create a new chained gradient transformation.""" transformations = tuple( itertools.chain.from_iterable( - t.transformations - if isinstance(t, ChainedGradientTransformation) - else ((t,) if not isinstance(t, IdentityGradientTransformation) else ()) + ( + t.transformations + if isinstance(t, ChainedGradientTransformation) + else ((t,) if not isinstance(t, IdentityGradientTransformation) else ()) + ) for t in transformations ), ) diff --git a/torchopt/clip.py b/torchopt/clip.py index eda4bef3..55ae83fc 100644 --- a/torchopt/clip.py +++ b/torchopt/clip.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/combine.py b/torchopt/combine.py index fc1a7152..158ec982 100644 --- a/torchopt/combine.py +++ b/torchopt/combine.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/diff/__init__.py b/torchopt/diff/__init__.py index 984841ed..194512f5 100644 --- a/torchopt/diff/__init__.py +++ b/torchopt/diff/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/diff/implicit/decorator.py b/torchopt/diff/implicit/decorator.py index 1fc9bb4a..d3efda2c 100644 --- a/torchopt/diff/implicit/decorator.py +++ b/torchopt/diff/implicit/decorator.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/diff/implicit/nn/__init__.py b/torchopt/diff/implicit/nn/__init__.py index 5bc7aa8d..e91ef8ed 100644 --- a/torchopt/diff/implicit/nn/__init__.py +++ b/torchopt/diff/implicit/nn/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py index a72e5304..8719f675 100644 --- a/torchopt/diff/implicit/nn/module.py +++ b/torchopt/diff/implicit/nn/module.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/diff/zero_order/__init__.py b/torchopt/diff/zero_order/__init__.py index b621ffdc..f00e097a 100644 --- a/torchopt/diff/zero_order/__init__.py +++ b/torchopt/diff/zero_order/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/diff/zero_order/decorator.py b/torchopt/diff/zero_order/decorator.py index f63f0574..b1126636 100644 --- a/torchopt/diff/zero_order/decorator.py +++ b/torchopt/diff/zero_order/decorator.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/diff/zero_order/nn/__init__.py b/torchopt/diff/zero_order/nn/__init__.py index 1bf64efe..f2753b27 100644 --- a/torchopt/diff/zero_order/nn/__init__.py +++ b/torchopt/diff/zero_order/nn/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/diff/zero_order/nn/module.py b/torchopt/diff/zero_order/nn/module.py index 75da28f9..7ac12bb4 100644 --- a/torchopt/diff/zero_order/nn/module.py +++ b/torchopt/diff/zero_order/nn/module.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/distributed/__init__.py b/torchopt/distributed/__init__.py index 534b2dea..31f1283b 100644 --- a/torchopt/distributed/__init__.py +++ b/torchopt/distributed/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/distributed/api.py b/torchopt/distributed/api.py index 86c2cfe8..117af9ab 100644 --- a/torchopt/distributed/api.py +++ b/torchopt/distributed/api.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/hook.py b/torchopt/hook.py index 13ed6abf..b51e29eb 100644 --- a/torchopt/hook.py +++ b/torchopt/hook.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/nn/__init__.py b/torchopt/nn/__init__.py index 8271ad7d..7665f201 100644 --- a/torchopt/nn/__init__.py +++ b/torchopt/nn/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/nn/module.py b/torchopt/nn/module.py index 64623146..419afb6a 100644 --- a/torchopt/nn/module.py +++ b/torchopt/nn/module.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/nn/stateless.py b/torchopt/nn/stateless.py index 8268ca3f..d3437d0d 100644 --- a/torchopt/nn/stateless.py +++ b/torchopt/nn/stateless.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/__init__.py b/torchopt/optim/__init__.py index 20da5fca..f620608c 100644 --- a/torchopt/optim/__init__.py +++ b/torchopt/optim/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/adadelta.py b/torchopt/optim/adadelta.py index 7c73cb58..a64e00e4 100644 --- a/torchopt/optim/adadelta.py +++ b/torchopt/optim/adadelta.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/adagrad.py b/torchopt/optim/adagrad.py index a7e8c72b..277b7105 100644 --- a/torchopt/optim/adagrad.py +++ b/torchopt/optim/adagrad.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/adam.py b/torchopt/optim/adam.py index 5d85cbdc..6ff68a69 100644 --- a/torchopt/optim/adam.py +++ b/torchopt/optim/adam.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/adamax.py b/torchopt/optim/adamax.py index 904c05a0..f693723c 100644 --- a/torchopt/optim/adamax.py +++ b/torchopt/optim/adamax.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/adamw.py b/torchopt/optim/adamw.py index be8c6727..463f245f 100644 --- a/torchopt/optim/adamw.py +++ b/torchopt/optim/adamw.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/func/base.py b/torchopt/optim/func/base.py index 7a7839a3..7bb27877 100644 --- a/torchopt/optim/func/base.py +++ b/torchopt/optim/func/base.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/meta/__init__.py b/torchopt/optim/meta/__init__.py index 516f2b5f..9e30dfef 100644 --- a/torchopt/optim/meta/__init__.py +++ b/torchopt/optim/meta/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/meta/adagrad.py b/torchopt/optim/meta/adagrad.py index 4e8ef0eb..58d913aa 100644 --- a/torchopt/optim/meta/adagrad.py +++ b/torchopt/optim/meta/adagrad.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/meta/adamw.py b/torchopt/optim/meta/adamw.py index 204a5428..05387b77 100644 --- a/torchopt/optim/meta/adamw.py +++ b/torchopt/optim/meta/adamw.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/meta/base.py b/torchopt/optim/meta/base.py index 54327f3b..73ecdde7 100644 --- a/torchopt/optim/meta/base.py +++ b/torchopt/optim/meta/base.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/radam.py b/torchopt/optim/radam.py index c2f6a211..bba8c0d4 100644 --- a/torchopt/optim/radam.py +++ b/torchopt/optim/radam.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/pytree.py b/torchopt/pytree.py index 6d41d0fa..6adea0e8 100644 --- a/torchopt/pytree.py +++ b/torchopt/pytree.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/schedule/__init__.py b/torchopt/schedule/__init__.py index b9916783..8e5545a4 100644 --- a/torchopt/schedule/__init__.py +++ b/torchopt/schedule/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/schedule/exponential_decay.py b/torchopt/schedule/exponential_decay.py index 8811b353..0925e164 100644 --- a/torchopt/schedule/exponential_decay.py +++ b/torchopt/schedule/exponential_decay.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/schedule/polynomial.py b/torchopt/schedule/polynomial.py index 39629c38..2482f769 100644 --- a/torchopt/schedule/polynomial.py +++ b/torchopt/schedule/polynomial.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/transform/__init__.py b/torchopt/transform/__init__.py index c75fcb5d..adef5596 100644 --- a/torchopt/transform/__init__.py +++ b/torchopt/transform/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/transform/add_decayed_weights.py b/torchopt/transform/add_decayed_weights.py index 39948694..950682cf 100644 --- a/torchopt/transform/add_decayed_weights.py +++ b/torchopt/transform/add_decayed_weights.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/transform/nan_to_num.py b/torchopt/transform/nan_to_num.py index 27d87499..d3530853 100644 --- a/torchopt/transform/nan_to_num.py +++ b/torchopt/transform/nan_to_num.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/transform/scale.py b/torchopt/transform/scale.py index c731003c..493b7196 100644 --- a/torchopt/transform/scale.py +++ b/torchopt/transform/scale.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/transform/scale_by_adadelta.py b/torchopt/transform/scale_by_adadelta.py index bbe40080..f389d293 100644 --- a/torchopt/transform/scale_by_adadelta.py +++ b/torchopt/transform/scale_by_adadelta.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/transform/scale_by_adam.py b/torchopt/transform/scale_by_adam.py index cc0ea3b6..b08c6a14 100644 --- a/torchopt/transform/scale_by_adam.py +++ b/torchopt/transform/scale_by_adam.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/transform/scale_by_adamax.py b/torchopt/transform/scale_by_adamax.py index 0a1c3ec9..f11ed311 100644 --- a/torchopt/transform/scale_by_adamax.py +++ b/torchopt/transform/scale_by_adamax.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/transform/scale_by_radam.py b/torchopt/transform/scale_by_radam.py index acb85a82..fad32b13 100644 --- a/torchopt/transform/scale_by_radam.py +++ b/torchopt/transform/scale_by_radam.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/transform/scale_by_rms.py b/torchopt/transform/scale_by_rms.py index 084be839..4ee67ed0 100644 --- a/torchopt/transform/scale_by_rms.py +++ b/torchopt/transform/scale_by_rms.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -135,14 +135,14 @@ def update_fn( ) if inplace: - + # pylint: disable-next=invalid-name def f(n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: return g.div_(n.sqrt().add_(eps)) if g is not None else g tree_map_(f, nu, updates) else: - + # pylint: disable-next=invalid-name def f(n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: return g.div(n.sqrt().add(eps)) if g is not None else g diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py index b1f3d2a8..9bc97206 100644 --- a/torchopt/transform/scale_by_rss.py +++ b/torchopt/transform/scale_by_rss.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/transform/scale_by_schedule.py b/torchopt/transform/scale_by_schedule.py index 749b1853..48f3f271 100644 --- a/torchopt/transform/scale_by_schedule.py +++ b/torchopt/transform/scale_by_schedule.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/transform/scale_by_stddev.py b/torchopt/transform/scale_by_stddev.py index d9589c45..6b99f31a 100644 --- a/torchopt/transform/scale_by_stddev.py +++ b/torchopt/transform/scale_by_stddev.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/transform/trace.py b/torchopt/transform/trace.py index d530a676..9bf37e2f 100644 --- a/torchopt/transform/trace.py +++ b/torchopt/transform/trace.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/transform/utils.py b/torchopt/transform/utils.py index f1ed39da..ec4e51c1 100644 --- a/torchopt/transform/utils.py +++ b/torchopt/transform/utils.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/typing.py b/torchopt/typing.py index c5c76984..60d11e0e 100644 --- a/torchopt/typing.py +++ b/torchopt/typing.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/utils.py b/torchopt/utils.py index ef771966..c067d570 100644 --- a/torchopt/utils.py +++ b/torchopt/utils.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/visual.py b/torchopt/visual.py index 47a7f5d5..d7885889 100644 --- a/torchopt/visual.py +++ b/torchopt/visual.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -145,7 +145,7 @@ def size_to_str(size: tuple[int, ...]) -> str: def get_var_name(var: torch.Tensor, name: str | None = None) -> str: if not name: - name = param_map[var] if var in param_map else '' + name = param_map.get(var, '') return f'{name}\n{size_to_str(var.size())}' def get_var_name_with_flag(var: torch.Tensor) -> str | None: From 90f805ef7523c8e231c9585679c34f934c898b8a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 19 Mar 2024 23:57:06 +0800 Subject: [PATCH 13/26] deps(workflows): bump pypa/cibuildwheel from 2.16 to 2.17 (#208) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Xuehai Pan --- .github/workflows/build.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 921cf4af..894442dd 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -133,7 +133,7 @@ jobs: run: python .github/workflows/set_cibw_build.py - name: Build wheels - uses: pypa/cibuildwheel@v2.16 + uses: pypa/cibuildwheel@v2.17 env: CIBW_BUILD: ${{ env.CIBW_BUILD }} with: @@ -184,7 +184,7 @@ jobs: run: python .github/workflows/set_cibw_build.py - name: Build wheels - uses: pypa/cibuildwheel@v2.16 + uses: pypa/cibuildwheel@v2.17 env: CIBW_BUILD: ${{ env.CIBW_BUILD }} with: From 62c53478aee325f37c3ed4f30d1091777dcf0260 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 20 Mar 2024 01:58:40 +0800 Subject: [PATCH 14/26] deps(workflows): bump codecov/codecov-action from 3 to 4 (#205) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Xuehai Pan --- .github/workflows/tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2f9f03d9..6e51ecad 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -88,7 +88,7 @@ jobs: - name: Upload coverage to Codecov if: runner.os == 'Linux' - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} file: ./tests/coverage.xml @@ -137,7 +137,7 @@ jobs: - name: Upload coverage to Codecov if: runner.os == 'Linux' - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} file: ./tests/coverage.xml From ae04ce0464d3c1ff353754cd59029aae9d6e638d Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 21 Mar 2024 13:48:41 +0000 Subject: [PATCH 15/26] chore: update license header --- docs/source/_static/css/style.css | 2 +- examples/FuncTorch/parallel_train_torchopt.py | 2 +- examples/L2R/helpers/argument.py | 2 +- examples/L2R/helpers/model.py | 2 +- examples/L2R/l2r.py | 2 +- examples/LOLA/helpers/argument.py | 2 +- examples/LOLA/helpers/env.py | 2 +- examples/LOLA/helpers/utils.py | 2 +- examples/LOLA/lola_dice.py | 2 +- examples/LOLA/visualize.py | 2 +- examples/MAML-RL/helpers/__init__.py | 2 +- examples/MAML-RL/helpers/policy.py | 2 +- examples/MAML-RL/helpers/policy_torchrl.py | 2 +- examples/MAML-RL/helpers/tabular_mdp.py | 2 +- examples/MAML-RL/maml.py | 2 +- examples/MAML-RL/maml_torchrl.py | 2 +- examples/MGRL/mgrl.py | 2 +- examples/visualize.py | 2 +- src/CMakeLists.txt | 2 +- src/extension.cpp | 2 +- torchopt/accelerated_op/_src/__init__.py | 2 +- torchopt/diff/implicit/__init__.py | 2 +- torchopt/distributed/autograd.py | 2 +- torchopt/distributed/world.py | 2 +- torchopt/linalg/__init__.py | 2 +- torchopt/linalg/cg.py | 2 +- torchopt/linalg/ns.py | 2 +- torchopt/linalg/utils.py | 2 +- torchopt/linear_solve/__init__.py | 2 +- torchopt/linear_solve/cg.py | 2 +- torchopt/linear_solve/inv.py | 2 +- torchopt/linear_solve/normal_cg.py | 2 +- torchopt/linear_solve/utils.py | 2 +- torchopt/optim/base.py | 2 +- torchopt/optim/func/__init__.py | 2 +- torchopt/optim/meta/adadelta.py | 2 +- torchopt/optim/meta/adam.py | 2 +- torchopt/optim/meta/adamax.py | 2 +- torchopt/optim/meta/radam.py | 2 +- torchopt/optim/meta/rmsprop.py | 2 +- torchopt/optim/meta/sgd.py | 2 +- torchopt/optim/rmsprop.py | 2 +- torchopt/optim/sgd.py | 2 +- torchopt/update.py | 2 +- torchopt/version.py | 2 +- 45 files changed, 45 insertions(+), 45 deletions(-) diff --git a/docs/source/_static/css/style.css b/docs/source/_static/css/style.css index 15b8d2e2..60f33012 100644 --- a/docs/source/_static/css/style.css +++ b/docs/source/_static/css/style.css @@ -1,5 +1,5 @@ /** - * Copyright 2022 MetaOPT Team. All Rights Reserved. + * Copyright 2022-2024 MetaOPT Team. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/examples/FuncTorch/parallel_train_torchopt.py b/examples/FuncTorch/parallel_train_torchopt.py index f28bded7..523515e1 100644 --- a/examples/FuncTorch/parallel_train_torchopt.py +++ b/examples/FuncTorch/parallel_train_torchopt.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/L2R/helpers/argument.py b/examples/L2R/helpers/argument.py index 5df9f314..7db6c982 100644 --- a/examples/L2R/helpers/argument.py +++ b/examples/L2R/helpers/argument.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/L2R/helpers/model.py b/examples/L2R/helpers/model.py index dbde0e8d..877ad50a 100644 --- a/examples/L2R/helpers/model.py +++ b/examples/L2R/helpers/model.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/L2R/l2r.py b/examples/L2R/l2r.py index 64990976..a0ae764b 100644 --- a/examples/L2R/l2r.py +++ b/examples/L2R/l2r.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/LOLA/helpers/argument.py b/examples/LOLA/helpers/argument.py index 39618134..ad53c056 100644 --- a/examples/LOLA/helpers/argument.py +++ b/examples/LOLA/helpers/argument.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/LOLA/helpers/env.py b/examples/LOLA/helpers/env.py index f496276e..e1576a7d 100644 --- a/examples/LOLA/helpers/env.py +++ b/examples/LOLA/helpers/env.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/LOLA/helpers/utils.py b/examples/LOLA/helpers/utils.py index 20f67be5..4dd436ec 100644 --- a/examples/LOLA/helpers/utils.py +++ b/examples/LOLA/helpers/utils.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/LOLA/lola_dice.py b/examples/LOLA/lola_dice.py index 6dbaaf24..485de894 100644 --- a/examples/LOLA/lola_dice.py +++ b/examples/LOLA/lola_dice.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/LOLA/visualize.py b/examples/LOLA/visualize.py index 6dc54ddf..7af19b21 100755 --- a/examples/LOLA/visualize.py +++ b/examples/LOLA/visualize.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/MAML-RL/helpers/__init__.py b/examples/MAML-RL/helpers/__init__.py index 9855e0b3..31d45c37 100644 --- a/examples/MAML-RL/helpers/__init__.py +++ b/examples/MAML-RL/helpers/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/MAML-RL/helpers/policy.py b/examples/MAML-RL/helpers/policy.py index 0a43a5b1..0bf8e188 100644 --- a/examples/MAML-RL/helpers/policy.py +++ b/examples/MAML-RL/helpers/policy.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/MAML-RL/helpers/policy_torchrl.py b/examples/MAML-RL/helpers/policy_torchrl.py index 91bdb269..5bcb3863 100644 --- a/examples/MAML-RL/helpers/policy_torchrl.py +++ b/examples/MAML-RL/helpers/policy_torchrl.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/MAML-RL/helpers/tabular_mdp.py b/examples/MAML-RL/helpers/tabular_mdp.py index f8feb7b7..0d8f5f7f 100644 --- a/examples/MAML-RL/helpers/tabular_mdp.py +++ b/examples/MAML-RL/helpers/tabular_mdp.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/MAML-RL/maml.py b/examples/MAML-RL/maml.py index 42fddbac..0cb57a92 100644 --- a/examples/MAML-RL/maml.py +++ b/examples/MAML-RL/maml.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/MAML-RL/maml_torchrl.py b/examples/MAML-RL/maml_torchrl.py index 225f73bc..56db91ef 100644 --- a/examples/MAML-RL/maml_torchrl.py +++ b/examples/MAML-RL/maml_torchrl.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/MGRL/mgrl.py b/examples/MGRL/mgrl.py index 49eb79c4..bdc4b140 100644 --- a/examples/MGRL/mgrl.py +++ b/examples/MGRL/mgrl.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/visualize.py b/examples/visualize.py index 5e08267f..067fa511 100644 --- a/examples/visualize.py +++ b/examples/visualize.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2f4ae731..30d18335 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/extension.cpp b/src/extension.cpp index 45880bf6..3d3d52b3 100644 --- a/src/extension.cpp +++ b/src/extension.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2024 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/torchopt/accelerated_op/_src/__init__.py b/torchopt/accelerated_op/_src/__init__.py index bbf0b4cd..8c2f7b03 100644 --- a/torchopt/accelerated_op/_src/__init__.py +++ b/torchopt/accelerated_op/_src/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/diff/implicit/__init__.py b/torchopt/diff/implicit/__init__.py index 4e50b615..21737015 100644 --- a/torchopt/diff/implicit/__init__.py +++ b/torchopt/diff/implicit/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/distributed/autograd.py b/torchopt/distributed/autograd.py index 4e10d24e..f7da4f46 100644 --- a/torchopt/distributed/autograd.py +++ b/torchopt/distributed/autograd.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/distributed/world.py b/torchopt/distributed/world.py index a9821ee0..a61280c5 100644 --- a/torchopt/distributed/world.py +++ b/torchopt/distributed/world.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/linalg/__init__.py b/torchopt/linalg/__init__.py index 20dc16aa..fc499d67 100644 --- a/torchopt/linalg/__init__.py +++ b/torchopt/linalg/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/linalg/cg.py b/torchopt/linalg/cg.py index 42cb6bea..a82ff877 100644 --- a/torchopt/linalg/cg.py +++ b/torchopt/linalg/cg.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/linalg/ns.py b/torchopt/linalg/ns.py index ce49fe77..b049a5ad 100644 --- a/torchopt/linalg/ns.py +++ b/torchopt/linalg/ns.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/linalg/utils.py b/torchopt/linalg/utils.py index e3cd197e..a5ac765d 100644 --- a/torchopt/linalg/utils.py +++ b/torchopt/linalg/utils.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/linear_solve/__init__.py b/torchopt/linear_solve/__init__.py index 8d9115d3..2d61eb6d 100644 --- a/torchopt/linear_solve/__init__.py +++ b/torchopt/linear_solve/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/linear_solve/cg.py b/torchopt/linear_solve/cg.py index e8f9fb77..f4127639 100644 --- a/torchopt/linear_solve/cg.py +++ b/torchopt/linear_solve/cg.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/linear_solve/inv.py b/torchopt/linear_solve/inv.py index e2a377d5..f37be8c5 100644 --- a/torchopt/linear_solve/inv.py +++ b/torchopt/linear_solve/inv.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/linear_solve/normal_cg.py b/torchopt/linear_solve/normal_cg.py index 78813ecb..405ab43c 100644 --- a/torchopt/linear_solve/normal_cg.py +++ b/torchopt/linear_solve/normal_cg.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/linear_solve/utils.py b/torchopt/linear_solve/utils.py index 22dcec6f..5e4bf7bd 100644 --- a/torchopt/linear_solve/utils.py +++ b/torchopt/linear_solve/utils.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/base.py b/torchopt/optim/base.py index d0be2fd1..bdaa0d67 100644 --- a/torchopt/optim/base.py +++ b/torchopt/optim/base.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/func/__init__.py b/torchopt/optim/func/__init__.py index f14fc6ae..f136f808 100644 --- a/torchopt/optim/func/__init__.py +++ b/torchopt/optim/func/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/meta/adadelta.py b/torchopt/optim/meta/adadelta.py index 36d8d9ad..49bdf23c 100644 --- a/torchopt/optim/meta/adadelta.py +++ b/torchopt/optim/meta/adadelta.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/meta/adam.py b/torchopt/optim/meta/adam.py index bd9804b9..bac71790 100644 --- a/torchopt/optim/meta/adam.py +++ b/torchopt/optim/meta/adam.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/meta/adamax.py b/torchopt/optim/meta/adamax.py index 01082af2..568a46f7 100644 --- a/torchopt/optim/meta/adamax.py +++ b/torchopt/optim/meta/adamax.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/meta/radam.py b/torchopt/optim/meta/radam.py index baf4cdd2..a32670d0 100644 --- a/torchopt/optim/meta/radam.py +++ b/torchopt/optim/meta/radam.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/meta/rmsprop.py b/torchopt/optim/meta/rmsprop.py index 3aff20e1..a8b4abfa 100644 --- a/torchopt/optim/meta/rmsprop.py +++ b/torchopt/optim/meta/rmsprop.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/meta/sgd.py b/torchopt/optim/meta/sgd.py index 476ed9d6..81e04413 100644 --- a/torchopt/optim/meta/sgd.py +++ b/torchopt/optim/meta/sgd.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/rmsprop.py b/torchopt/optim/rmsprop.py index 5c4e536f..032e5864 100644 --- a/torchopt/optim/rmsprop.py +++ b/torchopt/optim/rmsprop.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/sgd.py b/torchopt/optim/sgd.py index 3da9595a..27cd53c1 100644 --- a/torchopt/optim/sgd.py +++ b/torchopt/optim/sgd.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/update.py b/torchopt/update.py index 3a2a6984..8636d7a4 100644 --- a/torchopt/update.py +++ b/torchopt/update.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/version.py b/torchopt/version.py index 685735e6..a1618caf 100644 --- a/torchopt/version.py +++ b/torchopt/version.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 54d139b459ab3cc1f24ae6f24769e3fea6cd1f38 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 25 Mar 2024 09:03:12 +0000 Subject: [PATCH 16/26] chore(pre-commit): update pre-commit hooks --- .pre-commit-config.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d36f0472..a358c69f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,11 +26,11 @@ repos: - id: debug-statements - id: double-quote-string-fixer - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v18.1.1 + rev: v18.1.2 hooks: - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.3.3 + rev: v0.3.4 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] @@ -43,7 +43,7 @@ repos: hooks: - id: black-jupyter - repo: https://github.com/asottile/pyupgrade - rev: v3.15.1 + rev: v3.15.2 hooks: - id: pyupgrade args: [--py38-plus] # sync with requires-python From 3b6bb672fa65b6a9ab84ec183ddb947aa5eedb25 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Apr 2024 19:36:31 +0800 Subject: [PATCH 17/26] chore(pre-commit): [pre-commit.ci] autoupdate (#209) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.3.4 → v0.3.5](https://github.com/astral-sh/ruff-pre-commit/compare/v0.3.4...v0.3.5) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a358c69f..e150d058 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,7 +30,7 @@ repos: hooks: - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.3.4 + rev: v0.3.5 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] From 4c4c646eca156b3cf75ba0a9efe931fee1c1f470 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 9 May 2024 20:17:19 +0800 Subject: [PATCH 18/26] chore(pre-commit): [pre-commit.ci] autoupdate (#213) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/pre-commit/pre-commit-hooks: v4.5.0 → v4.6.0](https://github.com/pre-commit/pre-commit-hooks/compare/v4.5.0...v4.6.0) - [github.com/pre-commit/mirrors-clang-format: v18.1.2 → v18.1.4](https://github.com/pre-commit/mirrors-clang-format/compare/v18.1.2...v18.1.4) - [github.com/astral-sh/ruff-pre-commit: v0.3.5 → v0.4.3](https://github.com/astral-sh/ruff-pre-commit/compare/v0.3.5...v0.4.3) - [github.com/psf/black: 24.3.0 → 24.4.2](https://github.com/psf/black/compare/24.3.0...24.4.2) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e150d058..8e501e70 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ ci: default_stages: [commit, push, manual] repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v4.6.0 hooks: - id: check-symlinks - id: destroyed-symlinks @@ -26,11 +26,11 @@ repos: - id: debug-statements - id: double-quote-string-fixer - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v18.1.2 + rev: v18.1.4 hooks: - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.3.5 + rev: v0.4.3 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] @@ -39,7 +39,7 @@ repos: hooks: - id: isort - repo: https://github.com/psf/black - rev: 24.3.0 + rev: 24.4.2 hooks: - id: black-jupyter - repo: https://github.com/asottile/pyupgrade From 190ca7233089a8ca1d53bc448d4cd5ff8719e5c1 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 10 May 2024 17:19:14 +0800 Subject: [PATCH 19/26] deps(torch): support PyTorch 2.0+ only (#215) --- .github/workflows/build.yml | 4 +-- .github/workflows/lint.yml | 2 +- .github/workflows/tests.yml | 2 +- .pre-commit-config.yaml | 4 +-- CHANGELOG.md | 2 +- CMakeLists.txt | 2 +- Dockerfile | 2 +- Makefile | 8 ++++-- README.md | 6 ++-- conda-recipe-minimal-cpu.yaml | 8 +++--- conda-recipe-minimal.yaml | 18 ++++++------ conda-recipe.yaml | 40 +++++++++++++------------- docs/conda-recipe.yaml | 16 +++++------ docs/requirements.txt | 2 +- docs/source/developer/contributing.rst | 8 +++--- docs/source/index.rst | 2 +- examples/requirements.txt | 4 +-- pyproject.toml | 19 ++++++------ requirements.txt | 4 +-- tests/helpers.py | 2 +- tests/requirements.txt | 14 ++++----- tutorials/requirements.txt | 6 ++-- 22 files changed, 90 insertions(+), 85 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 894442dd..57857770 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -37,8 +37,8 @@ concurrency: cancel-in-progress: ${{ github.event_name == 'pull_request' }} env: - CUDA_VERSION: "11.7" - TEST_TORCH_SPECS: "cpu cu116" + CUDA_VERSION: "12.1" + TEST_TORCH_SPECS: "cpu cu118" jobs: build: diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 62f96340..472d5967 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -16,7 +16,7 @@ concurrency: cancel-in-progress: ${{ github.event_name == 'pull_request' }} env: - CUDA_VERSION: "11.7" + CUDA_VERSION: "12.1" jobs: lint: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6e51ecad..24d06fe7 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -27,7 +27,7 @@ concurrency: cancel-in-progress: ${{ github.event_name == 'pull_request' }} env: - CUDA_VERSION: "11.7" + CUDA_VERSION: "12.1" jobs: test: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8e501e70..621632c9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,11 +26,11 @@ repos: - id: debug-statements - id: double-quote-string-fixer - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v18.1.4 + rev: v18.1.5 hooks: - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.3 + rev: v0.4.4 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/CHANGELOG.md b/CHANGELOG.md index 315d24db..70b44628 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,7 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Removed -- +- Drop PyTorch 1.x support by [@XuehaiPan](https://github.com/XuehaiPan) in [#215](https://github.com/metaopt/torchopt/pull/215). ------ diff --git a/CMakeLists.txt b/CMakeLists.txt index bf17f38c..13814dee 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,7 +17,7 @@ cmake_minimum_required(VERSION 3.11) # for FetchContent project(torchopt LANGUAGES CXX) include(FetchContent) -set(PYBIND11_VERSION v2.11.1) +set(PYBIND11_VERSION v2.12.0) if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) diff --git a/Dockerfile b/Dockerfile index 7295af74..246a81e9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,7 +7,7 @@ # $ docker build --target devel --tag torchopt-devel:latest . # -ARG cuda_docker_tag="11.7.1-cudnn8-devel-ubuntu22.04" +ARG cuda_docker_tag="12.1.0-cudnn8-devel-ubuntu22.04" FROM nvidia/cuda:"${cuda_docker_tag}" AS builder ENV DEBIAN_FRONTEND=noninteractive diff --git a/Makefile b/Makefile index 1c42fcd7..fd25d8d9 100644 --- a/Makefile +++ b/Makefile @@ -236,7 +236,11 @@ docker-devel: docker: docker-base docker-devel docker-run-base: docker-base - docker run --network=host --gpus=all -v /:/host -h ubuntu -it $(PROJECT_NAME):$(COMMIT_HASH) + docker run -it --network=host --gpus=all --shm-size=4gb -v /:/host -v "${PWD}:/workspace" \ + -h ubuntu -w /workspace $(PROJECT_NAME):$(COMMIT_HASH) docker-run-devel: docker-devel - docker run --network=host --gpus=all -v /:/host -h ubuntu -it $(PROJECT_NAME)-devel:$(COMMIT_HASH) + docker run -it --network=host --gpus=all --shm-size=4gb -v /:/host -v "${PWD}:/workspace" \ + -h ubuntu -w /workspace $(PROJECT_NAME)-devel:$(COMMIT_HASH) + +docker-run: docker-run-base diff --git a/README.md b/README.md index 2300f619..91d44a25 100644 --- a/README.md +++ b/README.md @@ -425,11 +425,11 @@ Then run the following command to install TorchOpt from PyPI ([![PyPI](https://i pip3 install torchopt ``` -If the minimum version of PyTorch is not satisfied, `pip` will install/upgrade it for you. Please be careful about the `torch` build for CPU / CUDA support (e.g. `cpu`, `cu116`, `cu117`). +If the minimum version of PyTorch is not satisfied, `pip` will install/upgrade it for you. Please be careful about the `torch` build for CPU / CUDA support (e.g. `cpu`, `cu118`, `cu121`). You may need to specify the extra index URL for the `torch` package: ```bash -pip3 install torchopt --extra-index-url https://download.pytorch.org/whl/cu117 +pip3 install torchopt --extra-index-url https://download.pytorch.org/whl/cu121 ``` See for more information about installing PyTorch. @@ -450,7 +450,7 @@ git clone https://github.com/metaopt/torchopt.git cd torchopt # You may need `CONDA_OVERRIDE_CUDA` if conda fails to detect the NVIDIA driver (e.g. in docker or WSL2) -CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe-minimal.yaml +CONDA_OVERRIDE_CUDA=12.1 conda env create --file conda-recipe-minimal.yaml conda activate torchopt make install-editable # or run `pip3 install --no-build-isolation --editable .` diff --git a/conda-recipe-minimal-cpu.yaml b/conda-recipe-minimal-cpu.yaml index c872924d..dda60369 100644 --- a/conda-recipe-minimal-cpu.yaml +++ b/conda-recipe-minimal-cpu.yaml @@ -26,11 +26,11 @@ channels: - conda-forge dependencies: - - python = 3.10 + - python = 3.11 - pip # Learning - - pytorch::pytorch >= 1.13 # sync with project.dependencies + - pytorch::pytorch >= 2.0 # sync with project.dependencies - pytorch::torchvision - pytorch::pytorch-mutex = *=*cpu* - pip: @@ -40,10 +40,10 @@ dependencies: - cmake >= 3.11 - make - cxx-compiler - - pybind11 >= 2.10.1 + - pybind11 >= 2.11.1 # Misc - optree >= 0.4.1 - - typing-extensions >= 4.0.0 + - typing-extensions - numpy - python-graphviz diff --git a/conda-recipe-minimal.yaml b/conda-recipe-minimal.yaml index 39ff72cc..7e28d2ef 100644 --- a/conda-recipe-minimal.yaml +++ b/conda-recipe-minimal.yaml @@ -15,41 +15,41 @@ # # Create virtual environment with command: # -# $ CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe-minimal.yaml +# $ CONDA_OVERRIDE_CUDA=12.1 conda env create --file conda-recipe-minimal.yaml # name: torchopt channels: - pytorch - - nvidia/label/cuda-11.7.1 + - nvidia/label/cuda-12.1.0 - defaults - conda-forge dependencies: - - python = 3.10 + - python = 3.11 - pip # Learning - - pytorch::pytorch >= 1.13 # sync with project.dependencies + - pytorch::pytorch >= 2.0 # sync with project.dependencies - pytorch::torchvision - pytorch::pytorch-mutex = *=*cuda* - pip: - torchviz # Device select - - nvidia/label/cuda-11.7.1::cuda-toolkit = 11.7 + - nvidia/label/cuda-12.1.0::cuda-toolkit = 12.1 # Build toolchain - cmake >= 3.11 - make - cxx-compiler - - nvidia/label/cuda-11.7.1::cuda-nvcc - - nvidia/label/cuda-11.7.1::cuda-cudart-dev - - pybind11 >= 2.10.1 + - nvidia/label/cuda-12.1.0::cuda-nvcc + - nvidia/label/cuda-12.1.0::cuda-cudart-dev + - pybind11 >= 2.11.1 # Misc - optree >= 0.4.1 - - typing-extensions >= 4.0.0 + - typing-extensions - numpy - python-graphviz diff --git a/conda-recipe.yaml b/conda-recipe.yaml index c82eb3c6..9753852b 100644 --- a/conda-recipe.yaml +++ b/conda-recipe.yaml @@ -15,49 +15,49 @@ # # Create virtual environment with command: # -# $ CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe.yaml +# $ CONDA_OVERRIDE_CUDA=12.1 conda env create --file conda-recipe.yaml # name: torchopt channels: - pytorch - - nvidia/label/cuda-11.7.1 + - nvidia/label/cuda-12.1.0 - defaults - conda-forge dependencies: - - python = 3.10 + - python = 3.11 - pip # Learning - - pytorch::pytorch >= 1.13 # sync with project.dependencies + - pytorch::pytorch >= 2.0 # sync with project.dependencies - pytorch::torchvision - pytorch::pytorch-mutex = *=*cuda* - pip: - torchviz - sphinxcontrib-katex # for documentation - - jax # for tutorials - - jaxlib # for tutorials - - optax # for tutorials - - jaxopt # for tests + - conda-forge::jax # for tutorials + - conda-forge::jaxlib # for tutorials + - conda-forge::optax # for tutorials + - conda-forge::jaxopt # for tests - tensorboard # for examples # Device select - - nvidia/label/cuda-11.7.1::cuda-toolkit = 11.7 + - nvidia/label/cuda-12.1.0::cuda-toolkit = 12.1 # Build toolchain - cmake >= 3.11 - make - cxx-compiler - - nvidia/label/cuda-11.7.1::cuda-nvcc - - nvidia/label/cuda-11.7.1::cuda-cudart-dev + - nvidia/label/cuda-12.1.0::cuda-nvcc + - nvidia/label/cuda-12.1.0::cuda-cudart-dev - patchelf >= 0.14 - - pybind11 >= 2.10.1 + - pybind11 >= 2.11.1 # Misc - optree >= 0.4.1 - - typing-extensions >= 4.0.0 + - typing-extensions - numpy - matplotlib-base - seaborn @@ -83,10 +83,10 @@ dependencies: - pytest - pytest-cov - pytest-xdist - - isort >= 5.11.0 - - conda-forge::black-jupyter >= 22.6.0 - - pylint >= 2.15.0 - - mypy >= 0.990 + - isort + - conda-forge::black-jupyter + - pylint + - mypy - flake8 - flake8-bugbear - flake8-comprehensions @@ -96,8 +96,8 @@ dependencies: - ruff - doc8 - pydocstyle - - clang-format >= 14 - - clang-tools >= 14 # clang-tidy - - cpplint + - conda-forge::clang-format >= 14 + - conda-forge::clang-tools >= 14 # clang-tidy + - conda-forge::cpplint - conda-forge::pre-commit - conda-forge::identify diff --git a/docs/conda-recipe.yaml b/docs/conda-recipe.yaml index b6f1b580..d7d2f288 100644 --- a/docs/conda-recipe.yaml +++ b/docs/conda-recipe.yaml @@ -15,23 +15,23 @@ # # Create virtual environment with command: # -# $ CONDA_OVERRIDE_CUDA=11.7 conda env create --file docs/conda-recipe.yaml +# $ CONDA_OVERRIDE_CUDA=12.1 conda env create --file docs/conda-recipe.yaml # name: torchopt-docs channels: - pytorch - - nvidia/label/cuda-11.7.1 + - nvidia/label/cuda-12.1.0 - defaults - conda-forge dependencies: - - python = 3.10 + - python = 3.11 - pip # Learning - - pytorch::pytorch >= 1.13 # sync with project.dependencies + - pytorch::pytorch >= 2.0 # sync with project.dependencies - pytorch::cpuonly - pytorch::pytorch-mutex = *=*cpu* - pip: @@ -42,13 +42,13 @@ dependencies: - cmake >= 3.11 - make - cxx-compiler - - nvidia/label/cuda-11.7.1::cuda-nvcc - - nvidia/label/cuda-11.7.1::cuda-cudart-dev - - pybind11 >= 2.10.1 + - nvidia/label/cuda-12.1.0::cuda-nvcc + - nvidia/label/cuda-12.1.0::cuda-cudart-dev + - pybind11 >= 2.11.1 # Misc - optree >= 0.4.1 - - typing-extensions >= 4.0.0 + - typing-extensions - numpy - matplotlib-base - seaborn diff --git a/docs/requirements.txt b/docs/requirements.txt index 82bf2d91..c9631b75 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,6 @@ --extra-index-url https://download.pytorch.org/whl/cpu # Sync with project.dependencies -torch >= 1.13 +torch >= 2.0 --requirement ../requirements.txt diff --git a/docs/source/developer/contributing.rst b/docs/source/developer/contributing.rst index 4e7dd355..e40a564a 100644 --- a/docs/source/developer/contributing.rst +++ b/docs/source/developer/contributing.rst @@ -17,7 +17,7 @@ Before contributing to TorchOpt, please follow the instructions below to setup. .. code-block:: bash # You may need `CONDA_OVERRIDE_CUDA` if conda fails to detect the NVIDIA driver (e.g. in docker or WSL2) - CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe.yaml + CONDA_OVERRIDE_CUDA=12.1 conda env create --file conda-recipe.yaml conda activate torchopt @@ -91,14 +91,14 @@ To build compatible **manylinux2014** (:pep:`599`) wheels for distribution, you pip3 install --upgrade cibuildwheel - export TEST_TORCH_SPECS="cpu cu116" # `torch` builds for testing - export CUDA_VERSION="11.7" # version of `nvcc` for compilation + export TEST_TORCH_SPECS="cpu cu118" # `torch` builds for testing + export CUDA_VERSION="12.1" # version of `nvcc` for compilation python3 -m cibuildwheel --platform=linux --output-dir=wheelhouse --config-file=pyproject.toml It will install the CUDA compiler with ``CUDA_VERSION`` in the build container. Then build wheel binaries for all supported CPython versions. The outputs will be placed in the ``wheelhouse`` directory. To build a wheel for a specific CPython version, you can use the |CIBW_BUILD|_ environment variable. -For example, the following command will build a wheel for Python 3.7: +For example, the following command will build a wheel for Python 3.8: .. code-block:: bash diff --git a/docs/source/index.rst b/docs/source/index.rst index ec9a3749..83602090 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -42,7 +42,7 @@ You can use the following commands with `conda ` cd torchopt # You may need `CONDA_OVERRIDE_CUDA` if conda fails to detect the NVIDIA driver (e.g. in docker or WSL2) - CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe-minimal.yaml + CONDA_OVERRIDE_CUDA=12.1 conda env create --file conda-recipe-minimal.yaml conda activate torchopt diff --git a/examples/requirements.txt b/examples/requirements.txt index 76bed365..48945c62 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -1,5 +1,5 @@ ---extra-index-url https://download.pytorch.org/whl/cu117 -torch >= 1.13 +--extra-index-url https://download.pytorch.org/whl/cu121 +torch >= 2.0 torchvision --requirement ../requirements.txt diff --git a/pyproject.toml b/pyproject.toml index f3a5645b..17d14b05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ [build-system] # Sync with project.dependencies -requires = ["setuptools", "torch >= 1.13", "numpy", "pybind11 >= 2.10.1"] +requires = ["setuptools", "torch >= 2.0", "numpy", "pybind11 >= 2.11.1"] build-backend = "setuptools.build_meta" [project] @@ -51,11 +51,11 @@ classifiers = [ ] dependencies = [ # See also build-system.requires and project.requires-python - "torch >= 1.13", + "torch >= 2.0", "optree >= 0.4.1", "numpy", "graphviz", - "typing-extensions >= 4.0.0", + "typing-extensions", ] dynamic = ["version"] @@ -68,9 +68,9 @@ Documentation = "https://torchopt.readthedocs.io" [project.optional-dependencies] lint = [ "isort", - "black[jupyter] >= 22.6.0", - "pylint[spelling] >= 2.15.0", - "mypy >= 0.990", + "black[jupyter]", + "pylint[spelling]", + "mypy", "flake8", "flake8-bugbear", "flake8-comprehensions", @@ -88,7 +88,7 @@ test = [ "pytest", "pytest-cov", "pytest-xdist", - "jax[cpu] >= 0.3; platform_system != 'Windows'", + "jax[cpu] >= 0.4; platform_system != 'Windows'", "jaxopt; platform_system != 'Windows'", "optax; platform_system != 'Windows'", ] @@ -113,8 +113,8 @@ build-verbosity = 3 environment.USE_FP16 = "ON" environment.CUDACXX = "/usr/local/cuda/bin/nvcc" environment.TORCH_CUDA_ARCH_LIST = "Common" -environment.DEFAULT_CUDA_VERSION = "11.7" -environment.DEFAULT_TEST_TORCH_SPECS = "cpu cu116" +environment.DEFAULT_CUDA_VERSION = "12.1" +environment.DEFAULT_TEST_TORCH_SPECS = "cpu cu118" environment-pass = ["CUDA_VERSION", "TEST_TORCH_SPECS"] container-engine = "docker" test-extras = ["test"] @@ -316,4 +316,5 @@ filterwarnings = [ 'ignore:jax\.numpy\.DeviceArray is deprecated\. Use jax\.Array\.:DeprecationWarning', 'ignore:.*functorch.*deprecate.*:UserWarning', 'ignore:.*Apple Paravirtual device.*:UserWarning', + 'ignore:.*NVML.*:UserWarning', ] diff --git a/requirements.txt b/requirements.txt index 961ddf73..a5151c36 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # Sync with project.dependencies -torch >= 1.13 +torch >= 2.0 optree >= 0.4.1 numpy graphviz -typing-extensions >= 4.0.0 +typing-extensions diff --git a/tests/helpers.py b/tests/helpers.py index 1b624ce9..0dc415d4 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -65,7 +65,7 @@ def parametrize(**argvalues) -> pytest.mark.parametrize: argvalues = list(itertools.product(*tuple(map(argvalues.get, arguments)))) ids = tuple( - '-'.join(f'{arg}({val})' for arg, val in zip(arguments, values)) for values in argvalues + '-'.join(f'{arg}({val!r})' for arg, val in zip(arguments, values)) for values in argvalues ) return pytest.mark.parametrize(arguments, argvalues, ids=ids) diff --git a/tests/requirements.txt b/tests/requirements.txt index 397ac350..ee54732b 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,9 +1,9 @@ ---extra-index-url https://download.pytorch.org/whl/cu117 -torch >= 1.13 +--extra-index-url https://download.pytorch.org/whl/cu121 +torch >= 2.0 --requirement ../requirements.txt -jax[cpu] >= 0.3; platform_system != 'Windows' +jax[cpu] >= 0.4; platform_system != 'Windows' jaxopt; platform_system != 'Windows' optax < 0.1.8a0; platform_system != 'Windows' and python_version < '3.9' optax >= 0.1.8; platform_system != 'Windows' and python_version >= '3.9' @@ -11,10 +11,10 @@ optax >= 0.1.8; platform_system != 'Windows' and python_version >= '3.9' pytest pytest-cov pytest-xdist -isort >= 5.11.0 -black[jupyter] >= 22.6.0 -pylint[spelling] >= 2.15.0 -mypy >= 0.990 +isort +black[jupyter] +pylint[spelling] +mypy flake8 flake8-bugbear flake8-comprehensions diff --git a/tutorials/requirements.txt b/tutorials/requirements.txt index ff5a5c42..e8a3be95 100644 --- a/tutorials/requirements.txt +++ b/tutorials/requirements.txt @@ -1,11 +1,11 @@ ---extra-index-url https://download.pytorch.org/whl/cu117 +--extra-index-url https://download.pytorch.org/whl/cu121 # Sync with project.dependencies -torch >= 1.13 +torch >= 2.0 torchvision --requirement ../requirements.txt ipykernel -jax[cpu] >= 0.3 +jax[cpu] >= 0.4 jaxopt optax From 330ef575eed39a24ead182d8fd853b776fc39f96 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 10 May 2024 18:30:38 +0800 Subject: [PATCH 20/26] refactor(setup.py): refactor build system (#214) --- CHANGELOG.md | 2 +- setup.py | 81 +++++++++++++++++++++++++++++----------------------- 2 files changed, 47 insertions(+), 36 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 70b44628..6489696c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,7 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- +- Refactor the raw import statement in `setup.py` with `importlib` utilities by [@XuehaiPan](https://github.com/XuehaiPan) in [#214](https://github.com/metaopt/torchopt/pull/214). ### Fixed diff --git a/setup.py b/setup.py index dc1103df..c50ba5ed 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,4 @@ +import contextlib import os import pathlib import platform @@ -5,22 +6,13 @@ import shutil import sys import sysconfig +from importlib.util import module_from_spec, spec_from_file_location -from setuptools import setup +from setuptools import Extension, setup +from setuptools.command.build_ext import build_ext -try: - from pybind11.setup_helpers import Pybind11Extension as Extension - from pybind11.setup_helpers import build_ext -except ImportError: - from setuptools import Extension - from setuptools.command.build_ext import build_ext - HERE = pathlib.Path(__file__).absolute().parent -VERSION_FILE = HERE / 'torchopt' / 'version.py' - -sys.path.insert(0, str(VERSION_FILE.parent)) -import version # noqa class CMakeExtension(Extension): @@ -47,7 +39,6 @@ def build_extension(self, ext): build_temp.mkdir(parents=True, exist_ok=True) config = 'Debug' if self.debug else 'Release' - cmake_args = [ f'-DCMAKE_BUILD_TYPE={config}', f'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{config.upper()}={ext_path.parent}', @@ -83,13 +74,53 @@ def build_extension(self, ext): build_args.extend(['--target', ext.target, '--']) + cwd = os.getcwd() try: os.chdir(build_temp) self.spawn([cmake, ext.source_dir, *cmake_args]) if not self.dry_run: self.spawn([cmake, '--build', '.', *build_args]) finally: - os.chdir(HERE) + os.chdir(cwd) + + +@contextlib.contextmanager +def vcs_version(name, path): + path = pathlib.Path(path).absolute() + assert path.is_file() + module_spec = spec_from_file_location(name=name, location=path) + assert module_spec is not None + assert module_spec.loader is not None + module = sys.modules.get(name) + if module is None: + module = module_from_spec(module_spec) + sys.modules[name] = module + module_spec.loader.exec_module(module) + + if module.__release__: + yield module + return + + content = None + try: + try: + content = path.read_text(encoding='utf-8') + path.write_text( + data=re.sub( + r"""__version__\s*=\s*('[^']+'|"[^"]+")""", + f'__version__ = {module.__version__!r}', + string=content, + ), + encoding='utf-8', + ) + except OSError: + content = None + + yield module + finally: + if content is not None: + with path.open(mode='wt', encoding='utf-8', newline='') as file: + file.write(content) CIBUILDWHEEL = os.getenv('CIBUILDWHEEL', '0') == '1' @@ -112,29 +143,9 @@ def build_extension(self, ext): ext_kwargs.clear() -VERSION_CONTENT = None - -try: - if not version.__release__: - try: - VERSION_CONTENT = VERSION_FILE.read_text(encoding='utf-8') - VERSION_FILE.write_text( - data=re.sub( - r"""__version__\s*=\s*('[^']+'|"[^"]+")""", - f'__version__ = {version.__version__!r}', - string=VERSION_CONTENT, - ), - encoding='utf-8', - ) - except OSError: - VERSION_CONTENT = None - +with vcs_version(name='torchopt.version', path=(HERE / 'torchopt' / 'version.py')) as version: setup( name='torchopt', version=version.__version__, **ext_kwargs, ) -finally: - if VERSION_CONTENT is not None: - with VERSION_FILE.open(mode='wt', encoding='utf-8', newline='') as file: - file.write(VERSION_CONTENT) From d81999cd1faedb60d019249ed983e21dfe4f9622 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 10 May 2024 23:56:55 +0800 Subject: [PATCH 21/26] feat(workflows): enable Python 3.12 build (#216) --- .github/workflows/build.yml | 21 +++++++++++++++------ CHANGELOG.md | 2 +- Makefile | 5 +++-- pyproject.toml | 7 +++++-- 4 files changed, 24 insertions(+), 11 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 57857770..62497276 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -56,9 +56,12 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: "3.8 - 3.11" # sync with requires-python in pyproject.toml + python-version: "3.8 - 3.12" # sync with requires-python in pyproject.toml update-environment: true + - name: Install dependencies + run: python -m pip install --upgrade pip setuptools wheel build + - name: Set __release__ if: | startsWith(github.ref, 'refs/tags/') || @@ -69,9 +72,6 @@ jobs: - name: Print version run: python setup.py --version - - name: Install dependencies - run: python -m pip install --upgrade pip setuptools wheel build - - name: Build sdist and pure-Python wheel run: python -m build env: @@ -120,6 +120,9 @@ jobs: python-version: ${{ matrix.python-version }} update-environment: true + - name: Install dependencies + run: python -m pip install --upgrade pip setuptools wheel build + - name: Set __release__ if: | startsWith(github.ref, 'refs/tags/') || @@ -155,7 +158,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: ["3.9", "3.10", "3.11"] # sync with requires-python in pyproject.toml + python-version: ["3.9", "3.10", "3.11", "3.12"] # sync with requires-python in pyproject.toml fail-fast: false timeout-minutes: 60 steps: @@ -171,6 +174,9 @@ jobs: python-version: ${{ matrix.python-version }} update-environment: true + - name: Install dependencies + run: python -m pip install --upgrade pip setuptools wheel build + - name: Set __release__ if: | startsWith(github.ref, 'refs/tags/') || @@ -249,9 +255,12 @@ jobs: uses: actions/setup-python@v5 if: startsWith(github.ref, 'refs/tags/') with: - python-version: "3.8 - 3.11" # sync with requires-python in pyproject.toml + python-version: "3.8 - 3.12" # sync with requires-python in pyproject.toml update-environment: true + - name: Install dependencies + run: python -m pip install --upgrade pip setuptools wheel build + - name: Set __release__ if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' run: | diff --git a/CHANGELOG.md b/CHANGELOG.md index 6489696c..62234c25 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- +- Enable CI workflow to build CXX/CUDA extension for Python 3.12 by [@XuehaiPan](https://github.com/XuehaiPan) in [#216](https://github.com/metaopt/torchopt/pull/216). ### Changed diff --git a/Makefile b/Makefile index fd25d8d9..e9099f0c 100644 --- a/Makefile +++ b/Makefile @@ -22,7 +22,8 @@ install: install-editable: $(PYTHON) -m pip install --upgrade pip $(PYTHON) -m pip install --upgrade setuptools wheel - $(PYTHON) -m pip install torch numpy pybind11 cmake + $(PYTHON) -m pip install --upgrade pybind11 cmake + $(PYTHON) -m pip install torch numpy USE_FP16=ON TORCH_CUDA_ARCH_LIST=Auto $(PYTHON) -m pip install -vvv --no-build-isolation --editable . install-e: install-editable # alias @@ -114,7 +115,7 @@ addlicense-install: go-install pytest: test-install $(PYTHON) -m pytest --version cd tests && $(PYTHON) -c 'import $(PROJECT_PATH)' && \ - $(PYTHON) -m pytest --verbose --color=yes --durations=0 \ + $(PYTHON) -m pytest --verbose --color=yes \ --cov="$(PROJECT_PATH)" --cov-config=.coveragerc --cov-report=xml --cov-report=term-missing \ $(PYTESTOPTS) . diff --git a/pyproject.toml b/pyproject.toml index 17d14b05..ed93944a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ authors = [ license = { text = "Apache License, Version 2.0" } keywords = [ "PyTorch", - "functorch", + "FuncTorch", "JAX", "Meta-Learning", "Optimizer", @@ -38,6 +38,8 @@ classifiers = [ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", "Operating System :: Microsoft :: Windows", "Operating System :: POSIX :: Linux", "Operating System :: MacOS", @@ -179,7 +181,7 @@ test-command = """ line-length = 100 skip-string-normalization = true # Sync with requires-python -target-version = ["py38", "py39", "py310", "py311"] +target-version = ["py38"] [tool.isort] atomic = true @@ -314,6 +316,7 @@ filterwarnings = [ "error", 'ignore:Explicitly requested dtype float64 requested in .* is not available, and will be truncated to dtype float32\.:UserWarning', 'ignore:jax\.numpy\.DeviceArray is deprecated\. Use jax\.Array\.:DeprecationWarning', + 'ignore:(ast\.Str|ast\.NameConstant|Attribute s) is deprecated and will be removed in Python 3\.14:DeprecationWarning', 'ignore:.*functorch.*deprecate.*:UserWarning', 'ignore:.*Apple Paravirtual device.*:UserWarning', 'ignore:.*NVML.*:UserWarning', From b3f570c234101d493a32a579e1f45facc581e9aa Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 13 May 2024 14:26:22 +0800 Subject: [PATCH 22/26] deps(workflows): bump pypa/cibuildwheel from 2.17 to 2.18 (#217) Bumps [pypa/cibuildwheel](https://github.com/pypa/cibuildwheel) from 2.17 to 2.18. - [Release notes](https://github.com/pypa/cibuildwheel/releases) - [Changelog](https://github.com/pypa/cibuildwheel/blob/main/docs/changelog.md) - [Commits](https://github.com/pypa/cibuildwheel/compare/v2.17...v2.18) --- updated-dependencies: - dependency-name: pypa/cibuildwheel dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/build.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 62497276..69318735 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -136,7 +136,7 @@ jobs: run: python .github/workflows/set_cibw_build.py - name: Build wheels - uses: pypa/cibuildwheel@v2.17 + uses: pypa/cibuildwheel@v2.18 env: CIBW_BUILD: ${{ env.CIBW_BUILD }} with: @@ -190,7 +190,7 @@ jobs: run: python .github/workflows/set_cibw_build.py - name: Build wheels - uses: pypa/cibuildwheel@v2.17 + uses: pypa/cibuildwheel@v2.18 env: CIBW_BUILD: ${{ env.CIBW_BUILD }} with: From a4cfc49f27fd623526db1cb13d823ac7ebc5ab41 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 24 May 2024 14:54:25 +0000 Subject: [PATCH 23/26] chore(pre-commit): update pre-commit hooks --- .pre-commit-config.yaml | 4 ++-- CODE_OF_CONDUCT.md | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 621632c9..4e930f5a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,7 +30,7 @@ repos: hooks: - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.4 + rev: v0.4.5 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] @@ -68,7 +68,7 @@ repos: ^docs/source/conf.py$ ) - repo: https://github.com/codespell-project/codespell - rev: v2.2.6 + rev: v2.3.0 hooks: - id: codespell additional_dependencies: [".[toml]"] diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index cc4a4e9a..4c96e694 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -5,7 +5,7 @@ We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender -identity and expression, level of experience, education, socio-economic status, +identity and expression, level of experience, education, socioeconomic status, nationality, personal appearance, race, religion, or sexual identity and orientation. From d24792d1c96643deda3e9adf104b4a7befebd7e8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 5 Jun 2024 02:07:46 +0800 Subject: [PATCH 24/26] chore(pre-commit): [pre-commit.ci] autoupdate (#224) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.4.5 → v0.4.7](https://github.com/astral-sh/ruff-pre-commit/compare/v0.4.5...v0.4.7) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4e930f5a..f8419466 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,7 +30,7 @@ repos: hooks: - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.5 + rev: v0.4.7 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] From 7a605a9c915ccfd10717f4488eccf53adbd463a2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 17 Jun 2024 17:30:15 +0800 Subject: [PATCH 25/26] deps(workflows): bump pypa/cibuildwheel from 2.18 to 2.19 (#225) * deps(workflows): bump pypa/cibuildwheel from 2.18 to 2.19 Bumps [pypa/cibuildwheel](https://github.com/pypa/cibuildwheel) from 2.18 to 2.19. - [Release notes](https://github.com/pypa/cibuildwheel/releases) - [Changelog](https://github.com/pypa/cibuildwheel/blob/main/docs/changelog.md) - [Commits](https://github.com/pypa/cibuildwheel/compare/v2.18...v2.19) --- updated-dependencies: - dependency-name: pypa/cibuildwheel dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] * chore(pre-commit): update pre-commit hooks * chore(CMakeLists.txt): set default `pybind11` tag to stable --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Xuehai Pan --- .github/workflows/build.yml | 4 +- .github/workflows/tests.yml | 8 +-- .pre-commit-config.yaml | 8 +-- CMakeLists.txt | 18 +++++-- pyproject.toml | 17 ++++-- tests/helpers.py | 7 ++- tests/test_alias.py | 7 ++- tests/test_implicit.py | 23 +++++--- tests/test_utils.py | 6 ++- torchopt/__init__.py | 66 +++++++++++------------ torchopt/accelerated_op/__init__.py | 9 ++-- torchopt/accelerated_op/_src/adam_op.py | 6 ++- torchopt/alias/__init__.py | 11 +++- torchopt/alias/adadelta.py | 7 ++- torchopt/alias/adam.py | 7 ++- torchopt/alias/adamax.py | 7 ++- torchopt/alias/adamw.py | 7 ++- torchopt/alias/radam.py | 7 ++- torchopt/alias/utils.py | 12 +++-- torchopt/base.py | 4 +- torchopt/clip.py | 7 ++- torchopt/combine.py | 7 ++- torchopt/diff/implicit/__init__.py | 2 +- torchopt/diff/implicit/decorator.py | 23 ++++---- torchopt/diff/implicit/nn/module.py | 10 ++-- torchopt/diff/zero_order/__init__.py | 2 +- torchopt/diff/zero_order/decorator.py | 31 +++++------ torchopt/diff/zero_order/nn/module.py | 7 ++- torchopt/distributed/api.py | 12 ++--- torchopt/distributed/autograd.py | 11 ++-- torchopt/distributed/world.py | 16 +++--- torchopt/hook.py | 13 +++-- torchopt/linalg/cg.py | 7 ++- torchopt/linalg/ns.py | 7 ++- torchopt/linalg/utils.py | 7 ++- torchopt/linear_solve/__init__.py | 2 +- torchopt/linear_solve/cg.py | 7 ++- torchopt/linear_solve/inv.py | 7 ++- torchopt/linear_solve/normal_cg.py | 7 ++- torchopt/linear_solve/utils.py | 7 ++- torchopt/nn/__init__.py | 4 +- torchopt/nn/module.py | 11 ++-- torchopt/nn/stateless.py | 10 ++-- torchopt/optim/adadelta.py | 11 ++-- torchopt/optim/adagrad.py | 11 ++-- torchopt/optim/adam.py | 11 ++-- torchopt/optim/adamax.py | 11 ++-- torchopt/optim/adamw.py | 11 ++-- torchopt/optim/func/base.py | 7 ++- torchopt/optim/meta/adadelta.py | 9 +++- torchopt/optim/meta/adagrad.py | 9 +++- torchopt/optim/meta/adam.py | 9 +++- torchopt/optim/meta/adamax.py | 9 +++- torchopt/optim/meta/adamw.py | 11 ++-- torchopt/optim/meta/radam.py | 9 +++- torchopt/optim/radam.py | 11 ++-- torchopt/pytree.py | 6 ++- torchopt/schedule/__init__.py | 2 +- torchopt/schedule/exponential_decay.py | 10 ++-- torchopt/schedule/polynomial.py | 9 +++- torchopt/transform/__init__.py | 14 ++--- torchopt/transform/add_decayed_weights.py | 15 +++--- torchopt/transform/nan_to_num.py | 9 +++- torchopt/transform/scale.py | 9 +++- torchopt/transform/scale_by_adadelta.py | 7 ++- torchopt/transform/scale_by_adam.py | 11 ++-- torchopt/transform/scale_by_adamax.py | 7 ++- torchopt/transform/scale_by_radam.py | 9 ++-- torchopt/transform/scale_by_rms.py | 7 ++- torchopt/transform/scale_by_rss.py | 7 ++- torchopt/transform/scale_by_schedule.py | 7 ++- torchopt/transform/scale_by_stddev.py | 7 ++- torchopt/transform/trace.py | 11 ++-- torchopt/transform/utils.py | 11 ++-- torchopt/typing.py | 52 +++++++++--------- torchopt/update.py | 9 +++- torchopt/utils.py | 12 ++--- torchopt/visual.py | 11 ++-- 78 files changed, 538 insertions(+), 291 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 69318735..99b553e4 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -136,7 +136,7 @@ jobs: run: python .github/workflows/set_cibw_build.py - name: Build wheels - uses: pypa/cibuildwheel@v2.18 + uses: pypa/cibuildwheel@v2.19 env: CIBW_BUILD: ${{ env.CIBW_BUILD }} with: @@ -190,7 +190,7 @@ jobs: run: python .github/workflows/set_cibw_build.py - name: Build wheels - uses: pypa/cibuildwheel@v2.18 + uses: pypa/cibuildwheel@v2.19 env: CIBW_BUILD: ${{ env.CIBW_BUILD }} with: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 24d06fe7..f156ffe3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -80,15 +80,15 @@ jobs: USE_FP16: "ON" TORCH_CUDA_ARCH_LIST: "Common" run: | - python -m pip install -vvv -e . + python -m pip install -vvv --editable . - name: Test with pytest run: | make pytest - name: Upload coverage to Codecov - if: runner.os == 'Linux' uses: codecov/codecov-action@v4 + if: ${{ matrix.os == 'ubuntu-latest' }} with: token: ${{ secrets.CODECOV_TOKEN }} file: ./tests/coverage.xml @@ -127,7 +127,7 @@ jobs: - name: Install TorchOpt run: | - python -m pip install -vvv -e . + python -m pip install -vvv --editable . env: TORCHOPT_NO_EXTENSIONS: "true" @@ -136,8 +136,8 @@ jobs: make pytest - name: Upload coverage to Codecov - if: runner.os == 'Linux' uses: codecov/codecov-action@v4 + if: ${{ matrix.os == 'ubuntu-latest' }} with: token: ${{ secrets.CODECOV_TOKEN }} file: ./tests/coverage.xml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f8419466..4814c681 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,11 +26,11 @@ repos: - id: debug-statements - id: double-quote-string-fixer - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v18.1.5 + rev: v18.1.6 hooks: - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.7 + rev: v0.4.9 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] @@ -43,7 +43,7 @@ repos: hooks: - id: black-jupyter - repo: https://github.com/asottile/pyupgrade - rev: v3.15.2 + rev: v3.16.0 hooks: - id: pyupgrade args: [--py38-plus] # sync with requires-python @@ -52,7 +52,7 @@ repos: ^examples/ ) - repo: https://github.com/pycqa/flake8 - rev: 7.0.0 + rev: 7.1.0 hooks: - id: flake8 additional_dependencies: diff --git a/CMakeLists.txt b/CMakeLists.txt index 13814dee..101ba3ec 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,7 +17,14 @@ cmake_minimum_required(VERSION 3.11) # for FetchContent project(torchopt LANGUAGES CXX) include(FetchContent) -set(PYBIND11_VERSION v2.12.0) + +set(THIRD_PARTY_DIR "${CMAKE_SOURCE_DIR}/third-party") +if(NOT DEFINED PYBIND11_VERSION AND NOT "$ENV{PYBIND11_VERSION}" STREQUAL "") + set(PYBIND11_VERSION "$ENV{PYBIND11_VERSION}") +endif() +if(NOT PYBIND11_VERSION) + set(PYBIND11_VERSION stable) +endif() if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) @@ -172,7 +179,7 @@ endif() system( STRIP OUTPUT_VARIABLE PYTHON_VERSION - COMMAND "${PYTHON_EXECUTABLE}" -c "print(__import__('platform').python_version())" + COMMAND "${PYTHON_EXECUTABLE}" -c "print('.'.join(map(str, __import__('sys').version_info[:3])))" ) message(STATUS "Use Python version: ${PYTHON_VERSION}") @@ -216,11 +223,12 @@ if("${PYBIND11_CMAKE_DIR}" STREQUAL "") GIT_REPOSITORY https://github.com/pybind/pybind11.git GIT_TAG "${PYBIND11_VERSION}" GIT_SHALLOW TRUE - SOURCE_DIR "${CMAKE_SOURCE_DIR}/third-party/pybind11" - BINARY_DIR "${CMAKE_SOURCE_DIR}/third-party/.cmake/pybind11/build" - STAMP_DIR "${CMAKE_SOURCE_DIR}/third-party/.cmake/pybind11/stamp" + SOURCE_DIR "${THIRD_PARTY_DIR}/pybind11" + BINARY_DIR "${THIRD_PARTY_DIR}/.cmake/pybind11/build" + STAMP_DIR "${THIRD_PARTY_DIR}/.cmake/pybind11/stamp" ) FetchContent_GetProperties(pybind11) + if(NOT pybind11_POPULATED) message(STATUS "Populating Git repository pybind11@${PYBIND11_VERSION} to third-party/pybind11...") FetchContent_MakeAvailable(pybind11) diff --git a/pyproject.toml b/pyproject.toml index ed93944a..d343e04a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -235,6 +235,7 @@ extend-exclude = ["examples"] select = [ "E", "W", # pycodestyle "F", # pyflakes + "C90", # mccabe "UP", # pyupgrade "ANN", # flake8-annotations "S", # flake8-bandit @@ -243,7 +244,10 @@ select = [ "COM", # flake8-commas "C4", # flake8-comprehensions "EXE", # flake8-executable + "FA", # flake8-future-annotations + "LOG", # flake8-logging "ISC", # flake8-implicit-str-concat + "INP", # flake8-no-pep420 "PIE", # flake8-pie "PYI", # flake8-pyi "Q", # flake8-quotes @@ -251,6 +255,10 @@ select = [ "RET", # flake8-return "SIM", # flake8-simplify "TID", # flake8-tidy-imports + "TCH", # flake8-type-checking + "PERF", # perflint + "FURB", # refurb + "TRY", # tryceratops "RUF", # ruff ] ignore = [ @@ -268,9 +276,9 @@ ignore = [ # S101: use of `assert` detected # internal use and may never raise at runtime "S101", - # PLR0402: use from {module} import {name} in lieu of alias - # use alias for import convention (e.g., `import torch.nn as nn`) - "PLR0402", + # TRY003: avoid specifying long messages outside the exception class + # long messages are necessary for clarity + "TRY003", ] typing-modules = ["torchopt.typing"] @@ -296,6 +304,9 @@ typing-modules = ["torchopt.typing"] "F401", # unused-import "F811", # redefined-while-unused ] +"docs/source/conf.py" = [ + "INP001", # flake8-no-pep420 +] [tool.ruff.lint.flake8-annotations] allow-star-arg-any = true diff --git a/tests/helpers.py b/tests/helpers.py index 0dc415d4..ca5aa443 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -20,7 +20,7 @@ import itertools import os import random -from typing import Iterable +from typing import TYPE_CHECKING, Iterable import numpy as np import pytest @@ -30,7 +30,10 @@ from torch.utils import data from torchopt import pytree -from torchopt.typing import TensorTree + + +if TYPE_CHECKING: + from torchopt.typing import TensorTree BATCH_SIZE = 64 diff --git a/tests/test_alias.py b/tests/test_alias.py index 58b5a328..3c42d7c8 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -15,7 +15,7 @@ from __future__ import annotations -from typing import Callable +from typing import TYPE_CHECKING, Callable import functorch import pytest @@ -26,7 +26,10 @@ import torchopt from torchopt import pytree from torchopt.alias.utils import _set_use_chain_flat -from torchopt.typing import TensorTree + + +if TYPE_CHECKING: + from torchopt.typing import TensorTree @helpers.parametrize( diff --git a/tests/test_implicit.py b/tests/test_implicit.py index ff0ba15c..6cccb716 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -18,7 +18,7 @@ import copy import re from collections import OrderedDict -from types import FunctionType +from typing import TYPE_CHECKING import functorch import numpy as np @@ -47,6 +47,10 @@ HAS_JAX = False +if TYPE_CHECKING: + from types import FunctionType + + BATCH_SIZE = 8 NUM_UPDATES = 3 @@ -123,7 +127,7 @@ def get_rr_dataset_torch() -> data.DataLoader: inner_lr=[2e-2, 2e-3], inner_update=[20, 50, 100], ) -def test_imaml_solve_normal_cg( +def test_imaml_solve_normal_cg( # noqa: C901 dtype: torch.dtype, lr: float, inner_lr: float, @@ -251,7 +255,7 @@ def outer_level(p, xs, ys): inner_update=[20, 50, 100], ns=[False, True], ) -def test_imaml_solve_inv( +def test_imaml_solve_inv( # noqa: C901 dtype: torch.dtype, lr: float, inner_lr: float, @@ -375,7 +379,12 @@ def outer_level(p, xs, ys): inner_lr=[2e-2, 2e-3], inner_update=[20, 50, 100], ) -def test_imaml_module(dtype: torch.dtype, lr: float, inner_lr: float, inner_update: int) -> None: +def test_imaml_module( # noqa: C901 + dtype: torch.dtype, + lr: float, + inner_lr: float, + inner_update: int, +) -> None: np_dtype = helpers.dtype_torch2numpy(dtype) jax_model, jax_params = get_model_jax(dtype=np_dtype) @@ -763,7 +772,7 @@ def solve(self): make_optimality_from_objective(MyModule2) -def test_module_abstract_methods() -> None: +def test_module_abstract_methods() -> None: # noqa: C901 class MyModule1(torchopt.nn.ImplicitMetaGradientModule): def objective(self): return torch.tensor(0.0) @@ -809,7 +818,7 @@ def solve(self): class MyModule5(torchopt.nn.ImplicitMetaGradientModule): @classmethod - def optimality(self): + def optimality(cls): return () def solve(self): @@ -846,7 +855,7 @@ def solve(self): class MyModule8(torchopt.nn.ImplicitMetaGradientModule): @classmethod - def objective(self): + def objective(cls): return () def solve(self): diff --git a/tests/test_utils.py b/tests/test_utils.py index 5215e7b3..57c35e47 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================== +import operator + import torch import torchopt @@ -80,7 +82,7 @@ def test_module_clone() -> None: assert y.is_cuda -def test_extract_state_dict(): +def test_extract_state_dict(): # noqa: C901 fc = torch.nn.Linear(1, 1) state_dict = torchopt.extract_state_dict(fc, by='reference', device=torch.device('meta')) for param_dict in state_dict.params: @@ -121,7 +123,7 @@ def test_extract_state_dict(): loss = fc(torch.ones(1, 1)).sum() optim.step(loss) state_dict = torchopt.extract_state_dict(optim) - same = pytree.tree_map(lambda x, y: x is y, state_dict, tuple(optim.state_groups)) + same = pytree.tree_map(operator.is_, state_dict, tuple(optim.state_groups)) assert all(pytree.tree_flatten(same)[0]) diff --git a/torchopt/__init__.py b/torchopt/__init__.py index 5e568526..830072e3 100644 --- a/torchopt/__init__.py +++ b/torchopt/__init__.py @@ -81,50 +81,50 @@ __all__ = [ - 'accelerated_op_available', - 'adam', - 'adamax', - 'adadelta', - 'radam', - 'adamw', - 'adagrad', - 'rmsprop', - 'sgd', - 'clip_grad_norm', - 'nan_to_num', - 'register_hook', - 'chain', - 'Optimizer', 'SGD', - 'Adam', - 'AdaMax', - 'Adamax', 'AdaDelta', - 'Adadelta', - 'RAdam', - 'AdamW', 'AdaGrad', + 'AdaMax', + 'Adadelta', 'Adagrad', - 'RMSProp', - 'RMSprop', - 'MetaOptimizer', - 'MetaSGD', - 'MetaAdam', - 'MetaAdaMax', - 'MetaAdamax', + 'Adam', + 'AdamW', + 'Adamax', + 'FuncOptimizer', 'MetaAdaDelta', - 'MetaAdadelta', - 'MetaRAdam', - 'MetaAdamW', 'MetaAdaGrad', + 'MetaAdaMax', + 'MetaAdadelta', 'MetaAdagrad', + 'MetaAdam', + 'MetaAdamW', + 'MetaAdamax', + 'MetaOptimizer', + 'MetaRAdam', 'MetaRMSProp', 'MetaRMSprop', - 'FuncOptimizer', + 'MetaSGD', + 'Optimizer', + 'RAdam', + 'RMSProp', + 'RMSprop', + 'accelerated_op_available', + 'adadelta', + 'adagrad', + 'adam', + 'adamax', + 'adamw', 'apply_updates', + 'chain', + 'clip_grad_norm', 'extract_state_dict', - 'recover_state_dict', - 'stop_gradient', 'module_clone', 'module_detach_', + 'nan_to_num', + 'radam', + 'recover_state_dict', + 'register_hook', + 'rmsprop', + 'sgd', + 'stop_gradient', ] diff --git a/torchopt/accelerated_op/__init__.py b/torchopt/accelerated_op/__init__.py index 103b6fc0..90452046 100644 --- a/torchopt/accelerated_op/__init__.py +++ b/torchopt/accelerated_op/__init__.py @@ -16,12 +16,15 @@ from __future__ import annotations -from typing import Iterable +from typing import TYPE_CHECKING, Iterable import torch from torchopt.accelerated_op.adam_op import AdamOp -from torchopt.typing import Device + + +if TYPE_CHECKING: + from torchopt.typing import Device def is_available(devices: Device | Iterable[Device] | None = None) -> bool: @@ -42,6 +45,6 @@ def is_available(devices: Device | Iterable[Device] | None = None) -> bool: return False updates = torch.tensor(1.0, device=device) op(updates, updates, updates, 1) - return True except Exception: # noqa: BLE001 # pylint: disable=broad-except return False + return True diff --git a/torchopt/accelerated_op/_src/adam_op.py b/torchopt/accelerated_op/_src/adam_op.py index bc999766..d7f9796d 100644 --- a/torchopt/accelerated_op/_src/adam_op.py +++ b/torchopt/accelerated_op/_src/adam_op.py @@ -18,7 +18,11 @@ from __future__ import annotations -import torch +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + import torch def forward_( diff --git a/torchopt/alias/__init__.py b/torchopt/alias/__init__.py index 3cfb5b8b..5767c5d7 100644 --- a/torchopt/alias/__init__.py +++ b/torchopt/alias/__init__.py @@ -41,4 +41,13 @@ from torchopt.alias.sgd import sgd -__all__ = ['adagrad', 'radam', 'adam', 'adamax', 'adadelta', 'adamw', 'rmsprop', 'sgd'] +__all__ = [ + 'adadelta', + 'adagrad', + 'adam', + 'adamax', + 'adamw', + 'radam', + 'rmsprop', + 'sgd', +] diff --git a/torchopt/alias/adadelta.py b/torchopt/alias/adadelta.py index fb0b551a..910cb13e 100644 --- a/torchopt/alias/adadelta.py +++ b/torchopt/alias/adadelta.py @@ -16,6 +16,8 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from torchopt.alias.utils import ( _get_use_chain_flat, flip_sign_and_add_weight_decay, @@ -23,7 +25,10 @@ ) from torchopt.combine import chain from torchopt.transform import scale_by_adadelta -from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +if TYPE_CHECKING: + from torchopt.typing import GradientTransformation, ScalarOrSchedule __all__ = ['adadelta'] diff --git a/torchopt/alias/adam.py b/torchopt/alias/adam.py index 9419e908..0ae0eb8e 100644 --- a/torchopt/alias/adam.py +++ b/torchopt/alias/adam.py @@ -33,6 +33,8 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from torchopt.alias.utils import ( _get_use_chain_flat, flip_sign_and_add_weight_decay, @@ -40,7 +42,10 @@ ) from torchopt.combine import chain from torchopt.transform import scale_by_accelerated_adam, scale_by_adam -from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +if TYPE_CHECKING: + from torchopt.typing import GradientTransformation, ScalarOrSchedule __all__ = ['adam'] diff --git a/torchopt/alias/adamax.py b/torchopt/alias/adamax.py index f80c0c2f..3da16713 100644 --- a/torchopt/alias/adamax.py +++ b/torchopt/alias/adamax.py @@ -16,6 +16,8 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from torchopt.alias.utils import ( _get_use_chain_flat, flip_sign_and_add_weight_decay, @@ -23,7 +25,10 @@ ) from torchopt.combine import chain from torchopt.transform import scale_by_adamax -from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +if TYPE_CHECKING: + from torchopt.typing import GradientTransformation, ScalarOrSchedule __all__ = ['adamax'] diff --git a/torchopt/alias/adamw.py b/torchopt/alias/adamw.py index 38d4d5ac..2dc72ef1 100644 --- a/torchopt/alias/adamw.py +++ b/torchopt/alias/adamw.py @@ -33,7 +33,7 @@ from __future__ import annotations -from typing import Callable +from typing import TYPE_CHECKING, Callable from torchopt.alias.utils import ( _get_use_chain_flat, @@ -42,7 +42,10 @@ ) from torchopt.combine import chain from torchopt.transform import add_decayed_weights, scale_by_accelerated_adam, scale_by_adam -from torchopt.typing import GradientTransformation, OptState, Params, ScalarOrSchedule + + +if TYPE_CHECKING: + from torchopt.typing import GradientTransformation, OptState, Params, ScalarOrSchedule __all__ = ['adamw'] diff --git a/torchopt/alias/radam.py b/torchopt/alias/radam.py index 56d3d3d5..9e2880ee 100644 --- a/torchopt/alias/radam.py +++ b/torchopt/alias/radam.py @@ -16,6 +16,8 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from torchopt.alias.utils import ( _get_use_chain_flat, flip_sign_and_add_weight_decay, @@ -23,7 +25,10 @@ ) from torchopt.combine import chain from torchopt.transform import scale_by_radam -from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +if TYPE_CHECKING: + from torchopt.typing import GradientTransformation, ScalarOrSchedule __all__ = ['radam'] diff --git a/torchopt/alias/utils.py b/torchopt/alias/utils.py index 49f8784d..0f41e822 100644 --- a/torchopt/alias/utils.py +++ b/torchopt/alias/utils.py @@ -16,14 +16,18 @@ from __future__ import annotations import threading - -import torch +from typing import TYPE_CHECKING from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation, identity from torchopt.transform import scale, scale_by_schedule from torchopt.transform.utils import tree_map_flat, tree_map_flat_ -from torchopt.typing import Numeric, OptState, Params, ScalarOrSchedule, Updates + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import Numeric, OptState, Params, ScalarOrSchedule, Updates __all__ = ['flip_sign_and_add_weight_decay', 'scale_by_neg_lr'] @@ -68,7 +72,7 @@ def _flip_sign_and_add_weight_decay_flat( ) -def _flip_sign_and_add_weight_decay( +def _flip_sign_and_add_weight_decay( # noqa: C901 weight_decay: float = 0.0, maximize: bool = False, *, diff --git a/torchopt/base.py b/torchopt/base.py index 572708e2..81892e17 100644 --- a/torchopt/base.py +++ b/torchopt/base.py @@ -44,10 +44,10 @@ __all__ = [ + 'ChainedGradientTransformation', 'EmptyState', - 'UninitializedState', 'GradientTransformation', - 'ChainedGradientTransformation', + 'UninitializedState', 'identity', ] diff --git a/torchopt/clip.py b/torchopt/clip.py index 55ae83fc..d64afc58 100644 --- a/torchopt/clip.py +++ b/torchopt/clip.py @@ -19,11 +19,16 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import torch from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates __all__ = ['clip_grad_norm'] diff --git a/torchopt/combine.py b/torchopt/combine.py index 158ec982..15345286 100644 --- a/torchopt/combine.py +++ b/torchopt/combine.py @@ -33,9 +33,14 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from torchopt import pytree from torchopt.base import ChainedGradientTransformation, GradientTransformation, identity -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates __all__ = ['chain', 'chain_flat'] diff --git a/torchopt/diff/implicit/__init__.py b/torchopt/diff/implicit/__init__.py index 21737015..4cff14c6 100644 --- a/torchopt/diff/implicit/__init__.py +++ b/torchopt/diff/implicit/__init__.py @@ -19,4 +19,4 @@ from torchopt.diff.implicit.nn import ImplicitMetaGradientModule -__all__ = ['custom_root', 'ImplicitMetaGradientModule'] +__all__ = ['ImplicitMetaGradientModule', 'custom_root'] diff --git a/torchopt/diff/implicit/decorator.py b/torchopt/diff/implicit/decorator.py index d3efda2c..11ba0153 100644 --- a/torchopt/diff/implicit/decorator.py +++ b/torchopt/diff/implicit/decorator.py @@ -37,20 +37,23 @@ import functools import inspect -from typing import Any, Callable, Dict, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, Sequence, Tuple import functorch import torch from torch.autograd import Function from torchopt import linear_solve, pytree -from torchopt.typing import ( - ListOfOptionalTensors, - ListOfTensors, - TensorOrTensors, - TupleOfOptionalTensors, - TupleOfTensors, -) + + +if TYPE_CHECKING: + from torchopt.typing import ( + ListOfOptionalTensors, + ListOfTensors, + TensorOrTensors, + TupleOfOptionalTensors, + TupleOfTensors, + ) __all__ = ['custom_root'] @@ -253,7 +256,7 @@ def _merge_tensor_and_others( # pylint: disable-next=too-many-arguments,too-many-statements -def _custom_root( +def _custom_root( # noqa: C901 solver_fn: Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]], optimality_fn: Callable[..., TensorOrTensors], solve: Callable[..., TensorOrTensors], @@ -271,7 +274,7 @@ def _custom_root( fn = getattr(reference_signature, 'subfn', reference_signature) reference_signature = inspect.signature(fn) - def make_custom_vjp_solver_fn( + def make_custom_vjp_solver_fn( # noqa: C901 solver_fn: Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]], kwarg_keys: Sequence[str], args_signs: tuple[tuple[int, int, type[tuple | list] | None], ...], diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py index 8719f675..6b214cb8 100644 --- a/torchopt/diff/implicit/nn/module.py +++ b/torchopt/diff/implicit/nn/module.py @@ -22,15 +22,19 @@ import functools import inspect import itertools -from typing import Any, Iterable +from typing import TYPE_CHECKING, Any, Iterable import functorch -import torch from torchopt.diff.implicit.decorator import custom_root from torchopt.nn.module import MetaGradientModule from torchopt.nn.stateless import reparametrize, swap_state -from torchopt.typing import LinearSolver, TupleOfTensors + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import LinearSolver, TupleOfTensors __all__ = ['ImplicitMetaGradientModule'] diff --git a/torchopt/diff/zero_order/__init__.py b/torchopt/diff/zero_order/__init__.py index f00e097a..4369f4e5 100644 --- a/torchopt/diff/zero_order/__init__.py +++ b/torchopt/diff/zero_order/__init__.py @@ -25,7 +25,7 @@ from torchopt.diff.zero_order.nn import ZeroOrderGradientModule -__all__ = ['zero_order', 'ZeroOrderGradientModule'] +__all__ = ['ZeroOrderGradientModule', 'zero_order'] class _CallableModule(_ModuleType): # pylint: disable=too-few-public-methods diff --git a/torchopt/diff/zero_order/decorator.py b/torchopt/diff/zero_order/decorator.py index b1126636..e498b43c 100644 --- a/torchopt/diff/zero_order/decorator.py +++ b/torchopt/diff/zero_order/decorator.py @@ -17,6 +17,7 @@ from __future__ import annotations import functools +import itertools from typing import Any, Callable, Literal, Sequence from typing_extensions import TypeAlias # Python 3.10+ @@ -43,7 +44,7 @@ def sample( return self.sample_fn(sample_shape) -def _zero_order_naive( # pylint: disable=too-many-statements +def _zero_order_naive( # noqa: C901 # pylint: disable=too-many-statements fn: Callable[..., torch.Tensor], distribution: Samplable, argnums: tuple[int, ...], @@ -51,7 +52,7 @@ def _zero_order_naive( # pylint: disable=too-many-statements sigma: float, ) -> Callable[..., torch.Tensor]: @functools.wraps(fn) - def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements + def apply(*args: Any) -> torch.Tensor: # noqa: C901 # pylint: disable=too-many-statements diff_params = [args[argnum] for argnum in argnums] flat_diff_params: list[Any] flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params) # type: ignore[arg-type] @@ -81,7 +82,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: output = fn(*origin_args) if not isinstance(output, torch.Tensor): - raise RuntimeError('`output` must be a tensor.') + raise TypeError('`output` must be a tensor.') if output.ndim != 0: raise RuntimeError('`output` must be a scalar tensor.') ctx.save_for_backward(*flat_diff_params, *tensors) @@ -122,9 +123,9 @@ def add_perturbation( for _ in range(num_samples): noises = [distribution.sample(sample_shape=p.shape) for p in flat_diff_params] - flat_noisy_params = [ - add_perturbation(t, n) for t, n in zip(flat_diff_params, noises) # type: ignore[arg-type] - ] + flat_noisy_params = list( + itertools.starmap(add_perturbation, zip(flat_diff_params, noises)), + ) noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment] diff_params_treespec, flat_noisy_params, @@ -149,7 +150,7 @@ def add_perturbation( return apply -def _zero_order_forward( # pylint: disable=too-many-statements +def _zero_order_forward( # noqa: C901 # pylint: disable=too-many-statements fn: Callable[..., torch.Tensor], distribution: Samplable, argnums: tuple[int, ...], @@ -157,7 +158,7 @@ def _zero_order_forward( # pylint: disable=too-many-statements sigma: float, ) -> Callable[..., torch.Tensor]: @functools.wraps(fn) - def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements + def apply(*args: Any) -> torch.Tensor: # noqa: C901 # pylint: disable=too-many-statements diff_params = [args[argnum] for argnum in argnums] flat_diff_params: list[Any] flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params) # type: ignore[arg-type] @@ -187,7 +188,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: output = fn(*origin_args) if not isinstance(output, torch.Tensor): - raise RuntimeError('`output` must be a tensor.') + raise TypeError('`output` must be a tensor.') if output.ndim != 0: raise RuntimeError('`output` must be a scalar tensor.') ctx.save_for_backward(*flat_diff_params, *tensors, output) @@ -226,9 +227,9 @@ def add_perturbation(tensor: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: for _ in range(num_samples): noises = [distribution.sample(sample_shape=p.shape) for p in flat_diff_params] - flat_noisy_params = [ - add_perturbation(t, n) for t, n in zip(flat_diff_params, noises) # type: ignore[arg-type] - ] + flat_noisy_params = list( + itertools.starmap(add_perturbation, zip(flat_diff_params, noises)), + ) noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment] diff_params_treespec, flat_noisy_params, @@ -254,7 +255,7 @@ def add_perturbation(tensor: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: return apply -def _zero_order_antithetic( # pylint: disable=too-many-statements +def _zero_order_antithetic( # noqa: C901 # pylint: disable=too-many-statements fn: Callable[..., torch.Tensor], distribution: Samplable, argnums: tuple[int, ...], @@ -262,7 +263,7 @@ def _zero_order_antithetic( # pylint: disable=too-many-statements sigma: float, ) -> Callable[..., torch.Tensor]: @functools.wraps(fn) - def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements + def apply(*args: Any) -> torch.Tensor: # noqa: C901 # pylint: disable=too-many-statements diff_params = [args[argnum] for argnum in argnums] flat_diff_params: list[Any] flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params) # type: ignore[arg-type] @@ -292,7 +293,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: output = fn(*origin_args) if not isinstance(output, torch.Tensor): - raise RuntimeError('`output` must be a tensor.') + raise TypeError('`output` must be a tensor.') if output.ndim != 0: raise RuntimeError('`output` must be a scalar tensor.') ctx.save_for_backward(*flat_diff_params, *tensors) diff --git a/torchopt/diff/zero_order/nn/module.py b/torchopt/diff/zero_order/nn/module.py index 7ac12bb4..eeddabeb 100644 --- a/torchopt/diff/zero_order/nn/module.py +++ b/torchopt/diff/zero_order/nn/module.py @@ -20,14 +20,17 @@ import abc import functools -from typing import Any, Sequence +from typing import TYPE_CHECKING, Any, Sequence import torch import torch.nn as nn from torchopt.diff.zero_order.decorator import Method, Samplable, zero_order from torchopt.nn.stateless import reparametrize -from torchopt.typing import Numeric, TupleOfTensors + + +if TYPE_CHECKING: + from torchopt.typing import Numeric, TupleOfTensors __all__ = ['ZeroOrderGradientModule'] diff --git a/torchopt/distributed/api.py b/torchopt/distributed/api.py index 117af9ab..97be682f 100644 --- a/torchopt/distributed/api.py +++ b/torchopt/distributed/api.py @@ -42,15 +42,15 @@ __all__ = [ 'TensorDimensionPartitioner', - 'dim_partitioner', 'batch_partitioner', + 'dim_partitioner', 'mean_reducer', - 'sum_reducer', - 'remote_async_call', - 'remote_sync_call', 'parallelize', 'parallelize_async', 'parallelize_sync', + 'remote_async_call', + 'remote_sync_call', + 'sum_reducer', ] @@ -107,7 +107,7 @@ def __init__( self.workers = workers # pylint: disable-next=too-many-branches,too-many-locals - def __call__( + def __call__( # noqa: C901 self, *args: Any, **kwargs: Any, @@ -310,7 +310,7 @@ def remote_async_call( elif callable(partitioner): partitions = partitioner(*args, **kwargs) # type: ignore[assignment] else: - raise ValueError(f'Invalid partitioner: {partitioner!r}.') + raise TypeError(f'Invalid partitioner: {partitioner!r}.') futures = [] for rank, worker_args, worker_kwargs in partitions: diff --git a/torchopt/distributed/autograd.py b/torchopt/distributed/autograd.py index f7da4f46..71afdb86 100644 --- a/torchopt/distributed/autograd.py +++ b/torchopt/distributed/autograd.py @@ -17,15 +17,18 @@ from __future__ import annotations from threading import Lock +from typing import TYPE_CHECKING import torch import torch.distributed.autograd as autograd from torch.distributed.autograd import context -from torchopt.typing import TensorOrTensors, TupleOfOptionalTensors +if TYPE_CHECKING: + from torchopt.typing import TensorOrTensors, TupleOfOptionalTensors -__all__ = ['is_available', 'context'] + +__all__ = ['context', 'is_available'] LOCK = Lock() @@ -121,7 +124,7 @@ def grad( for p in inputs: try: grads.append(all_local_grads[p]) - except KeyError as ex: + except KeyError as ex: # noqa: PERF203 if not allow_unused: raise RuntimeError( 'One of the differentiated Tensors appears to not have been used in the ' @@ -131,4 +134,4 @@ def grad( return tuple(grads) - __all__ += ['DistAutogradContext', 'get_gradients', 'backward', 'grad'] + __all__ += ['DistAutogradContext', 'backward', 'get_gradients', 'grad'] diff --git a/torchopt/distributed/world.py b/torchopt/distributed/world.py index a61280c5..610e52a0 100644 --- a/torchopt/distributed/world.py +++ b/torchopt/distributed/world.py @@ -26,19 +26,19 @@ __all__ = [ - 'get_world_info', - 'get_world_rank', - 'get_rank', - 'get_world_size', + 'auto_init_rpc', + 'barrier', 'get_local_rank', 'get_local_world_size', + 'get_rank', 'get_worker_id', - 'barrier', - 'auto_init_rpc', - 'on_rank', + 'get_world_info', + 'get_world_rank', + 'get_world_size', 'not_on_rank', - 'rank_zero_only', + 'on_rank', 'rank_non_zero_only', + 'rank_zero_only', ] diff --git a/torchopt/hook.py b/torchopt/hook.py index b51e29eb..c11b92f6 100644 --- a/torchopt/hook.py +++ b/torchopt/hook.py @@ -16,16 +16,19 @@ from __future__ import annotations -from typing import Callable - -import torch +from typing import TYPE_CHECKING, Callable from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation -from torchopt.typing import OptState, Params, Updates -__all__ = ['zero_nan_hook', 'nan_to_num_hook', 'register_hook'] +if TYPE_CHECKING: + import torch + + from torchopt.typing import OptState, Params, Updates + + +__all__ = ['nan_to_num_hook', 'register_hook', 'zero_nan_hook'] def zero_nan_hook(g: torch.Tensor) -> torch.Tensor: diff --git a/torchopt/linalg/cg.py b/torchopt/linalg/cg.py index a82ff877..1096a5af 100644 --- a/torchopt/linalg/cg.py +++ b/torchopt/linalg/cg.py @@ -36,14 +36,17 @@ from __future__ import annotations from functools import partial -from typing import Callable +from typing import TYPE_CHECKING, Callable import torch from torchopt import pytree from torchopt.linalg.utils import cat_shapes, normalize_matvec from torchopt.pytree import tree_vdot_real -from torchopt.typing import TensorTree + + +if TYPE_CHECKING: + from torchopt.typing import TensorTree __all__ = ['cg'] diff --git a/torchopt/linalg/ns.py b/torchopt/linalg/ns.py index b049a5ad..5fc8d478 100644 --- a/torchopt/linalg/ns.py +++ b/torchopt/linalg/ns.py @@ -19,13 +19,16 @@ from __future__ import annotations import functools -from typing import Callable +from typing import TYPE_CHECKING, Callable import torch from torchopt import pytree from torchopt.linalg.utils import normalize_matvec -from torchopt.typing import TensorTree + + +if TYPE_CHECKING: + from torchopt.typing import TensorTree __all__ = ['ns', 'ns_inv'] diff --git a/torchopt/linalg/utils.py b/torchopt/linalg/utils.py index a5ac765d..bbcc80aa 100644 --- a/torchopt/linalg/utils.py +++ b/torchopt/linalg/utils.py @@ -17,12 +17,15 @@ from __future__ import annotations import itertools -from typing import Callable +from typing import TYPE_CHECKING, Callable import torch from torchopt import pytree -from torchopt.typing import TensorTree + + +if TYPE_CHECKING: + from torchopt.typing import TensorTree def cat_shapes(tree: TensorTree) -> tuple[int, ...]: diff --git a/torchopt/linear_solve/__init__.py b/torchopt/linear_solve/__init__.py index 2d61eb6d..43ca1da0 100644 --- a/torchopt/linear_solve/__init__.py +++ b/torchopt/linear_solve/__init__.py @@ -36,4 +36,4 @@ from torchopt.linear_solve.normal_cg import solve_normal_cg -__all__ = ['solve_cg', 'solve_normal_cg', 'solve_inv'] +__all__ = ['solve_cg', 'solve_inv', 'solve_normal_cg'] diff --git a/torchopt/linear_solve/cg.py b/torchopt/linear_solve/cg.py index f4127639..23814cc2 100644 --- a/torchopt/linear_solve/cg.py +++ b/torchopt/linear_solve/cg.py @@ -36,11 +36,14 @@ from __future__ import annotations import functools -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable from torchopt import linalg from torchopt.linear_solve.utils import make_ridge_matvec -from torchopt.typing import LinearSolver, TensorTree + + +if TYPE_CHECKING: + from torchopt.typing import LinearSolver, TensorTree __all__ = ['solve_cg'] diff --git a/torchopt/linear_solve/inv.py b/torchopt/linear_solve/inv.py index f37be8c5..4dbe1542 100644 --- a/torchopt/linear_solve/inv.py +++ b/torchopt/linear_solve/inv.py @@ -36,13 +36,16 @@ from __future__ import annotations import functools -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable import torch from torchopt import linalg, pytree from torchopt.linear_solve.utils import make_ridge_matvec, materialize_matvec -from torchopt.typing import LinearSolver, TensorTree + + +if TYPE_CHECKING: + from torchopt.typing import LinearSolver, TensorTree __all__ = ['solve_inv'] diff --git a/torchopt/linear_solve/normal_cg.py b/torchopt/linear_solve/normal_cg.py index 405ab43c..a5af49b2 100644 --- a/torchopt/linear_solve/normal_cg.py +++ b/torchopt/linear_solve/normal_cg.py @@ -36,11 +36,14 @@ from __future__ import annotations import functools -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable from torchopt import linalg from torchopt.linear_solve.utils import make_normal_matvec, make_ridge_matvec, make_rmatvec -from torchopt.typing import LinearSolver, TensorTree + + +if TYPE_CHECKING: + from torchopt.typing import LinearSolver, TensorTree __all__ = ['solve_normal_cg'] diff --git a/torchopt/linear_solve/utils.py b/torchopt/linear_solve/utils.py index 5e4bf7bd..9d1b8779 100644 --- a/torchopt/linear_solve/utils.py +++ b/torchopt/linear_solve/utils.py @@ -33,12 +33,15 @@ from __future__ import annotations -from typing import Callable +from typing import TYPE_CHECKING, Callable import functorch from torchopt import pytree -from torchopt.typing import TensorTree + + +if TYPE_CHECKING: + from torchopt.typing import TensorTree def make_rmatvec( diff --git a/torchopt/nn/__init__.py b/torchopt/nn/__init__.py index 7665f201..b55e49d7 100644 --- a/torchopt/nn/__init__.py +++ b/torchopt/nn/__init__.py @@ -21,10 +21,10 @@ __all__ = [ - 'MetaGradientModule', 'ImplicitMetaGradientModule', + 'MetaGradientModule', 'ZeroOrderGradientModule', - 'reparametrize', 'reparameterize', + 'reparametrize', 'swap_state', ] diff --git a/torchopt/nn/module.py b/torchopt/nn/module.py index 419afb6a..8c40f58a 100644 --- a/torchopt/nn/module.py +++ b/torchopt/nn/module.py @@ -17,14 +17,17 @@ from __future__ import annotations from collections import OrderedDict -from typing import Any, Iterator, NamedTuple +from typing import TYPE_CHECKING, Any, Iterator, NamedTuple from typing_extensions import Self # Python 3.11+ import torch import torch.nn as nn from torchopt import pytree -from torchopt.typing import TensorContainer + + +if TYPE_CHECKING: + from torchopt.typing import TensorContainer class MetaInputsContainer(NamedTuple): @@ -61,7 +64,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=unused """Initialize a new module instance.""" super().__init__() - def __getattr__(self, name: str) -> torch.Tensor | nn.Module: + def __getattr__(self, name: str) -> torch.Tensor | nn.Module: # noqa: C901 """Get an attribute of the module.""" if '_parameters' in self.__dict__: _parameters = self.__dict__['_parameters'] @@ -86,7 +89,7 @@ def __getattr__(self, name: str) -> torch.Tensor | nn.Module: raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") # pylint: disable-next=too-many-branches,too-many-statements - def __setattr__(self, name: str, value: torch.Tensor | nn.Module) -> None: + def __setattr__(self, name: str, value: torch.Tensor | nn.Module) -> None: # noqa: C901 """Set an attribute of the module.""" def remove_from(*dicts_or_sets: dict[str, Any] | set[str]) -> None: diff --git a/torchopt/nn/stateless.py b/torchopt/nn/stateless.py index d3437d0d..c7f92b86 100644 --- a/torchopt/nn/stateless.py +++ b/torchopt/nn/stateless.py @@ -17,13 +17,15 @@ from __future__ import annotations import contextlib -from typing import Generator, Iterable +from typing import TYPE_CHECKING, Generator, Iterable -import torch -import torch.nn as nn +if TYPE_CHECKING: + import torch + import torch.nn as nn -__all__ = ['swap_state', 'reparametrize', 'reparameterize'] + +__all__ = ['reparameterize', 'reparametrize', 'swap_state'] MISSING: torch.Tensor = object() # type: ignore[assignment] diff --git a/torchopt/optim/adadelta.py b/torchopt/optim/adadelta.py index a64e00e4..600b69c5 100644 --- a/torchopt/optim/adadelta.py +++ b/torchopt/optim/adadelta.py @@ -16,13 +16,16 @@ from __future__ import annotations -from typing import Iterable - -import torch +from typing import TYPE_CHECKING, Iterable from torchopt import alias from torchopt.optim.base import Optimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import ScalarOrSchedule __all__ = ['AdaDelta', 'Adadelta'] diff --git a/torchopt/optim/adagrad.py b/torchopt/optim/adagrad.py index 277b7105..06091281 100644 --- a/torchopt/optim/adagrad.py +++ b/torchopt/optim/adagrad.py @@ -16,13 +16,16 @@ from __future__ import annotations -from typing import Iterable - -import torch +from typing import TYPE_CHECKING, Iterable from torchopt import alias from torchopt.optim.base import Optimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import ScalarOrSchedule __all__ = ['AdaGrad', 'Adagrad'] diff --git a/torchopt/optim/adam.py b/torchopt/optim/adam.py index 6ff68a69..555af22e 100644 --- a/torchopt/optim/adam.py +++ b/torchopt/optim/adam.py @@ -16,13 +16,16 @@ from __future__ import annotations -from typing import Iterable - -import torch +from typing import TYPE_CHECKING, Iterable from torchopt import alias from torchopt.optim.base import Optimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import ScalarOrSchedule __all__ = ['Adam'] diff --git a/torchopt/optim/adamax.py b/torchopt/optim/adamax.py index f693723c..e4996e85 100644 --- a/torchopt/optim/adamax.py +++ b/torchopt/optim/adamax.py @@ -16,13 +16,16 @@ from __future__ import annotations -from typing import Iterable - -import torch +from typing import TYPE_CHECKING, Iterable from torchopt import alias from torchopt.optim.base import Optimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import ScalarOrSchedule __all__ = ['AdaMax', 'Adamax'] diff --git a/torchopt/optim/adamw.py b/torchopt/optim/adamw.py index 463f245f..a60061ea 100644 --- a/torchopt/optim/adamw.py +++ b/torchopt/optim/adamw.py @@ -16,13 +16,16 @@ from __future__ import annotations -from typing import Callable, Iterable - -import torch +from typing import TYPE_CHECKING, Callable, Iterable from torchopt import alias from torchopt.optim.base import Optimizer -from torchopt.typing import OptState, Params, ScalarOrSchedule + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import OptState, Params, ScalarOrSchedule __all__ = ['AdamW'] diff --git a/torchopt/optim/func/base.py b/torchopt/optim/func/base.py index 7bb27877..fa287f04 100644 --- a/torchopt/optim/func/base.py +++ b/torchopt/optim/func/base.py @@ -16,13 +16,18 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import torch from torchopt.base import GradientTransformation, UninitializedState -from torchopt.typing import OptState, Params from torchopt.update import apply_updates +if TYPE_CHECKING: + from torchopt.typing import OptState, Params + + __all__ = ['FuncOptimizer'] diff --git a/torchopt/optim/meta/adadelta.py b/torchopt/optim/meta/adadelta.py index 49bdf23c..eb386ae3 100644 --- a/torchopt/optim/meta/adadelta.py +++ b/torchopt/optim/meta/adadelta.py @@ -16,11 +16,16 @@ from __future__ import annotations -import torch.nn as nn +from typing import TYPE_CHECKING from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch.nn as nn + + from torchopt.typing import ScalarOrSchedule __all__ = ['MetaAdaDelta', 'MetaAdadelta'] diff --git a/torchopt/optim/meta/adagrad.py b/torchopt/optim/meta/adagrad.py index 58d913aa..129c1338 100644 --- a/torchopt/optim/meta/adagrad.py +++ b/torchopt/optim/meta/adagrad.py @@ -16,11 +16,16 @@ from __future__ import annotations -import torch.nn as nn +from typing import TYPE_CHECKING from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch.nn as nn + + from torchopt.typing import ScalarOrSchedule __all__ = ['MetaAdaGrad', 'MetaAdagrad'] diff --git a/torchopt/optim/meta/adam.py b/torchopt/optim/meta/adam.py index bac71790..7a78ea7f 100644 --- a/torchopt/optim/meta/adam.py +++ b/torchopt/optim/meta/adam.py @@ -16,11 +16,16 @@ from __future__ import annotations -import torch.nn as nn +from typing import TYPE_CHECKING from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch.nn as nn + + from torchopt.typing import ScalarOrSchedule __all__ = ['MetaAdam'] diff --git a/torchopt/optim/meta/adamax.py b/torchopt/optim/meta/adamax.py index 568a46f7..d6b40427 100644 --- a/torchopt/optim/meta/adamax.py +++ b/torchopt/optim/meta/adamax.py @@ -16,11 +16,16 @@ from __future__ import annotations -import torch.nn as nn +from typing import TYPE_CHECKING from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch.nn as nn + + from torchopt.typing import ScalarOrSchedule __all__ = ['MetaAdaMax', 'MetaAdamax'] diff --git a/torchopt/optim/meta/adamw.py b/torchopt/optim/meta/adamw.py index 05387b77..62864582 100644 --- a/torchopt/optim/meta/adamw.py +++ b/torchopt/optim/meta/adamw.py @@ -16,13 +16,16 @@ from __future__ import annotations -from typing import Callable - -import torch.nn as nn +from typing import TYPE_CHECKING, Callable from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer -from torchopt.typing import OptState, Params, ScalarOrSchedule + + +if TYPE_CHECKING: + import torch.nn as nn + + from torchopt.typing import OptState, Params, ScalarOrSchedule __all__ = ['MetaAdamW'] diff --git a/torchopt/optim/meta/radam.py b/torchopt/optim/meta/radam.py index a32670d0..bb07b5ba 100644 --- a/torchopt/optim/meta/radam.py +++ b/torchopt/optim/meta/radam.py @@ -16,11 +16,16 @@ from __future__ import annotations -import torch.nn as nn +from typing import TYPE_CHECKING from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch.nn as nn + + from torchopt.typing import ScalarOrSchedule __all__ = ['MetaRAdam'] diff --git a/torchopt/optim/radam.py b/torchopt/optim/radam.py index bba8c0d4..20e9dd22 100644 --- a/torchopt/optim/radam.py +++ b/torchopt/optim/radam.py @@ -16,13 +16,16 @@ from __future__ import annotations -from typing import Iterable - -import torch +from typing import TYPE_CHECKING, Iterable from torchopt import alias from torchopt.optim.base import Optimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import ScalarOrSchedule __all__ = ['RAdam'] diff --git a/torchopt/pytree.py b/torchopt/pytree.py index 6adea0e8..53abc2d2 100644 --- a/torchopt/pytree.py +++ b/torchopt/pytree.py @@ -18,7 +18,7 @@ import functools import operator -from typing import Callable +from typing import TYPE_CHECKING, Callable import optree import optree.typing as typing # pylint: disable=unused-import @@ -26,7 +26,9 @@ import torch.distributed.rpc as rpc from optree import * # pylint: disable=wildcard-import,unused-wildcard-import -from torchopt.typing import Future, RRef, Scalar, T, TensorTree + +if TYPE_CHECKING: + from torchopt.typing import Future, RRef, Scalar, T, TensorTree __all__ = [ diff --git a/torchopt/schedule/__init__.py b/torchopt/schedule/__init__.py index 8e5545a4..d3d3eff5 100644 --- a/torchopt/schedule/__init__.py +++ b/torchopt/schedule/__init__.py @@ -35,4 +35,4 @@ from torchopt.schedule.polynomial import linear_schedule, polynomial_schedule -__all__ = ['exponential_decay', 'polynomial_schedule', 'linear_schedule'] +__all__ = ['exponential_decay', 'linear_schedule', 'polynomial_schedule'] diff --git a/torchopt/schedule/exponential_decay.py b/torchopt/schedule/exponential_decay.py index 0925e164..c19c54b9 100644 --- a/torchopt/schedule/exponential_decay.py +++ b/torchopt/schedule/exponential_decay.py @@ -31,11 +31,15 @@ # ============================================================================== """Exponential learning rate decay.""" +from __future__ import annotations + import logging import math -from typing import Optional +from typing import TYPE_CHECKING + -from torchopt.typing import Numeric, Scalar, Schedule +if TYPE_CHECKING: + from torchopt.typing import Numeric, Scalar, Schedule __all__ = ['exponential_decay'] @@ -48,7 +52,7 @@ def exponential_decay( transition_begin: int = 0, transition_steps: int = 1, staircase: bool = False, - end_value: Optional[float] = None, + end_value: float | None = None, ) -> Schedule: """Construct a schedule with either continuous or discrete exponential decay. diff --git a/torchopt/schedule/polynomial.py b/torchopt/schedule/polynomial.py index 2482f769..d2a5160c 100644 --- a/torchopt/schedule/polynomial.py +++ b/torchopt/schedule/polynomial.py @@ -31,15 +31,20 @@ # ============================================================================== """Polynomial learning rate schedules.""" +from __future__ import annotations + import logging +from typing import TYPE_CHECKING import numpy as np import torch -from torchopt.typing import Numeric, Scalar, Schedule + +if TYPE_CHECKING: + from torchopt.typing import Numeric, Scalar, Schedule -__all__ = ['polynomial_schedule', 'linear_schedule'] +__all__ = ['linear_schedule', 'polynomial_schedule'] def polynomial_schedule( diff --git a/torchopt/transform/__init__.py b/torchopt/transform/__init__.py index adef5596..fa59a43b 100644 --- a/torchopt/transform/__init__.py +++ b/torchopt/transform/__init__.py @@ -46,18 +46,18 @@ __all__ = [ - 'trace', - 'scale', - 'scale_by_schedule', 'add_decayed_weights', 'masked', + 'nan_to_num', + 'scale', + 'scale_by_accelerated_adam', + 'scale_by_adadelta', 'scale_by_adam', 'scale_by_adamax', - 'scale_by_adadelta', 'scale_by_radam', - 'scale_by_accelerated_adam', - 'scale_by_rss', 'scale_by_rms', + 'scale_by_rss', + 'scale_by_schedule', 'scale_by_stddev', - 'nan_to_num', + 'trace', ] diff --git a/torchopt/transform/add_decayed_weights.py b/torchopt/transform/add_decayed_weights.py index 950682cf..0cb67837 100644 --- a/torchopt/transform/add_decayed_weights.py +++ b/torchopt/transform/add_decayed_weights.py @@ -34,17 +34,20 @@ from __future__ import annotations -from typing import Any, Callable, NamedTuple - -import torch +from typing import TYPE_CHECKING, Any, Callable, NamedTuple from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation, identity from torchopt.transform.utils import tree_map_flat, tree_map_flat_ -from torchopt.typing import OptState, Params, Updates -__all__ = ['masked', 'add_decayed_weights'] +if TYPE_CHECKING: + import torch + + from torchopt.typing import OptState, Params, Updates + + +__all__ = ['add_decayed_weights', 'masked'] class MaskedState(NamedTuple): @@ -189,7 +192,7 @@ def _add_decayed_weights_flat( ) -def _add_decayed_weights( +def _add_decayed_weights( # noqa: C901 weight_decay: float = 0.0, mask: OptState | Callable[[Params], OptState] | None = None, *, diff --git a/torchopt/transform/nan_to_num.py b/torchopt/transform/nan_to_num.py index d3530853..740df1b0 100644 --- a/torchopt/transform/nan_to_num.py +++ b/torchopt/transform/nan_to_num.py @@ -16,11 +16,16 @@ from __future__ import annotations -import torch +from typing import TYPE_CHECKING from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import OptState, Params, Updates def nan_to_num( diff --git a/torchopt/transform/scale.py b/torchopt/transform/scale.py index 493b7196..2b492bdf 100644 --- a/torchopt/transform/scale.py +++ b/torchopt/transform/scale.py @@ -33,12 +33,17 @@ from __future__ import annotations -import torch +from typing import TYPE_CHECKING from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation from torchopt.transform.utils import tree_map_flat, tree_map_flat_ -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import OptState, Params, Updates __all__ = ['scale'] diff --git a/torchopt/transform/scale_by_adadelta.py b/torchopt/transform/scale_by_adadelta.py index f389d293..6d05e5dd 100644 --- a/torchopt/transform/scale_by_adadelta.py +++ b/torchopt/transform/scale_by_adadelta.py @@ -19,14 +19,17 @@ from __future__ import annotations -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation from torchopt.transform.utils import tree_map_flat, update_moment -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates __all__ = ['scale_by_adadelta'] diff --git a/torchopt/transform/scale_by_adam.py b/torchopt/transform/scale_by_adam.py index b08c6a14..d45d1eb2 100644 --- a/torchopt/transform/scale_by_adam.py +++ b/torchopt/transform/scale_by_adam.py @@ -35,7 +35,7 @@ from __future__ import annotations -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple import torch @@ -43,10 +43,13 @@ from torchopt.accelerated_op import AdamOp from torchopt.base import GradientTransformation from torchopt.transform.utils import inc_count, tree_map_flat, update_moment -from torchopt.typing import OptState, Params, Updates -__all__ = ['scale_by_adam', 'scale_by_accelerated_adam'] +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates + + +__all__ = ['scale_by_accelerated_adam', 'scale_by_adam'] TRIPLE_PYTREE_SPEC = pytree.tree_structure((0, 1, 2), none_is_leaf=True) # type: ignore[arg-type] @@ -277,7 +280,7 @@ def _scale_by_accelerated_adam_flat( # pylint: disable-next=too-many-arguments -def _scale_by_accelerated_adam( +def _scale_by_accelerated_adam( # noqa: C901 b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8, diff --git a/torchopt/transform/scale_by_adamax.py b/torchopt/transform/scale_by_adamax.py index f11ed311..cfacbf35 100644 --- a/torchopt/transform/scale_by_adamax.py +++ b/torchopt/transform/scale_by_adamax.py @@ -19,14 +19,17 @@ from __future__ import annotations -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation from torchopt.transform.utils import tree_map_flat, update_moment -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates __all__ = ['scale_by_adamax'] diff --git a/torchopt/transform/scale_by_radam.py b/torchopt/transform/scale_by_radam.py index fad32b13..95f26149 100644 --- a/torchopt/transform/scale_by_radam.py +++ b/torchopt/transform/scale_by_radam.py @@ -20,14 +20,17 @@ from __future__ import annotations import math -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation from torchopt.transform.utils import tree_map_flat, update_moment -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates __all__ = ['scale_by_radam'] @@ -89,7 +92,7 @@ def _scale_by_radam_flat( ) -def _scale_by_radam( +def _scale_by_radam( # noqa: C901 b1: float = 0.9, b2: float = 0.999, eps: float = 1e-6, diff --git a/torchopt/transform/scale_by_rms.py b/torchopt/transform/scale_by_rms.py index 4ee67ed0..f2141388 100644 --- a/torchopt/transform/scale_by_rms.py +++ b/torchopt/transform/scale_by_rms.py @@ -33,14 +33,17 @@ from __future__ import annotations -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation from torchopt.transform.utils import tree_map_flat, tree_map_flat_, update_moment -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates __all__ = ['scale_by_rms'] diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py index 9bc97206..642b2e5c 100644 --- a/torchopt/transform/scale_by_rss.py +++ b/torchopt/transform/scale_by_rss.py @@ -33,14 +33,17 @@ from __future__ import annotations -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation from torchopt.transform.utils import tree_map_flat, update_moment -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates __all__ = ['scale_by_rss'] diff --git a/torchopt/transform/scale_by_schedule.py b/torchopt/transform/scale_by_schedule.py index 48f3f271..499e2adb 100644 --- a/torchopt/transform/scale_by_schedule.py +++ b/torchopt/transform/scale_by_schedule.py @@ -33,14 +33,17 @@ from __future__ import annotations -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation from torchopt.transform.utils import inc_count, tree_map_flat, tree_map_flat_ -from torchopt.typing import Numeric, OptState, Params, Schedule, SequenceOfTensors, Updates + + +if TYPE_CHECKING: + from torchopt.typing import Numeric, OptState, Params, Schedule, SequenceOfTensors, Updates __all__ = ['scale_by_schedule'] diff --git a/torchopt/transform/scale_by_stddev.py b/torchopt/transform/scale_by_stddev.py index 6b99f31a..5a3e6655 100644 --- a/torchopt/transform/scale_by_stddev.py +++ b/torchopt/transform/scale_by_stddev.py @@ -35,14 +35,17 @@ from __future__ import annotations -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation from torchopt.transform.utils import tree_map_flat, tree_map_flat_, update_moment -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates __all__ = ['scale_by_stddev'] diff --git a/torchopt/transform/trace.py b/torchopt/transform/trace.py index 9bf37e2f..219cbbec 100644 --- a/torchopt/transform/trace.py +++ b/torchopt/transform/trace.py @@ -35,14 +35,17 @@ from __future__ import annotations -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation, identity from torchopt.transform.utils import tree_map_flat, tree_map_flat_ -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates __all__ = ['trace'] @@ -101,7 +104,7 @@ def _trace_flat( ) -def _trace( +def _trace( # noqa: C901 momentum: float = 0.9, dampening: float = 0.0, nesterov: bool = False, @@ -136,7 +139,7 @@ def init_fn(params: Params) -> OptState: first_call = True - def update_fn( + def update_fn( # noqa: C901 updates: Updates, state: OptState, *, diff --git a/torchopt/transform/utils.py b/torchopt/transform/utils.py index ec4e51c1..9b38d561 100644 --- a/torchopt/transform/utils.py +++ b/torchopt/transform/utils.py @@ -34,15 +34,18 @@ from __future__ import annotations from collections import deque -from typing import Any, Callable, Sequence +from typing import TYPE_CHECKING, Any, Callable, Sequence import torch from torchopt import pytree -from torchopt.typing import TensorTree, Updates -__all__ = ['tree_map_flat', 'tree_map_flat_', 'inc_count', 'update_moment'] +if TYPE_CHECKING: + from torchopt.typing import TensorTree, Updates + + +__all__ = ['inc_count', 'tree_map_flat', 'tree_map_flat_', 'update_moment'] INT64_MAX = torch.iinfo(torch.int64).max @@ -161,7 +164,7 @@ def _update_moment_flat( # pylint: disable-next=too-many-arguments -def _update_moment( +def _update_moment( # noqa: C901 updates: Updates, moments: TensorTree, decay: float, diff --git a/torchopt/typing.py b/torchopt/typing.py index 60d11e0e..fcd888fb 100644 --- a/torchopt/typing.py +++ b/torchopt/typing.py @@ -14,6 +14,8 @@ # ============================================================================== """Typing utilities.""" +from __future__ import annotations + import abc from typing import ( Callable, @@ -45,39 +47,39 @@ __all__ = [ - 'GradientTransformation', 'ChainedGradientTransformation', + 'Device', + 'Distribution', 'EmptyState', - 'UninitializedState', - 'Params', - 'Updates', + 'Future', + 'GradientTransformation', + 'LinearSolver', + 'ListOfOptionalTensors', + 'ListOfTensors', + 'ModuleTensorContainers', + 'Numeric', 'OptState', + 'OptionalTensor', + 'OptionalTensorOrOptionalTensors', + 'OptionalTensorTree', + 'Params', + 'PyTree', + 'Samplable', + 'SampleFunc', 'Scalar', - 'Numeric', - 'Schedule', 'ScalarOrSchedule', - 'PyTree', - 'Tensor', - 'OptionalTensor', - 'ListOfTensors', - 'TupleOfTensors', + 'Schedule', + 'SequenceOfOptionalTensors', 'SequenceOfTensors', + 'Size', + 'Tensor', + 'TensorContainer', 'TensorOrTensors', 'TensorTree', - 'ListOfOptionalTensors', 'TupleOfOptionalTensors', - 'SequenceOfOptionalTensors', - 'OptionalTensorOrOptionalTensors', - 'OptionalTensorTree', - 'TensorContainer', - 'ModuleTensorContainers', - 'Future', - 'LinearSolver', - 'Device', - 'Size', - 'Distribution', - 'SampleFunc', - 'Samplable', + 'TupleOfTensors', + 'UninitializedState', + 'Updates', ] T = TypeVar('T') @@ -138,7 +140,7 @@ class Samplable(Protocol): # pylint: disable=too-few-public-methods def sample( self, sample_shape: Size = Size(), # noqa: B008 # pylint: disable=unused-argument - ) -> Union[Tensor, Sequence[Numeric]]: + ) -> Tensor | Sequence[Numeric]: # pylint: disable-next=line-too-long """Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.""" raise NotImplementedError diff --git a/torchopt/update.py b/torchopt/update.py index 8636d7a4..3f2d71fe 100644 --- a/torchopt/update.py +++ b/torchopt/update.py @@ -33,10 +33,15 @@ from __future__ import annotations -import torch +from typing import TYPE_CHECKING from torchopt import pytree -from torchopt.typing import Params, Updates + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import Params, Updates __all__ = ['apply_updates'] diff --git a/torchopt/utils.py b/torchopt/utils.py index c067d570..5f9202a3 100644 --- a/torchopt/utils.py +++ b/torchopt/utils.py @@ -34,11 +34,11 @@ __all__ = [ 'ModuleState', - 'stop_gradient', 'extract_state_dict', - 'recover_state_dict', 'module_clone', 'module_detach_', + 'recover_state_dict', + 'stop_gradient', ] @@ -115,7 +115,7 @@ def extract_state_dict( # pylint: disable-next=too-many-arguments,too-many-branches,too-many-locals -def extract_state_dict( +def extract_state_dict( # noqa: C901 target: nn.Module | MetaOptimizer, *, by: CopyMode = 'reference', @@ -272,7 +272,7 @@ def get_variable(t: torch.Tensor | None) -> torch.Tensor | None: return pytree.tree_map(get_variable, state) # type: ignore[arg-type,return-value] - raise RuntimeError(f'Unexpected class of {target}') + raise TypeError(f'Unexpected class of {target}') def extract_module_containers( @@ -346,7 +346,7 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor: state = cast(Sequence[OptState], state) target.load_state_dict(state) else: - raise RuntimeError(f'Unexpected class of {target}') + raise TypeError(f'Unexpected class of {target}') @overload @@ -383,7 +383,7 @@ def module_clone( # pylint: disable-next=too-many-locals -def module_clone( +def module_clone( # noqa: C901 target: nn.Module | MetaOptimizer | TensorTree, *, by: CopyMode = 'reference', diff --git a/torchopt/visual.py b/torchopt/visual.py index d7885889..7638d7ec 100644 --- a/torchopt/visual.py +++ b/torchopt/visual.py @@ -19,16 +19,19 @@ from __future__ import annotations -from typing import Any, Generator, Iterable, Mapping, cast +from typing import TYPE_CHECKING, Any, Generator, Iterable, Mapping, cast import torch from graphviz import Digraph from torchopt import pytree -from torchopt.typing import TensorTree from torchopt.utils import ModuleState +if TYPE_CHECKING: + from torchopt.typing import TensorTree + + __all__ = ['make_dot', 'resize_graph'] @@ -69,7 +72,7 @@ def truncate(s: str) -> str: # pylint: disable=invalid-name # pylint: disable-next=too-many-branches,too-many-statements,too-many-locals -def make_dot( +def make_dot( # noqa: C901 var: TensorTree, params: ( Mapping[str, torch.Tensor] @@ -153,7 +156,7 @@ def get_var_name_with_flag(var: torch.Tensor) -> str | None: return f'{param_map[var][0]}\n{size_to_str(param_map[var][1].size())}' return None - def add_nodes(fn: Any) -> None: # pylint: disable=too-many-branches + def add_nodes(fn: Any) -> None: # noqa: C901 # pylint: disable=too-many-branches assert not isinstance(fn, torch.Tensor) if fn in seen: return From 960fb0aa53f27bad1af0bdf018b5ad4be077949a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Jul 2024 17:18:44 +0800 Subject: [PATCH 26/26] chore(pre-commit): [pre-commit.ci] autoupdate (#226) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/pre-commit/mirrors-clang-format: v18.1.6 → v18.1.8](https://github.com/pre-commit/mirrors-clang-format/compare/v18.1.6...v18.1.8) - [github.com/astral-sh/ruff-pre-commit: v0.4.9 → v0.5.0](https://github.com/astral-sh/ruff-pre-commit/compare/v0.4.9...v0.5.0) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 4 ++-- torchopt/version.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4814c681..7ab860a5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,11 +26,11 @@ repos: - id: debug-statements - id: double-quote-string-fixer - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v18.1.6 + rev: v18.1.8 hooks: - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.9 + rev: v0.5.0 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/torchopt/version.py b/torchopt/version.py index a1618caf..9fdcac9b 100644 --- a/torchopt/version.py +++ b/torchopt/version.py @@ -25,8 +25,8 @@ try: prefix, sep, suffix = ( - subprocess.check_output( - ['git', 'describe', '--abbrev=7'], # noqa: S603,S607 + subprocess.check_output( # noqa: S603 + ['git', 'describe', '--abbrev=7'], # noqa: S607 cwd=os.path.dirname(os.path.abspath(__file__)), stderr=subprocess.DEVNULL, text=True, @@ -40,7 +40,7 @@ if sep: version_prefix, dot, version_tail = prefix.rpartition('.') prefix = f'{version_prefix}{dot}{int(version_tail) + 1}' - __version__ = sep.join((prefix, suffix)) + __version__ = f'{prefix}{sep}{suffix}' del version_prefix, dot, version_tail else: __version__ = prefix pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy