diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0a6c4d6e..99b553e4 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -37,8 +37,8 @@ concurrency: cancel-in-progress: ${{ github.event_name == 'pull_request' }} env: - CUDA_VERSION: "11.7" - TEST_TORCH_SPECS: "cpu cu116" + CUDA_VERSION: "12.1" + TEST_TORCH_SPECS: "cpu cu118" jobs: build: @@ -48,17 +48,20 @@ jobs: timeout-minutes: 60 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 0 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: - python-version: "3.8 - 3.11" # sync with requires-python in pyproject.toml + python-version: "3.8 - 3.12" # sync with requires-python in pyproject.toml update-environment: true + - name: Install dependencies + run: python -m pip install --upgrade pip setuptools wheel build + - name: Set __release__ if: | startsWith(github.ref, 'refs/tags/') || @@ -69,16 +72,13 @@ jobs: - name: Print version run: python setup.py --version - - name: Install dependencies - run: python -m pip install --upgrade pip setuptools wheel build - - name: Build sdist and pure-Python wheel run: python -m build env: TORCHOPT_NO_EXTENSIONS: "true" - name: Upload artifact - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: build path: dist/* @@ -97,28 +97,32 @@ jobs: make pytest build-wheels-py38: - name: Build wheels for Python ${{ matrix.python-version }} on ubuntu-latest - runs-on: ubuntu-latest + name: Build wheels for Python ${{ matrix.python-version }} on ${{ matrix.os }} + runs-on: ${{ matrix.os }} needs: [build] if: github.repository == 'metaopt/torchopt' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/')) strategy: matrix: + os: [ubuntu-latest] python-version: ["3.8"] # sync with requires-python in pyproject.toml fail-fast: false timeout-minutes: 60 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 0 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} update-environment: true + - name: Install dependencies + run: python -m pip install --upgrade pip setuptools wheel build + - name: Set __release__ if: | startsWith(github.ref, 'refs/tags/') || @@ -132,7 +136,7 @@ jobs: run: python .github/workflows/set_cibw_build.py - name: Build wheels - uses: pypa/cibuildwheel@v2.15 + uses: pypa/cibuildwheel@v2.19 env: CIBW_BUILD: ${{ env.CIBW_BUILD }} with: @@ -140,35 +144,39 @@ jobs: output-dir: wheelhouse config-file: "{package}/pyproject.toml" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: - name: wheels-py38 + name: wheels-${{ matrix.python-version }}-${{ matrix.os }} path: wheelhouse/*.whl if-no-files-found: error build-wheels: - name: Build wheels for Python ${{ matrix.python-version }} on ubuntu-latest - runs-on: ubuntu-latest + name: Build wheels for Python ${{ matrix.python-version }} on ${{ matrix.os }} + runs-on: ${{ matrix.os }} needs: [build, build-wheels-py38] if: github.repository == 'metaopt/torchopt' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/')) strategy: matrix: - python-version: ["3.9", "3.10", "3.11"] # sync with requires-python in pyproject.toml + os: [ubuntu-latest] + python-version: ["3.9", "3.10", "3.11", "3.12"] # sync with requires-python in pyproject.toml fail-fast: false timeout-minutes: 60 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 0 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} update-environment: true + - name: Install dependencies + run: python -m pip install --upgrade pip setuptools wheel build + - name: Set __release__ if: | startsWith(github.ref, 'refs/tags/') || @@ -182,7 +190,7 @@ jobs: run: python .github/workflows/set_cibw_build.py - name: Build wheels - uses: pypa/cibuildwheel@v2.15 + uses: pypa/cibuildwheel@v2.19 env: CIBW_BUILD: ${{ env.CIBW_BUILD }} with: @@ -190,15 +198,47 @@ jobs: output-dir: wheelhouse config-file: "{package}/pyproject.toml" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: - name: wheels + name: wheels-${{ matrix.python-version }}-${{ matrix.os }} path: wheelhouse/*.whl if-no-files-found: error - publish: + list-artifacts: + name: List artifacts runs-on: ubuntu-latest needs: [build, build-wheels-py38, build-wheels] + if: github.repository == 'metaopt/torchopt' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/')) + timeout-minutes: 15 + steps: + - name: Download built sdist + uses: actions/download-artifact@v4 + with: + # unpacks default artifact into dist/ + # if `name: artifact` is omitted, the action will create extra parent dir + name: build + path: dist + + - name: Download built wheels + uses: actions/download-artifact@v4 + with: + pattern: wheels-* + path: dist + merge-multiple: true + + - name: List distributions + run: ls -lh dist/* + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + name: artifacts + path: dist/* + if-no-files-found: error + + publish: + runs-on: ubuntu-latest + needs: [list-artifacts] if: | github.repository == 'metaopt/torchopt' && github.event_name != 'pull_request' && (github.event_name != 'workflow_dispatch' || github.event.inputs.task == 'build-and-publish') && @@ -206,18 +246,21 @@ jobs: timeout-minutes: 15 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 0 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 if: startsWith(github.ref, 'refs/tags/') with: - python-version: "3.8 - 3.11" # sync with requires-python in pyproject.toml + python-version: "3.8 - 3.12" # sync with requires-python in pyproject.toml update-environment: true + - name: Install dependencies + run: python -m pip install --upgrade pip setuptools wheel build + - name: Set __release__ if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' run: | @@ -236,28 +279,12 @@ jobs: exit 1 fi - - name: Download built sdist - uses: actions/download-artifact@v3 - with: - # unpacks default artifact into dist/ - # if `name: artifact` is omitted, the action will create extra parent dir - name: build - path: dist - - - name: Download built wheels - uses: actions/download-artifact@v3 - with: - # unpacks default artifact into dist/ - # if `name: artifact` is omitted, the action will create extra parent dir - name: wheels-py38 - path: dist - - - name: Download built wheels - uses: actions/download-artifact@v3 + - name: Download built artifacts + uses: actions/download-artifact@v4 with: # unpacks default artifact into dist/ # if `name: artifact` is omitted, the action will create extra parent dir - name: wheels + name: artifacts path: dist - name: List distributions @@ -269,10 +296,10 @@ jobs: with: user: __token__ password: ${{ secrets.TESTPYPI_UPLOAD_TOKEN }} - repository_url: https://test.pypi.org/legacy/ + repository-url: https://test.pypi.org/legacy/ verbose: true - print_hash: true - skip_existing: true + print-hash: true + skip-existing: true - name: Publish to PyPI if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' @@ -281,5 +308,5 @@ jobs: user: __token__ password: ${{ secrets.PYPI_UPLOAD_TOKEN }} verbose: true - print_hash: true - skip_existing: true + print-hash: true + skip-existing: true diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index b338b149..472d5967 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -16,7 +16,7 @@ concurrency: cancel-in-progress: ${{ github.event_name == 'pull_request' }} env: - CUDA_VERSION: "11.7" + CUDA_VERSION: "12.1" jobs: lint: @@ -24,15 +24,15 @@ jobs: timeout-minutes: 30 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 1 - - name: Set up Python 3.8 - uses: actions/setup-python@v4 + - name: Set up Python 3.9 + uses: actions/setup-python@v5 with: - python-version: "3.8" + python-version: "3.9" update-environment: true - name: Setup CUDA Toolkit diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4f6fad50..f156ffe3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -27,7 +27,7 @@ concurrency: cancel-in-progress: ${{ github.event_name == 'pull_request' }} env: - CUDA_VERSION: "11.7" + CUDA_VERSION: "12.1" jobs: test: @@ -36,13 +36,13 @@ jobs: timeout-minutes: 60 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 1 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.8" # the lowest version we support (sync with requires-python in pyproject.toml) update-environment: true @@ -80,15 +80,15 @@ jobs: USE_FP16: "ON" TORCH_CUDA_ARCH_LIST: "Common" run: | - python -m pip install -vvv -e . + python -m pip install -vvv --editable . - name: Test with pytest run: | make pytest - name: Upload coverage to Codecov - if: runner.os == 'Linux' - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 + if: ${{ matrix.os == 'ubuntu-latest' }} with: token: ${{ secrets.CODECOV_TOKEN }} file: ./tests/coverage.xml @@ -106,13 +106,13 @@ jobs: fail-fast: false steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 1 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.8" # the lowest version we support (sync with requires-python in pyproject.toml) update-environment: true @@ -127,7 +127,7 @@ jobs: - name: Install TorchOpt run: | - python -m pip install -vvv -e . + python -m pip install -vvv --editable . env: TORCHOPT_NO_EXTENSIONS: "true" @@ -136,8 +136,8 @@ jobs: make pytest - name: Upload coverage to Codecov - if: runner.os == 'Linux' - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 + if: ${{ matrix.os == 'ubuntu-latest' }} with: token: ${{ secrets.CODECOV_TOKEN }} file: ./tests/coverage.xml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e5c37d40..7ab860a5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ ci: default_stages: [commit, push, manual] repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.6.0 hooks: - id: check-symlinks - id: destroyed-symlinks @@ -26,24 +26,24 @@ repos: - id: debug-statements - id: double-quote-string-fixer - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v16.0.6 + rev: v18.1.8 hooks: - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.284 + rev: v0.5.0 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/PyCQA/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort - repo: https://github.com/psf/black - rev: 23.7.0 + rev: 24.4.2 hooks: - id: black-jupyter - repo: https://github.com/asottile/pyupgrade - rev: v3.10.1 + rev: v3.16.0 hooks: - id: pyupgrade args: [--py38-plus] # sync with requires-python @@ -52,7 +52,7 @@ repos: ^examples/ ) - repo: https://github.com/pycqa/flake8 - rev: 6.1.0 + rev: 7.1.0 hooks: - id: flake8 additional_dependencies: @@ -68,7 +68,7 @@ repos: ^docs/source/conf.py$ ) - repo: https://github.com/codespell-project/codespell - rev: v2.2.5 + rev: v2.3.0 hooks: - id: codespell additional_dependencies: [".[toml]"] diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 6a9c387e..c014fa97 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -7,9 +7,9 @@ version: 2 # Set the version of Python and other tools you might need build: - os: ubuntu-20.04 + os: ubuntu-lts-latest tools: - python: mambaforge-4.10 + python: mambaforge-latest jobs: post_install: - python -m pip install --upgrade pip setuptools diff --git a/CHANGELOG.md b/CHANGELOG.md index 24e4eea4..62234c25 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,11 +13,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- +- Enable CI workflow to build CXX/CUDA extension for Python 3.12 by [@XuehaiPan](https://github.com/XuehaiPan) in [#216](https://github.com/metaopt/torchopt/pull/216). ### Changed -- +- Refactor the raw import statement in `setup.py` with `importlib` utilities by [@XuehaiPan](https://github.com/XuehaiPan) in [#214](https://github.com/metaopt/torchopt/pull/214). ### Fixed @@ -25,7 +25,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Removed -- +- Drop PyTorch 1.x support by [@XuehaiPan](https://github.com/XuehaiPan) in [#215](https://github.com/metaopt/torchopt/pull/215). + +------ + +## [0.7.3] - 2023-11-10 + +### Changed + +- Set minimal C++ standard to C++17 by [@XuehaiPan](https://github.com/XuehaiPan) in [#195](https://github.com/metaopt/torchopt/pull/195). + +### Fixed + +- Fix `optree` compatibility for multi-tree-map with `None` values by [@XuehaiPan](https://github.com/XuehaiPan) in [#195](https://github.com/metaopt/torchopt/pull/195). ------ @@ -195,7 +207,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ------ -[Unreleased]: https://github.com/metaopt/torchopt/compare/v0.7.2...HEAD +[Unreleased]: https://github.com/metaopt/torchopt/compare/v0.7.3...HEAD +[0.7.3]: https://github.com/metaopt/torchopt/compare/v0.7.2...v0.7.3 [0.7.2]: https://github.com/metaopt/torchopt/compare/v0.7.1...v0.7.2 [0.7.1]: https://github.com/metaopt/torchopt/compare/v0.7.0...v0.7.1 [0.7.0]: https://github.com/metaopt/torchopt/compare/v0.6.0...v0.7.0 diff --git a/CITATION.cff b/CITATION.cff index 965b6a7f..3c6098bf 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -32,7 +32,7 @@ authors: family-names: Yang affiliation: Peking University email: yaodong.yang@pku.edu.cn -version: 0.7.2 -date-released: "2023-08-18" +version: 0.7.3 +date-released: "2023-11-10" license: Apache-2.0 repository-code: "https://github.com/metaopt/torchopt" diff --git a/CMakeLists.txt b/CMakeLists.txt index 3b091a22..101ba3ec 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,13 +17,20 @@ cmake_minimum_required(VERSION 3.11) # for FetchContent project(torchopt LANGUAGES CXX) include(FetchContent) -set(PYBIND11_VERSION v2.10.3) + +set(THIRD_PARTY_DIR "${CMAKE_SOURCE_DIR}/third-party") +if(NOT DEFINED PYBIND11_VERSION AND NOT "$ENV{PYBIND11_VERSION}" STREQUAL "") + set(PYBIND11_VERSION "$ENV{PYBIND11_VERSION}") +endif() +if(NOT PYBIND11_VERSION) + set(PYBIND11_VERSION stable) +endif() if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) endif() -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) find_package(Threads REQUIRED) # -pthread @@ -172,7 +179,7 @@ endif() system( STRIP OUTPUT_VARIABLE PYTHON_VERSION - COMMAND "${PYTHON_EXECUTABLE}" -c "print(__import__('platform').python_version())" + COMMAND "${PYTHON_EXECUTABLE}" -c "print('.'.join(map(str, __import__('sys').version_info[:3])))" ) message(STATUS "Use Python version: ${PYTHON_VERSION}") @@ -216,11 +223,12 @@ if("${PYBIND11_CMAKE_DIR}" STREQUAL "") GIT_REPOSITORY https://github.com/pybind/pybind11.git GIT_TAG "${PYBIND11_VERSION}" GIT_SHALLOW TRUE - SOURCE_DIR "${CMAKE_SOURCE_DIR}/third-party/pybind11" - BINARY_DIR "${CMAKE_SOURCE_DIR}/third-party/.cmake/pybind11/build" - STAMP_DIR "${CMAKE_SOURCE_DIR}/third-party/.cmake/pybind11/stamp" + SOURCE_DIR "${THIRD_PARTY_DIR}/pybind11" + BINARY_DIR "${THIRD_PARTY_DIR}/.cmake/pybind11/build" + STAMP_DIR "${THIRD_PARTY_DIR}/.cmake/pybind11/stamp" ) FetchContent_GetProperties(pybind11) + if(NOT pybind11_POPULATED) message(STATUS "Populating Git repository pybind11@${PYBIND11_VERSION} to third-party/pybind11...") FetchContent_MakeAvailable(pybind11) diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index cc4a4e9a..4c96e694 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -5,7 +5,7 @@ We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender -identity and expression, level of experience, education, socio-economic status, +identity and expression, level of experience, education, socioeconomic status, nationality, personal appearance, race, religion, or sexual identity and orientation. diff --git a/Dockerfile b/Dockerfile index 7295af74..246a81e9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,7 +7,7 @@ # $ docker build --target devel --tag torchopt-devel:latest . # -ARG cuda_docker_tag="11.7.1-cudnn8-devel-ubuntu22.04" +ARG cuda_docker_tag="12.1.0-cudnn8-devel-ubuntu22.04" FROM nvidia/cuda:"${cuda_docker_tag}" AS builder ENV DEBIAN_FRONTEND=noninteractive diff --git a/LICENSE b/LICENSE index 8d26c203..185e0144 100644 --- a/LICENSE +++ b/LICENSE @@ -187,7 +187,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright [2022-2023] [MetaOPT Team. All Rights Reserved.] + Copyright [2022-2024] [MetaOPT Team. All Rights Reserved.] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/Makefile b/Makefile index 0f7dd74e..e9099f0c 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -print-% : ; @echo $* = $($*) +print-%: ; @echo $* = $($*) PROJECT_NAME = torchopt COPYRIGHT = "MetaOPT Team. All Rights Reserved." PROJECT_PATH = $(PROJECT_NAME) @@ -22,7 +22,8 @@ install: install-editable: $(PYTHON) -m pip install --upgrade pip $(PYTHON) -m pip install --upgrade setuptools wheel - $(PYTHON) -m pip install torch numpy pybind11 + $(PYTHON) -m pip install --upgrade pybind11 cmake + $(PYTHON) -m pip install torch numpy USE_FP16=ON TORCH_CUDA_ARCH_LIST=Auto $(PYTHON) -m pip install -vvv --no-build-isolation --editable . install-e: install-editable # alias @@ -112,8 +113,9 @@ addlicense-install: go-install # Tests pytest: test-install + $(PYTHON) -m pytest --version cd tests && $(PYTHON) -c 'import $(PROJECT_PATH)' && \ - $(PYTHON) -m pytest --verbose --color=yes --durations=0 \ + $(PYTHON) -m pytest --verbose --color=yes \ --cov="$(PROJECT_PATH)" --cov-config=.coveragerc --cov-report=xml --cov-report=term-missing \ $(PYTESTOPTS) . @@ -122,30 +124,39 @@ test: pytest # Python linters pylint: pylint-install + $(PYTHON) -m pylint --version $(PYTHON) -m pylint $(PROJECT_PATH) flake8: flake8-install + $(PYTHON) -m flake8 --version $(PYTHON) -m flake8 --count --show-source --statistics py-format: py-format-install + $(PYTHON) -m isort --version + $(PYTHON) -m black --version $(PYTHON) -m isort --project $(PROJECT_PATH) --check $(PYTHON_FILES) && \ $(PYTHON) -m black --check $(PYTHON_FILES) tutorials ruff: ruff-install + $(PYTHON) -m ruff --version $(PYTHON) -m ruff check . ruff-fix: ruff-install + $(PYTHON) -m ruff --version $(PYTHON) -m ruff check . --fix --exit-non-zero-on-fix mypy: mypy-install + $(PYTHON) -m mypy --version $(PYTHON) -m mypy $(PROJECT_PATH) --install-types --non-interactive pre-commit: pre-commit-install + $(PYTHON) -m pre_commit --version $(PYTHON) -m pre_commit run --all-files # C++ linters cmake-configure: cmake-install + cmake --version cmake -S . -B cmake-build-debug \ -DCMAKE_BUILD_TYPE=Debug \ -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ @@ -157,13 +168,16 @@ cmake-build: cmake-configure cmake: cmake-build cpplint: cpplint-install + $(PYTHON) -m cpplint --version $(PYTHON) -m cpplint $(CXX_FILES) $(CUDA_FILES) clang-format: clang-format-install + $(CLANG_FORMAT) --version $(CLANG_FORMAT) --style=file -i $(CXX_FILES) $(CUDA_FILES) -n --Werror clang-tidy: clang-tidy-install cmake-configure - clang-tidy -p=cmake-build-debug $(CXX_FILES) + clang-tidy --version + clang-tidy --extra-arg="-v" -p=cmake-build-debug $(CXX_FILES) # Documentation @@ -223,7 +237,11 @@ docker-devel: docker: docker-base docker-devel docker-run-base: docker-base - docker run --network=host --gpus=all -v /:/host -h ubuntu -it $(PROJECT_NAME):$(COMMIT_HASH) + docker run -it --network=host --gpus=all --shm-size=4gb -v /:/host -v "${PWD}:/workspace" \ + -h ubuntu -w /workspace $(PROJECT_NAME):$(COMMIT_HASH) docker-run-devel: docker-devel - docker run --network=host --gpus=all -v /:/host -h ubuntu -it $(PROJECT_NAME)-devel:$(COMMIT_HASH) + docker run -it --network=host --gpus=all --shm-size=4gb -v /:/host -v "${PWD}:/workspace" \ + -h ubuntu -w /workspace $(PROJECT_NAME)-devel:$(COMMIT_HASH) + +docker-run: docker-run-base diff --git a/README.md b/README.md index ee1905ab..91d44a25 100644 --- a/README.md +++ b/README.md @@ -425,11 +425,11 @@ Then run the following command to install TorchOpt from PyPI ([![PyPI](https://i pip3 install torchopt ``` -If the minimum version of PyTorch is not satisfied, `pip` will install/upgrade it for you. Please be careful about the `torch` build for CPU / CUDA support (e.g. `cpu`, `cu116`, `cu117`). +If the minimum version of PyTorch is not satisfied, `pip` will install/upgrade it for you. Please be careful about the `torch` build for CPU / CUDA support (e.g. `cpu`, `cu118`, `cu121`). You may need to specify the extra index URL for the `torch` package: ```bash -pip3 install torchopt --extra-index-url https://download.pytorch.org/whl/cu117 +pip3 install torchopt --extra-index-url https://download.pytorch.org/whl/cu121 ``` See for more information about installing PyTorch. @@ -450,7 +450,7 @@ git clone https://github.com/metaopt/torchopt.git cd torchopt # You may need `CONDA_OVERRIDE_CUDA` if conda fails to detect the NVIDIA driver (e.g. in docker or WSL2) -CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe-minimal.yaml +CONDA_OVERRIDE_CUDA=12.1 conda env create --file conda-recipe-minimal.yaml conda activate torchopt make install-editable # or run `pip3 install --no-build-isolation --editable .` @@ -469,11 +469,15 @@ See [CHANGELOG.md](CHANGELOG.md). If you find TorchOpt useful, please cite it in your publications. ```bibtex -@article{torchopt, +@article{JMLR:TorchOpt, + author = {Jie Ren* and Xidong Feng* and Bo Liu* and Xuehai Pan* and Yao Fu and Luo Mai and Yaodong Yang}, title = {TorchOpt: An Efficient Library for Differentiable Optimization}, - author = {Ren, Jie and Feng, Xidong and Liu, Bo and Pan, Xuehai and Fu, Yao and Mai, Luo and Yang, Yaodong}, - journal = {arXiv preprint arXiv:2211.06934}, - year = {2022} + journal = {Journal of Machine Learning Research}, + year = {2023}, + volume = {24}, + number = {367}, + pages = {1--14}, + url = {http://jmlr.org/papers/v24/23-0191.html} } ``` diff --git a/conda-recipe-minimal-cpu.yaml b/conda-recipe-minimal-cpu.yaml index 0404f10c..dda60369 100644 --- a/conda-recipe-minimal-cpu.yaml +++ b/conda-recipe-minimal-cpu.yaml @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,11 +26,11 @@ channels: - conda-forge dependencies: - - python = 3.10 + - python = 3.11 - pip # Learning - - pytorch::pytorch >= 1.13 # sync with project.dependencies + - pytorch::pytorch >= 2.0 # sync with project.dependencies - pytorch::torchvision - pytorch::pytorch-mutex = *=*cpu* - pip: @@ -40,10 +40,10 @@ dependencies: - cmake >= 3.11 - make - cxx-compiler - - pybind11 >= 2.10.1 + - pybind11 >= 2.11.1 # Misc - optree >= 0.4.1 - - typing-extensions >= 4.0.0 + - typing-extensions - numpy - python-graphviz diff --git a/conda-recipe-minimal.yaml b/conda-recipe-minimal.yaml index c3d155b8..7e28d2ef 100644 --- a/conda-recipe-minimal.yaml +++ b/conda-recipe-minimal.yaml @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,41 +15,41 @@ # # Create virtual environment with command: # -# $ CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe-minimal.yaml +# $ CONDA_OVERRIDE_CUDA=12.1 conda env create --file conda-recipe-minimal.yaml # name: torchopt channels: - pytorch - - nvidia/label/cuda-11.7.1 + - nvidia/label/cuda-12.1.0 - defaults - conda-forge dependencies: - - python = 3.10 + - python = 3.11 - pip # Learning - - pytorch::pytorch >= 1.13 # sync with project.dependencies + - pytorch::pytorch >= 2.0 # sync with project.dependencies - pytorch::torchvision - pytorch::pytorch-mutex = *=*cuda* - pip: - torchviz # Device select - - nvidia/label/cuda-11.7.1::cuda-toolkit = 11.7 + - nvidia/label/cuda-12.1.0::cuda-toolkit = 12.1 # Build toolchain - cmake >= 3.11 - make - cxx-compiler - - nvidia/label/cuda-11.7.1::cuda-nvcc - - nvidia/label/cuda-11.7.1::cuda-cudart-dev - - pybind11 >= 2.10.1 + - nvidia/label/cuda-12.1.0::cuda-nvcc + - nvidia/label/cuda-12.1.0::cuda-cudart-dev + - pybind11 >= 2.11.1 # Misc - optree >= 0.4.1 - - typing-extensions >= 4.0.0 + - typing-extensions - numpy - python-graphviz diff --git a/conda-recipe.yaml b/conda-recipe.yaml index 997f11c5..9753852b 100644 --- a/conda-recipe.yaml +++ b/conda-recipe.yaml @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,49 +15,49 @@ # # Create virtual environment with command: # -# $ CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe.yaml +# $ CONDA_OVERRIDE_CUDA=12.1 conda env create --file conda-recipe.yaml # name: torchopt channels: - pytorch - - nvidia/label/cuda-11.7.1 + - nvidia/label/cuda-12.1.0 - defaults - conda-forge dependencies: - - python = 3.10 + - python = 3.11 - pip # Learning - - pytorch::pytorch >= 1.13 # sync with project.dependencies + - pytorch::pytorch >= 2.0 # sync with project.dependencies - pytorch::torchvision - pytorch::pytorch-mutex = *=*cuda* - pip: - torchviz - sphinxcontrib-katex # for documentation - - jax # for tutorials - - jaxlib # for tutorials - - optax # for tutorials - - jaxopt # for tests + - conda-forge::jax # for tutorials + - conda-forge::jaxlib # for tutorials + - conda-forge::optax # for tutorials + - conda-forge::jaxopt # for tests - tensorboard # for examples # Device select - - nvidia/label/cuda-11.7.1::cuda-toolkit = 11.7 + - nvidia/label/cuda-12.1.0::cuda-toolkit = 12.1 # Build toolchain - cmake >= 3.11 - make - cxx-compiler - - nvidia/label/cuda-11.7.1::cuda-nvcc - - nvidia/label/cuda-11.7.1::cuda-cudart-dev + - nvidia/label/cuda-12.1.0::cuda-nvcc + - nvidia/label/cuda-12.1.0::cuda-cudart-dev - patchelf >= 0.14 - - pybind11 >= 2.10.1 + - pybind11 >= 2.11.1 # Misc - optree >= 0.4.1 - - typing-extensions >= 4.0.0 + - typing-extensions - numpy - matplotlib-base - seaborn @@ -77,17 +77,16 @@ dependencies: - hunspell-en - myst-nb - ipykernel - - pandoc - docutils # Testing - pytest - pytest-cov - pytest-xdist - - isort >= 5.11.0 - - conda-forge::black-jupyter >= 22.6.0 - - pylint >= 2.15.0 - - mypy >= 0.990 + - isort + - conda-forge::black-jupyter + - pylint + - mypy - flake8 - flake8-bugbear - flake8-comprehensions @@ -97,8 +96,8 @@ dependencies: - ruff - doc8 - pydocstyle - - clang-format >= 14 - - clang-tools >= 14 # clang-tidy - - cpplint + - conda-forge::clang-format >= 14 + - conda-forge::clang-tools >= 14 # clang-tidy + - conda-forge::cpplint - conda-forge::pre-commit - conda-forge::identify diff --git a/docs/conda-recipe.yaml b/docs/conda-recipe.yaml index 9a14af3f..d7d2f288 100644 --- a/docs/conda-recipe.yaml +++ b/docs/conda-recipe.yaml @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,23 +15,23 @@ # # Create virtual environment with command: # -# $ CONDA_OVERRIDE_CUDA=11.7 conda env create --file docs/conda-recipe.yaml +# $ CONDA_OVERRIDE_CUDA=12.1 conda env create --file docs/conda-recipe.yaml # name: torchopt-docs channels: - pytorch - - nvidia/label/cuda-11.7.1 + - nvidia/label/cuda-12.1.0 - defaults - conda-forge dependencies: - - python = 3.10 + - python = 3.11 - pip # Learning - - pytorch::pytorch >= 1.13 # sync with project.dependencies + - pytorch::pytorch >= 2.0 # sync with project.dependencies - pytorch::cpuonly - pytorch::pytorch-mutex = *=*cpu* - pip: @@ -42,13 +42,13 @@ dependencies: - cmake >= 3.11 - make - cxx-compiler - - nvidia/label/cuda-11.7.1::cuda-nvcc - - nvidia/label/cuda-11.7.1::cuda-cudart-dev - - pybind11 >= 2.10.1 + - nvidia/label/cuda-12.1.0::cuda-nvcc + - nvidia/label/cuda-12.1.0::cuda-cudart-dev + - pybind11 >= 2.11.1 # Misc - optree >= 0.4.1 - - typing-extensions >= 4.0.0 + - typing-extensions - numpy - matplotlib-base - seaborn @@ -67,5 +67,4 @@ dependencies: - hunspell-en - myst-nb - ipykernel - - pandoc - docutils diff --git a/docs/requirements.txt b/docs/requirements.txt index 655c64ff..c9631b75 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,20 +1,20 @@ --extra-index-url https://download.pytorch.org/whl/cpu # Sync with project.dependencies -torch >= 1.13 +torch >= 2.0 --requirement ../requirements.txt -sphinx >= 5.2.1 +sphinx >= 5.2.1, < 7.0.0a0 +sphinxcontrib-bibtex >= 2.4 +sphinx-autodoc-typehints >= 1.20 +myst-nb >= 0.15 + sphinx-autoapi sphinx-autobuild sphinx-copybutton sphinx-rtd-theme sphinxcontrib-katex -sphinxcontrib-bibtex -sphinx-autodoc-typehints >= 1.19.2 IPython ipykernel -pandoc -myst-nb docutils matplotlib diff --git a/docs/source/_static/css/style.css b/docs/source/_static/css/style.css index 15b8d2e2..60f33012 100644 --- a/docs/source/_static/css/style.css +++ b/docs/source/_static/css/style.css @@ -1,5 +1,5 @@ /** - * Copyright 2022 MetaOPT Team. All Rights Reserved. + * Copyright 2022-2024 MetaOPT Team. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/docs/source/conf.py b/docs/source/conf.py index f5d206c7..a4f23533 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -66,7 +66,7 @@ def filter(self, record: logging.LogRecord) -> bool: # -- Project information ------------------------------------------------------- project = 'TorchOpt' -copyright = '2022-2023 MetaOPT Team' +copyright = '2022-2024 MetaOPT Team' author = 'TorchOpt Contributors' # The full version, including alpha/beta/rc tags diff --git a/docs/source/developer/contributing.rst b/docs/source/developer/contributing.rst index 4e7dd355..e40a564a 100644 --- a/docs/source/developer/contributing.rst +++ b/docs/source/developer/contributing.rst @@ -17,7 +17,7 @@ Before contributing to TorchOpt, please follow the instructions below to setup. .. code-block:: bash # You may need `CONDA_OVERRIDE_CUDA` if conda fails to detect the NVIDIA driver (e.g. in docker or WSL2) - CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe.yaml + CONDA_OVERRIDE_CUDA=12.1 conda env create --file conda-recipe.yaml conda activate torchopt @@ -91,14 +91,14 @@ To build compatible **manylinux2014** (:pep:`599`) wheels for distribution, you pip3 install --upgrade cibuildwheel - export TEST_TORCH_SPECS="cpu cu116" # `torch` builds for testing - export CUDA_VERSION="11.7" # version of `nvcc` for compilation + export TEST_TORCH_SPECS="cpu cu118" # `torch` builds for testing + export CUDA_VERSION="12.1" # version of `nvcc` for compilation python3 -m cibuildwheel --platform=linux --output-dir=wheelhouse --config-file=pyproject.toml It will install the CUDA compiler with ``CUDA_VERSION`` in the build container. Then build wheel binaries for all supported CPython versions. The outputs will be placed in the ``wheelhouse`` directory. To build a wheel for a specific CPython version, you can use the |CIBW_BUILD|_ environment variable. -For example, the following command will build a wheel for Python 3.7: +For example, the following command will build a wheel for Python 3.8: .. code-block:: bash diff --git a/docs/source/explicit_diff/explicit_diff.rst b/docs/source/explicit_diff/explicit_diff.rst index 9445adb8..28e06f77 100644 --- a/docs/source/explicit_diff/explicit_diff.rst +++ b/docs/source/explicit_diff/explicit_diff.rst @@ -59,7 +59,7 @@ For PyTorch-like API (e.g., ``step()``), we designed a base class :class:`torcho torchopt.MetaAdagrad torchopt.MetaAdam torchopt.MetaAdamW - torchopt.AdaMax + torchopt.MetaAdaMax torchopt.MetaAdamax torchopt.MetaRAdam torchopt.MetaRMSProp diff --git a/docs/source/index.rst b/docs/source/index.rst index 02fab843..83602090 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -42,7 +42,7 @@ You can use the following commands with `conda ` cd torchopt # You may need `CONDA_OVERRIDE_CUDA` if conda fails to detect the NVIDIA driver (e.g. in docker or WSL2) - CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe-minimal.yaml + CONDA_OVERRIDE_CUDA=12.1 conda env create --file conda-recipe-minimal.yaml conda activate torchopt @@ -118,11 +118,15 @@ If you find TorchOpt useful, please cite it in your publications. .. code-block:: bibtex - @article{torchopt, - title = {TorchOpt: An Efficient Library for Differentiable Optimization}, - author = {Ren, Jie and Feng, Xidong and Liu, Bo and Pan, Xuehai and Fu, Yao and Mai, Luo and Yang, Yaodong}, - journal = {arXiv preprint arXiv:2211.06934}, - year = {2022} + @article{JMLR:TorchOpt, + author = {Jie Ren* and Xidong Feng* and Bo Liu* and Xuehai Pan* and Yao Fu and Luo Mai and Yaodong Yang}, + title = {TorchOpt: An Efficient Library for Differentiable Optimization}, + journal = {Journal of Machine Learning Research}, + year = {2023}, + volume = {24}, + number = {367}, + pages = {1--14}, + url = {http://jmlr.org/papers/v24/23-0191.html} } diff --git a/examples/FuncTorch/maml_omniglot_vmap.py b/examples/FuncTorch/maml_omniglot_vmap.py index e1cfe95e..2f42e050 100644 --- a/examples/FuncTorch/maml_omniglot_vmap.py +++ b/examples/FuncTorch/maml_omniglot_vmap.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/FuncTorch/parallel_train_torchopt.py b/examples/FuncTorch/parallel_train_torchopt.py index f28bded7..523515e1 100644 --- a/examples/FuncTorch/parallel_train_torchopt.py +++ b/examples/FuncTorch/parallel_train_torchopt.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/L2R/helpers/argument.py b/examples/L2R/helpers/argument.py index 5df9f314..7db6c982 100644 --- a/examples/L2R/helpers/argument.py +++ b/examples/L2R/helpers/argument.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/L2R/helpers/model.py b/examples/L2R/helpers/model.py index dbde0e8d..877ad50a 100644 --- a/examples/L2R/helpers/model.py +++ b/examples/L2R/helpers/model.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/L2R/helpers/utils.py b/examples/L2R/helpers/utils.py index 7e95ca6f..ade64236 100644 --- a/examples/L2R/helpers/utils.py +++ b/examples/L2R/helpers/utils.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/L2R/l2r.py b/examples/L2R/l2r.py index 64990976..a0ae764b 100644 --- a/examples/L2R/l2r.py +++ b/examples/L2R/l2r.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/LOLA/helpers/agent.py b/examples/LOLA/helpers/agent.py index a8f8ee31..78946ee7 100644 --- a/examples/LOLA/helpers/agent.py +++ b/examples/LOLA/helpers/agent.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/LOLA/helpers/argument.py b/examples/LOLA/helpers/argument.py index 39618134..ad53c056 100644 --- a/examples/LOLA/helpers/argument.py +++ b/examples/LOLA/helpers/argument.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/LOLA/helpers/env.py b/examples/LOLA/helpers/env.py index f496276e..e1576a7d 100644 --- a/examples/LOLA/helpers/env.py +++ b/examples/LOLA/helpers/env.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/LOLA/helpers/utils.py b/examples/LOLA/helpers/utils.py index 20f67be5..4dd436ec 100644 --- a/examples/LOLA/helpers/utils.py +++ b/examples/LOLA/helpers/utils.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/LOLA/lola_dice.py b/examples/LOLA/lola_dice.py index 6dbaaf24..485de894 100644 --- a/examples/LOLA/lola_dice.py +++ b/examples/LOLA/lola_dice.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/LOLA/visualize.py b/examples/LOLA/visualize.py index 6dc54ddf..7af19b21 100755 --- a/examples/LOLA/visualize.py +++ b/examples/LOLA/visualize.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/MAML-RL/func_maml.py b/examples/MAML-RL/func_maml.py index f3a00642..475c1b12 100644 --- a/examples/MAML-RL/func_maml.py +++ b/examples/MAML-RL/func_maml.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/MAML-RL/helpers/__init__.py b/examples/MAML-RL/helpers/__init__.py index 9855e0b3..31d45c37 100644 --- a/examples/MAML-RL/helpers/__init__.py +++ b/examples/MAML-RL/helpers/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/MAML-RL/helpers/policy.py b/examples/MAML-RL/helpers/policy.py index 0a43a5b1..0bf8e188 100644 --- a/examples/MAML-RL/helpers/policy.py +++ b/examples/MAML-RL/helpers/policy.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/MAML-RL/helpers/policy_torchrl.py b/examples/MAML-RL/helpers/policy_torchrl.py index 91bdb269..5bcb3863 100644 --- a/examples/MAML-RL/helpers/policy_torchrl.py +++ b/examples/MAML-RL/helpers/policy_torchrl.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/MAML-RL/helpers/tabular_mdp.py b/examples/MAML-RL/helpers/tabular_mdp.py index f8feb7b7..0d8f5f7f 100644 --- a/examples/MAML-RL/helpers/tabular_mdp.py +++ b/examples/MAML-RL/helpers/tabular_mdp.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/MAML-RL/maml.py b/examples/MAML-RL/maml.py index 42fddbac..0cb57a92 100644 --- a/examples/MAML-RL/maml.py +++ b/examples/MAML-RL/maml.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/MAML-RL/maml_torchrl.py b/examples/MAML-RL/maml_torchrl.py index 225f73bc..56db91ef 100644 --- a/examples/MAML-RL/maml_torchrl.py +++ b/examples/MAML-RL/maml_torchrl.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/MGRL/mgrl.py b/examples/MGRL/mgrl.py index 49eb79c4..bdc4b140 100644 --- a/examples/MGRL/mgrl.py +++ b/examples/MGRL/mgrl.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/distributed/few-shot/maml_omniglot.py b/examples/distributed/few-shot/maml_omniglot.py index 24601dfa..f840e65e 100644 --- a/examples/distributed/few-shot/maml_omniglot.py +++ b/examples/distributed/few-shot/maml_omniglot.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/distributed/few-shot/maml_omniglot_local_loader.py b/examples/distributed/few-shot/maml_omniglot_local_loader.py index d7413770..fb737d4f 100644 --- a/examples/distributed/few-shot/maml_omniglot_local_loader.py +++ b/examples/distributed/few-shot/maml_omniglot_local_loader.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/few-shot/maml_omniglot.py b/examples/few-shot/maml_omniglot.py index d798aa1d..7f7f67fe 100644 --- a/examples/few-shot/maml_omniglot.py +++ b/examples/few-shot/maml_omniglot.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/iMAML/imaml_omniglot.py b/examples/iMAML/imaml_omniglot.py index 8a6960ba..1db08427 100644 --- a/examples/iMAML/imaml_omniglot.py +++ b/examples/iMAML/imaml_omniglot.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/iMAML/imaml_omniglot_functional.py b/examples/iMAML/imaml_omniglot_functional.py index 60fd4108..7bc1e9da 100644 --- a/examples/iMAML/imaml_omniglot_functional.py +++ b/examples/iMAML/imaml_omniglot_functional.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/requirements.txt b/examples/requirements.txt index 76bed365..48945c62 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -1,5 +1,5 @@ ---extra-index-url https://download.pytorch.org/whl/cu117 -torch >= 1.13 +--extra-index-url https://download.pytorch.org/whl/cu121 +torch >= 2.0 torchvision --requirement ../requirements.txt diff --git a/examples/visualize.py b/examples/visualize.py index 5e08267f..067fa511 100644 --- a/examples/visualize.py +++ b/examples/visualize.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/include/adam_op/adam_op.h b/include/adam_op/adam_op.h index a49b0a06..2d0abcd3 100644 --- a/include/adam_op/adam_op.h +++ b/include/adam_op/adam_op.h @@ -1,4 +1,4 @@ -// Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2024 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/include/adam_op/adam_op_impl_cpu.h b/include/adam_op/adam_op_impl_cpu.h index 37aba528..4d54377e 100644 --- a/include/adam_op/adam_op_impl_cpu.h +++ b/include/adam_op/adam_op_impl_cpu.h @@ -1,4 +1,4 @@ -// Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2024 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/include/adam_op/adam_op_impl_cuda.cuh b/include/adam_op/adam_op_impl_cuda.cuh index 6e661564..17002b36 100644 --- a/include/adam_op/adam_op_impl_cuda.cuh +++ b/include/adam_op/adam_op_impl_cuda.cuh @@ -1,4 +1,4 @@ -// Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2024 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/include/common.h b/include/common.h index 65f9ef33..256b0ca1 100644 --- a/include/common.h +++ b/include/common.h @@ -1,4 +1,4 @@ -// Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2024 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/include/utils.h b/include/utils.h index 0ef98539..cefabfac 100644 --- a/include/utils.h +++ b/include/utils.h @@ -1,4 +1,4 @@ -// Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2024 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pyproject.toml b/pyproject.toml index 47424855..d343e04a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ [build-system] # Sync with project.dependencies -requires = ["setuptools", "torch >= 1.13", "numpy", "pybind11 >= 2.10.1"] +requires = ["setuptools", "torch >= 2.0", "numpy", "pybind11 >= 2.11.1"] build-backend = "setuptools.build_meta" [project] @@ -22,7 +22,7 @@ authors = [ license = { text = "Apache License, Version 2.0" } keywords = [ "PyTorch", - "functorch", + "FuncTorch", "JAX", "Meta-Learning", "Optimizer", @@ -38,6 +38,8 @@ classifiers = [ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", "Operating System :: Microsoft :: Windows", "Operating System :: POSIX :: Linux", "Operating System :: MacOS", @@ -51,11 +53,11 @@ classifiers = [ ] dependencies = [ # See also build-system.requires and project.requires-python - "torch >= 1.13", + "torch >= 2.0", "optree >= 0.4.1", "numpy", "graphviz", - "typing-extensions >= 4.0.0", + "typing-extensions", ] dynamic = ["version"] @@ -68,9 +70,9 @@ Documentation = "https://torchopt.readthedocs.io" [project.optional-dependencies] lint = [ "isort", - "black[jupyter] >= 22.6.0", - "pylint[spelling] >= 2.15.0", - "mypy >= 0.990", + "black[jupyter]", + "pylint[spelling]", + "mypy", "flake8", "flake8-bugbear", "flake8-comprehensions", @@ -88,7 +90,7 @@ test = [ "pytest", "pytest-cov", "pytest-xdist", - "jax[cpu] >= 0.3; platform_system != 'Windows'", + "jax[cpu] >= 0.4; platform_system != 'Windows'", "jaxopt; platform_system != 'Windows'", "optax; platform_system != 'Windows'", ] @@ -113,8 +115,8 @@ build-verbosity = 3 environment.USE_FP16 = "ON" environment.CUDACXX = "/usr/local/cuda/bin/nvcc" environment.TORCH_CUDA_ARCH_LIST = "Common" -environment.DEFAULT_CUDA_VERSION = "11.7" -environment.DEFAULT_TEST_TORCH_SPECS = "cpu cu116" +environment.DEFAULT_CUDA_VERSION = "12.1" +environment.DEFAULT_TEST_TORCH_SPECS = "cpu cu118" environment-pass = ["CUDA_VERSION", "TEST_TORCH_SPECS"] container-engine = "docker" test-extras = ["test"] @@ -176,11 +178,10 @@ test-command = """ # Linter tools ################################################################# [tool.black] -safe = true line-length = 100 skip-string-normalization = true # Sync with requires-python -target-version = ["py38", "py39", "py310", "py311"] +target-version = ["py38"] [tool.isort] atomic = true @@ -194,15 +195,15 @@ multi_line_output = 3 [tool.mypy] # Sync with requires-python -python_version = 3.8 +python_version = "3.8" pretty = true show_error_codes = true show_error_context = true show_traceback = true allow_redefinition = true check_untyped_defs = true -disallow_incomplete_defs = false -disallow_untyped_defs = false +disallow_incomplete_defs = true +disallow_untyped_defs = true ignore_missing_imports = true no_implicit_optional = true strict_equality = true @@ -226,12 +227,15 @@ ignore-words = "docs/source/spelling_wordlist.txt" # Sync with requires-python target-version = "py38" line-length = 100 -show-source = true +output-format = "full" src = ["torchopt", "tests"] extend-exclude = ["examples"] + +[tool.ruff.lint] select = [ "E", "W", # pycodestyle "F", # pyflakes + "C90", # mccabe "UP", # pyupgrade "ANN", # flake8-annotations "S", # flake8-bandit @@ -240,7 +244,10 @@ select = [ "COM", # flake8-commas "C4", # flake8-comprehensions "EXE", # flake8-executable + "FA", # flake8-future-annotations + "LOG", # flake8-logging "ISC", # flake8-implicit-str-concat + "INP", # flake8-no-pep420 "PIE", # flake8-pie "PYI", # flake8-pyi "Q", # flake8-quotes @@ -248,6 +255,10 @@ select = [ "RET", # flake8-return "SIM", # flake8-simplify "TID", # flake8-tidy-imports + "TCH", # flake8-type-checking + "PERF", # perflint + "FURB", # refurb + "TRY", # tryceratops "RUF", # ruff ] ignore = [ @@ -265,13 +276,13 @@ ignore = [ # S101: use of `assert` detected # internal use and may never raise at runtime "S101", - # PLR0402: use from {module} import {name} in lieu of alias - # use alias for import convention (e.g., `import torch.nn as nn`) - "PLR0402", + # TRY003: avoid specifying long messages outside the exception class + # long messages are necessary for clarity + "TRY003", ] typing-modules = ["torchopt.typing"] -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "__init__.py" = [ "F401", # unused-import ] @@ -293,19 +304,22 @@ typing-modules = ["torchopt.typing"] "F401", # unused-import "F811", # redefined-while-unused ] +"docs/source/conf.py" = [ + "INP001", # flake8-no-pep420 +] -[tool.ruff.flake8-annotations] +[tool.ruff.lint.flake8-annotations] allow-star-arg-any = true -[tool.ruff.flake8-quotes] +[tool.ruff.lint.flake8-quotes] docstring-quotes = "double" multiline-quotes = "double" inline-quotes = "single" -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "all" -[tool.ruff.pylint] +[tool.ruff.lint.pylint] allow-magic-value-types = ["int", "str", "float"] [tool.pytest.ini_options] @@ -313,5 +327,8 @@ filterwarnings = [ "error", 'ignore:Explicitly requested dtype float64 requested in .* is not available, and will be truncated to dtype float32\.:UserWarning', 'ignore:jax\.numpy\.DeviceArray is deprecated\. Use jax\.Array\.:DeprecationWarning', + 'ignore:(ast\.Str|ast\.NameConstant|Attribute s) is deprecated and will be removed in Python 3\.14:DeprecationWarning', 'ignore:.*functorch.*deprecate.*:UserWarning', + 'ignore:.*Apple Paravirtual device.*:UserWarning', + 'ignore:.*NVML.*:UserWarning', ] diff --git a/requirements.txt b/requirements.txt index 961ddf73..a5151c36 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # Sync with project.dependencies -torch >= 1.13 +torch >= 2.0 optree >= 0.4.1 numpy graphviz -typing-extensions >= 4.0.0 +typing-extensions diff --git a/setup.py b/setup.py index dc1103df..c50ba5ed 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,4 @@ +import contextlib import os import pathlib import platform @@ -5,22 +6,13 @@ import shutil import sys import sysconfig +from importlib.util import module_from_spec, spec_from_file_location -from setuptools import setup +from setuptools import Extension, setup +from setuptools.command.build_ext import build_ext -try: - from pybind11.setup_helpers import Pybind11Extension as Extension - from pybind11.setup_helpers import build_ext -except ImportError: - from setuptools import Extension - from setuptools.command.build_ext import build_ext - HERE = pathlib.Path(__file__).absolute().parent -VERSION_FILE = HERE / 'torchopt' / 'version.py' - -sys.path.insert(0, str(VERSION_FILE.parent)) -import version # noqa class CMakeExtension(Extension): @@ -47,7 +39,6 @@ def build_extension(self, ext): build_temp.mkdir(parents=True, exist_ok=True) config = 'Debug' if self.debug else 'Release' - cmake_args = [ f'-DCMAKE_BUILD_TYPE={config}', f'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{config.upper()}={ext_path.parent}', @@ -83,13 +74,53 @@ def build_extension(self, ext): build_args.extend(['--target', ext.target, '--']) + cwd = os.getcwd() try: os.chdir(build_temp) self.spawn([cmake, ext.source_dir, *cmake_args]) if not self.dry_run: self.spawn([cmake, '--build', '.', *build_args]) finally: - os.chdir(HERE) + os.chdir(cwd) + + +@contextlib.contextmanager +def vcs_version(name, path): + path = pathlib.Path(path).absolute() + assert path.is_file() + module_spec = spec_from_file_location(name=name, location=path) + assert module_spec is not None + assert module_spec.loader is not None + module = sys.modules.get(name) + if module is None: + module = module_from_spec(module_spec) + sys.modules[name] = module + module_spec.loader.exec_module(module) + + if module.__release__: + yield module + return + + content = None + try: + try: + content = path.read_text(encoding='utf-8') + path.write_text( + data=re.sub( + r"""__version__\s*=\s*('[^']+'|"[^"]+")""", + f'__version__ = {module.__version__!r}', + string=content, + ), + encoding='utf-8', + ) + except OSError: + content = None + + yield module + finally: + if content is not None: + with path.open(mode='wt', encoding='utf-8', newline='') as file: + file.write(content) CIBUILDWHEEL = os.getenv('CIBUILDWHEEL', '0') == '1' @@ -112,29 +143,9 @@ def build_extension(self, ext): ext_kwargs.clear() -VERSION_CONTENT = None - -try: - if not version.__release__: - try: - VERSION_CONTENT = VERSION_FILE.read_text(encoding='utf-8') - VERSION_FILE.write_text( - data=re.sub( - r"""__version__\s*=\s*('[^']+'|"[^"]+")""", - f'__version__ = {version.__version__!r}', - string=VERSION_CONTENT, - ), - encoding='utf-8', - ) - except OSError: - VERSION_CONTENT = None - +with vcs_version(name='torchopt.version', path=(HERE / 'torchopt' / 'version.py')) as version: setup( name='torchopt', version=version.__version__, **ext_kwargs, ) -finally: - if VERSION_CONTENT is not None: - with VERSION_FILE.open(mode='wt', encoding='utf-8', newline='') as file: - file.write(VERSION_CONTENT) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2f4ae731..30d18335 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/adam_op/adam_op.cpp b/src/adam_op/adam_op.cpp index 08c9fb74..47f5d7f1 100644 --- a/src/adam_op/adam_op.cpp +++ b/src/adam_op/adam_op.cpp @@ -1,4 +1,4 @@ -// Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2024 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/src/adam_op/adam_op_impl_cpu.cpp b/src/adam_op/adam_op_impl_cpu.cpp index 1135206d..9c460685 100644 --- a/src/adam_op/adam_op_impl_cpu.cpp +++ b/src/adam_op/adam_op_impl_cpu.cpp @@ -1,4 +1,4 @@ -// Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2024 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/src/adam_op/adam_op_impl_cuda.cu b/src/adam_op/adam_op_impl_cuda.cu index ea1526a6..a12eca4f 100644 --- a/src/adam_op/adam_op_impl_cuda.cu +++ b/src/adam_op/adam_op_impl_cuda.cu @@ -1,4 +1,4 @@ -// Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2024 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/src/extension.cpp b/src/extension.cpp index 45880bf6..3d3d52b3 100644 --- a/src/extension.cpp +++ b/src/extension.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2024 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/tests/conftest.py b/tests/conftest.py index eaa734b2..bb2b1cf2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/helpers.py b/tests/helpers.py index 50451496..ca5aa443 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ import itertools import os import random -from typing import Iterable +from typing import TYPE_CHECKING, Iterable import numpy as np import pytest @@ -30,7 +30,10 @@ from torch.utils import data from torchopt import pytree -from torchopt.typing import TensorTree + + +if TYPE_CHECKING: + from torchopt.typing import TensorTree BATCH_SIZE = 64 @@ -65,7 +68,7 @@ def parametrize(**argvalues) -> pytest.mark.parametrize: argvalues = list(itertools.product(*tuple(map(argvalues.get, arguments)))) ids = tuple( - '-'.join(f'{arg}({val})' for arg, val in zip(arguments, values)) for values in argvalues + '-'.join(f'{arg}({val!r})' for arg, val in zip(arguments, values)) for values in argvalues ) return pytest.mark.parametrize(arguments, argvalues, ids=ids) diff --git a/tests/requirements.txt b/tests/requirements.txt index 87c994e1..ee54732b 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,19 +1,20 @@ ---extra-index-url https://download.pytorch.org/whl/cu117 -torch >= 1.13 +--extra-index-url https://download.pytorch.org/whl/cu121 +torch >= 2.0 --requirement ../requirements.txt -jax[cpu] >= 0.3; platform_system != 'Windows' +jax[cpu] >= 0.4; platform_system != 'Windows' jaxopt; platform_system != 'Windows' -optax; platform_system != 'Windows' +optax < 0.1.8a0; platform_system != 'Windows' and python_version < '3.9' +optax >= 0.1.8; platform_system != 'Windows' and python_version >= '3.9' pytest pytest-cov pytest-xdist -isort >= 5.11.0 -black[jupyter] >= 22.6.0 -pylint[spelling] >= 2.15.0 -mypy >= 0.990 +isort +black[jupyter] +pylint[spelling] +mypy flake8 flake8-bugbear flake8-comprehensions diff --git a/tests/test_accelerated_op.py b/tests/test_accelerated_op.py index 6cb45ca0..668c9b9a 100644 --- a/tests/test_accelerated_op.py +++ b/tests/test_accelerated_op.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_alias.py b/tests/test_alias.py index aef35b96..3c42d7c8 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,7 +15,7 @@ from __future__ import annotations -from typing import Callable +from typing import TYPE_CHECKING, Callable import functorch import pytest @@ -26,7 +26,10 @@ import torchopt from torchopt import pytree from torchopt.alias.utils import _set_use_chain_flat -from torchopt.typing import TensorTree + + +if TYPE_CHECKING: + from torchopt.typing import TensorTree @helpers.parametrize( diff --git a/tests/test_clip.py b/tests/test_clip.py index 0b191cfe..2614781e 100644 --- a/tests/test_clip.py +++ b/tests/test_clip.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_combine.py b/tests/test_combine.py index 39b3e37f..1a026b9e 100644 --- a/tests/test_combine.py +++ b/tests/test_combine.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_hook.py b/tests/test_hook.py index 1f3024c7..e89bb178 100644 --- a/tests/test_hook.py +++ b/tests/test_hook.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_implicit.py b/tests/test_implicit.py index 61623a17..6cccb716 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ import copy import re from collections import OrderedDict -from types import FunctionType +from typing import TYPE_CHECKING import functorch import numpy as np @@ -47,6 +47,10 @@ HAS_JAX = False +if TYPE_CHECKING: + from types import FunctionType + + BATCH_SIZE = 8 NUM_UPDATES = 3 @@ -123,7 +127,7 @@ def get_rr_dataset_torch() -> data.DataLoader: inner_lr=[2e-2, 2e-3], inner_update=[20, 50, 100], ) -def test_imaml_solve_normal_cg( +def test_imaml_solve_normal_cg( # noqa: C901 dtype: torch.dtype, lr: float, inner_lr: float, @@ -251,7 +255,7 @@ def outer_level(p, xs, ys): inner_update=[20, 50, 100], ns=[False, True], ) -def test_imaml_solve_inv( +def test_imaml_solve_inv( # noqa: C901 dtype: torch.dtype, lr: float, inner_lr: float, @@ -375,7 +379,12 @@ def outer_level(p, xs, ys): inner_lr=[2e-2, 2e-3], inner_update=[20, 50, 100], ) -def test_imaml_module(dtype: torch.dtype, lr: float, inner_lr: float, inner_update: int) -> None: +def test_imaml_module( # noqa: C901 + dtype: torch.dtype, + lr: float, + inner_lr: float, + inner_update: int, +) -> None: np_dtype = helpers.dtype_torch2numpy(dtype) jax_model, jax_params = get_model_jax(dtype=np_dtype) @@ -763,7 +772,7 @@ def solve(self): make_optimality_from_objective(MyModule2) -def test_module_abstract_methods() -> None: +def test_module_abstract_methods() -> None: # noqa: C901 class MyModule1(torchopt.nn.ImplicitMetaGradientModule): def objective(self): return torch.tensor(0.0) @@ -809,7 +818,7 @@ def solve(self): class MyModule5(torchopt.nn.ImplicitMetaGradientModule): @classmethod - def optimality(self): + def optimality(cls): return () def solve(self): @@ -846,7 +855,7 @@ def solve(self): class MyModule8(torchopt.nn.ImplicitMetaGradientModule): @classmethod - def objective(self): + def objective(cls): return () def solve(self): diff --git a/tests/test_import.py b/tests/test_import.py index f7523756..04d0ebbb 100644 --- a/tests/test_import.py +++ b/tests/test_import.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 7758b7db..c5b07618 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_meta_optim.py b/tests/test_meta_optim.py index 61f8a7ad..55712bdf 100644 --- a/tests/test_meta_optim.py +++ b/tests/test_meta_optim.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_nn.py b/tests/test_nn.py index 8e89bdb5..f77c20ec 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_optim.py b/tests/test_optim.py index dc3941d9..1257054f 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_pytree.py b/tests/test_pytree.py index d82d81f2..6ee2939b 100644 --- a/tests/test_pytree.py +++ b/tests/test_pytree.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_schedule.py b/tests/test_schedule.py index 1fdc4669..e4c0ac0a 100644 --- a/tests/test_schedule.py +++ b/tests/test_schedule.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_transform.py b/tests/test_transform.py index 9598386d..0a7bd498 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_utils.py b/tests/test_utils.py index d1be7c6f..57c35e47 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================== +import operator + import torch import torchopt @@ -80,7 +82,7 @@ def test_module_clone() -> None: assert y.is_cuda -def test_extract_state_dict(): +def test_extract_state_dict(): # noqa: C901 fc = torch.nn.Linear(1, 1) state_dict = torchopt.extract_state_dict(fc, by='reference', device=torch.device('meta')) for param_dict in state_dict.params: @@ -121,7 +123,7 @@ def test_extract_state_dict(): loss = fc(torch.ones(1, 1)).sum() optim.step(loss) state_dict = torchopt.extract_state_dict(optim) - same = pytree.tree_map(lambda x, y: x is y, state_dict, tuple(optim.state_groups)) + same = pytree.tree_map(operator.is_, state_dict, tuple(optim.state_groups)) assert all(pytree.tree_flatten(same)[0]) diff --git a/tests/test_zero_order.py b/tests/test_zero_order.py index 61f75f9a..65642559 100644 --- a/tests/test_zero_order.py +++ b/tests/test_zero_order.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/_C/adam_op.pyi b/torchopt/_C/adam_op.pyi index 04f141fd..5ef572aa 100644 --- a/torchopt/_C/adam_op.pyi +++ b/torchopt/_C/adam_op.pyi @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/__init__.py b/torchopt/__init__.py index a089f3dc..830072e3 100644 --- a/torchopt/__init__.py +++ b/torchopt/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -81,50 +81,50 @@ __all__ = [ - 'accelerated_op_available', - 'adam', - 'adamax', - 'adadelta', - 'radam', - 'adamw', - 'adagrad', - 'rmsprop', - 'sgd', - 'clip_grad_norm', - 'nan_to_num', - 'register_hook', - 'chain', - 'Optimizer', 'SGD', - 'Adam', - 'AdaMax', - 'Adamax', 'AdaDelta', - 'Adadelta', - 'RAdam', - 'AdamW', 'AdaGrad', + 'AdaMax', + 'Adadelta', 'Adagrad', - 'RMSProp', - 'RMSprop', - 'MetaOptimizer', - 'MetaSGD', - 'MetaAdam', - 'MetaAdaMax', - 'MetaAdamax', + 'Adam', + 'AdamW', + 'Adamax', + 'FuncOptimizer', 'MetaAdaDelta', - 'MetaAdadelta', - 'MetaRAdam', - 'MetaAdamW', 'MetaAdaGrad', + 'MetaAdaMax', + 'MetaAdadelta', 'MetaAdagrad', + 'MetaAdam', + 'MetaAdamW', + 'MetaAdamax', + 'MetaOptimizer', + 'MetaRAdam', 'MetaRMSProp', 'MetaRMSprop', - 'FuncOptimizer', + 'MetaSGD', + 'Optimizer', + 'RAdam', + 'RMSProp', + 'RMSprop', + 'accelerated_op_available', + 'adadelta', + 'adagrad', + 'adam', + 'adamax', + 'adamw', 'apply_updates', + 'chain', + 'clip_grad_norm', 'extract_state_dict', - 'recover_state_dict', - 'stop_gradient', 'module_clone', 'module_detach_', + 'nan_to_num', + 'radam', + 'recover_state_dict', + 'register_hook', + 'rmsprop', + 'sgd', + 'stop_gradient', ] diff --git a/torchopt/accelerated_op/__init__.py b/torchopt/accelerated_op/__init__.py index 3ac943e3..90452046 100644 --- a/torchopt/accelerated_op/__init__.py +++ b/torchopt/accelerated_op/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,12 +16,15 @@ from __future__ import annotations -from typing import Iterable +from typing import TYPE_CHECKING, Iterable import torch from torchopt.accelerated_op.adam_op import AdamOp -from torchopt.typing import Device + + +if TYPE_CHECKING: + from torchopt.typing import Device def is_available(devices: Device | Iterable[Device] | None = None) -> bool: @@ -42,6 +45,6 @@ def is_available(devices: Device | Iterable[Device] | None = None) -> bool: return False updates = torch.tensor(1.0, device=device) op(updates, updates, updates, 1) - return True except Exception: # noqa: BLE001 # pylint: disable=broad-except return False + return True diff --git a/torchopt/accelerated_op/_src/__init__.py b/torchopt/accelerated_op/_src/__init__.py index bbf0b4cd..8c2f7b03 100644 --- a/torchopt/accelerated_op/_src/__init__.py +++ b/torchopt/accelerated_op/_src/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/accelerated_op/_src/adam_op.py b/torchopt/accelerated_op/_src/adam_op.py index c8fc8898..d7f9796d 100644 --- a/torchopt/accelerated_op/_src/adam_op.py +++ b/torchopt/accelerated_op/_src/adam_op.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,7 +18,11 @@ from __future__ import annotations -import torch +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + import torch def forward_( diff --git a/torchopt/accelerated_op/adam_op.py b/torchopt/accelerated_op/adam_op.py index d6f9e9f9..43ac26cd 100644 --- a/torchopt/accelerated_op/adam_op.py +++ b/torchopt/accelerated_op/adam_op.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/alias/__init__.py b/torchopt/alias/__init__.py index 3ea721c4..5767c5d7 100644 --- a/torchopt/alias/__init__.py +++ b/torchopt/alias/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -41,4 +41,13 @@ from torchopt.alias.sgd import sgd -__all__ = ['adagrad', 'radam', 'adam', 'adamax', 'adadelta', 'adamw', 'rmsprop', 'sgd'] +__all__ = [ + 'adadelta', + 'adagrad', + 'adam', + 'adamax', + 'adamw', + 'radam', + 'rmsprop', + 'sgd', +] diff --git a/torchopt/alias/adadelta.py b/torchopt/alias/adadelta.py index 2e3640f2..910cb13e 100644 --- a/torchopt/alias/adadelta.py +++ b/torchopt/alias/adadelta.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from torchopt.alias.utils import ( _get_use_chain_flat, flip_sign_and_add_weight_decay, @@ -23,7 +25,10 @@ ) from torchopt.combine import chain from torchopt.transform import scale_by_adadelta -from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +if TYPE_CHECKING: + from torchopt.typing import GradientTransformation, ScalarOrSchedule __all__ = ['adadelta'] diff --git a/torchopt/alias/adagrad.py b/torchopt/alias/adagrad.py index 3f983c38..6fdb4aa3 100644 --- a/torchopt/alias/adagrad.py +++ b/torchopt/alias/adagrad.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/alias/adam.py b/torchopt/alias/adam.py index dc889285..0ae0eb8e 100644 --- a/torchopt/alias/adam.py +++ b/torchopt/alias/adam.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,6 +33,8 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from torchopt.alias.utils import ( _get_use_chain_flat, flip_sign_and_add_weight_decay, @@ -40,7 +42,10 @@ ) from torchopt.combine import chain from torchopt.transform import scale_by_accelerated_adam, scale_by_adam -from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +if TYPE_CHECKING: + from torchopt.typing import GradientTransformation, ScalarOrSchedule __all__ = ['adam'] diff --git a/torchopt/alias/adamax.py b/torchopt/alias/adamax.py index ffa19e37..3da16713 100644 --- a/torchopt/alias/adamax.py +++ b/torchopt/alias/adamax.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from torchopt.alias.utils import ( _get_use_chain_flat, flip_sign_and_add_weight_decay, @@ -23,7 +25,10 @@ ) from torchopt.combine import chain from torchopt.transform import scale_by_adamax -from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +if TYPE_CHECKING: + from torchopt.typing import GradientTransformation, ScalarOrSchedule __all__ = ['adamax'] diff --git a/torchopt/alias/adamw.py b/torchopt/alias/adamw.py index e8bed2ab..2dc72ef1 100644 --- a/torchopt/alias/adamw.py +++ b/torchopt/alias/adamw.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ from __future__ import annotations -from typing import Callable +from typing import TYPE_CHECKING, Callable from torchopt.alias.utils import ( _get_use_chain_flat, @@ -42,7 +42,10 @@ ) from torchopt.combine import chain from torchopt.transform import add_decayed_weights, scale_by_accelerated_adam, scale_by_adam -from torchopt.typing import GradientTransformation, OptState, Params, ScalarOrSchedule + + +if TYPE_CHECKING: + from torchopt.typing import GradientTransformation, OptState, Params, ScalarOrSchedule __all__ = ['adamw'] diff --git a/torchopt/alias/radam.py b/torchopt/alias/radam.py index 230c1151..9e2880ee 100644 --- a/torchopt/alias/radam.py +++ b/torchopt/alias/radam.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from torchopt.alias.utils import ( _get_use_chain_flat, flip_sign_and_add_weight_decay, @@ -23,7 +25,10 @@ ) from torchopt.combine import chain from torchopt.transform import scale_by_radam -from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +if TYPE_CHECKING: + from torchopt.typing import GradientTransformation, ScalarOrSchedule __all__ = ['radam'] diff --git a/torchopt/alias/rmsprop.py b/torchopt/alias/rmsprop.py index 96092548..612e4f45 100644 --- a/torchopt/alias/rmsprop.py +++ b/torchopt/alias/rmsprop.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/alias/sgd.py b/torchopt/alias/sgd.py index 6fb3c6db..6d5935bc 100644 --- a/torchopt/alias/sgd.py +++ b/torchopt/alias/sgd.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -44,6 +44,7 @@ __all__ = ['sgd'] +# pylint: disable-next=too-many-arguments def sgd( lr: ScalarOrSchedule, momentum: float = 0.0, diff --git a/torchopt/alias/utils.py b/torchopt/alias/utils.py index 5c8dc97a..0f41e822 100644 --- a/torchopt/alias/utils.py +++ b/torchopt/alias/utils.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,14 +16,18 @@ from __future__ import annotations import threading - -import torch +from typing import TYPE_CHECKING from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation, identity from torchopt.transform import scale, scale_by_schedule from torchopt.transform.utils import tree_map_flat, tree_map_flat_ -from torchopt.typing import Numeric, OptState, Params, ScalarOrSchedule, Updates + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import Numeric, OptState, Params, ScalarOrSchedule, Updates __all__ = ['flip_sign_and_add_weight_decay', 'scale_by_neg_lr'] @@ -68,7 +72,7 @@ def _flip_sign_and_add_weight_decay_flat( ) -def _flip_sign_and_add_weight_decay( +def _flip_sign_and_add_weight_decay( # noqa: C901 weight_decay: float = 0.0, maximize: bool = False, *, @@ -108,19 +112,21 @@ def update_fn( if inplace: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if g.requires_grad: return g.add_(p, alpha=weight_decay) return g.add_(p.data, alpha=weight_decay) - updates = tree_map_(f, updates, params) + tree_map_(f, params, updates) else: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: - return g.add(p, alpha=weight_decay) + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.add(p, alpha=weight_decay) if g is not None else g - updates = tree_map(f, updates, params) + updates = tree_map(f, params, updates) return updates, state @@ -139,7 +145,7 @@ def update_fn( def f(g: torch.Tensor) -> torch.Tensor: return g.neg_() - updates = tree_map_(f, updates) + tree_map_(f, updates) else: @@ -166,19 +172,21 @@ def update_fn( if inplace: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if g.requires_grad: return g.neg_().add_(p, alpha=weight_decay) return g.neg_().add_(p.data, alpha=weight_decay) - updates = tree_map_(f, updates, params) + tree_map_(f, params, updates) else: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: - return g.neg().add_(p, alpha=weight_decay) + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.neg().add_(p, alpha=weight_decay) if g is not None else g - updates = tree_map(f, updates, params) + updates = tree_map(f, params, updates) return updates, state diff --git a/torchopt/base.py b/torchopt/base.py index cab2b49f..81892e17 100644 --- a/torchopt/base.py +++ b/torchopt/base.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -44,10 +44,10 @@ __all__ = [ + 'ChainedGradientTransformation', 'EmptyState', - 'UninitializedState', 'GradientTransformation', - 'ChainedGradientTransformation', + 'UninitializedState', 'identity', ] @@ -164,9 +164,11 @@ def __new__(cls, *transformations: GradientTransformation) -> Self: """Create a new chained gradient transformation.""" transformations = tuple( itertools.chain.from_iterable( - t.transformations - if isinstance(t, ChainedGradientTransformation) - else ((t,) if not isinstance(t, IdentityGradientTransformation) else ()) + ( + t.transformations + if isinstance(t, ChainedGradientTransformation) + else ((t,) if not isinstance(t, IdentityGradientTransformation) else ()) + ) for t in transformations ), ) diff --git a/torchopt/clip.py b/torchopt/clip.py index eda4bef3..d64afc58 100644 --- a/torchopt/clip.py +++ b/torchopt/clip.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,11 +19,16 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import torch from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates __all__ = ['clip_grad_norm'] diff --git a/torchopt/combine.py b/torchopt/combine.py index fc1a7152..15345286 100644 --- a/torchopt/combine.py +++ b/torchopt/combine.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,9 +33,14 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from torchopt import pytree from torchopt.base import ChainedGradientTransformation, GradientTransformation, identity -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates __all__ = ['chain', 'chain_flat'] diff --git a/torchopt/diff/__init__.py b/torchopt/diff/__init__.py index 984841ed..194512f5 100644 --- a/torchopt/diff/__init__.py +++ b/torchopt/diff/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/diff/implicit/__init__.py b/torchopt/diff/implicit/__init__.py index 4e50b615..4cff14c6 100644 --- a/torchopt/diff/implicit/__init__.py +++ b/torchopt/diff/implicit/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,4 +19,4 @@ from torchopt.diff.implicit.nn import ImplicitMetaGradientModule -__all__ = ['custom_root', 'ImplicitMetaGradientModule'] +__all__ = ['ImplicitMetaGradientModule', 'custom_root'] diff --git a/torchopt/diff/implicit/decorator.py b/torchopt/diff/implicit/decorator.py index 03720a49..11ba0153 100644 --- a/torchopt/diff/implicit/decorator.py +++ b/torchopt/diff/implicit/decorator.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -37,20 +37,23 @@ import functools import inspect -from typing import Any, Callable, Dict, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, Sequence, Tuple import functorch import torch from torch.autograd import Function from torchopt import linear_solve, pytree -from torchopt.typing import ( - ListOfOptionalTensors, - ListOfTensors, - TensorOrTensors, - TupleOfOptionalTensors, - TupleOfTensors, -) + + +if TYPE_CHECKING: + from torchopt.typing import ( + ListOfOptionalTensors, + ListOfTensors, + TensorOrTensors, + TupleOfOptionalTensors, + TupleOfTensors, + ) __all__ = ['custom_root'] @@ -253,7 +256,7 @@ def _merge_tensor_and_others( # pylint: disable-next=too-many-arguments,too-many-statements -def _custom_root( +def _custom_root( # noqa: C901 solver_fn: Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]], optimality_fn: Callable[..., TensorOrTensors], solve: Callable[..., TensorOrTensors], @@ -271,7 +274,7 @@ def _custom_root( fn = getattr(reference_signature, 'subfn', reference_signature) reference_signature = inspect.signature(fn) - def make_custom_vjp_solver_fn( + def make_custom_vjp_solver_fn( # noqa: C901 solver_fn: Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]], kwarg_keys: Sequence[str], args_signs: tuple[tuple[int, int, type[tuple | list] | None], ...], @@ -279,7 +282,7 @@ def make_custom_vjp_solver_fn( # pylint: disable-next=missing-class-docstring,abstract-method class ImplicitMetaGradient(Function): @staticmethod - def forward( # type: ignore[override] # pylint: disable=arguments-differ + def forward( # pylint: disable=arguments-differ ctx: Any, *flat_args: Any, ) -> tuple[Any, ...]: diff --git a/torchopt/diff/implicit/nn/__init__.py b/torchopt/diff/implicit/nn/__init__.py index 5bc7aa8d..e91ef8ed 100644 --- a/torchopt/diff/implicit/nn/__init__.py +++ b/torchopt/diff/implicit/nn/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py index a72e5304..6b214cb8 100644 --- a/torchopt/diff/implicit/nn/module.py +++ b/torchopt/diff/implicit/nn/module.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,15 +22,19 @@ import functools import inspect import itertools -from typing import Any, Iterable +from typing import TYPE_CHECKING, Any, Iterable import functorch -import torch from torchopt.diff.implicit.decorator import custom_root from torchopt.nn.module import MetaGradientModule from torchopt.nn.stateless import reparametrize, swap_state -from torchopt.typing import LinearSolver, TupleOfTensors + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import LinearSolver, TupleOfTensors __all__ = ['ImplicitMetaGradientModule'] diff --git a/torchopt/diff/zero_order/__init__.py b/torchopt/diff/zero_order/__init__.py index b621ffdc..4369f4e5 100644 --- a/torchopt/diff/zero_order/__init__.py +++ b/torchopt/diff/zero_order/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,7 +25,7 @@ from torchopt.diff.zero_order.nn import ZeroOrderGradientModule -__all__ = ['zero_order', 'ZeroOrderGradientModule'] +__all__ = ['ZeroOrderGradientModule', 'zero_order'] class _CallableModule(_ModuleType): # pylint: disable=too-few-public-methods diff --git a/torchopt/diff/zero_order/decorator.py b/torchopt/diff/zero_order/decorator.py index f63f0574..e498b43c 100644 --- a/torchopt/diff/zero_order/decorator.py +++ b/torchopt/diff/zero_order/decorator.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ from __future__ import annotations import functools +import itertools from typing import Any, Callable, Literal, Sequence from typing_extensions import TypeAlias # Python 3.10+ @@ -43,7 +44,7 @@ def sample( return self.sample_fn(sample_shape) -def _zero_order_naive( # pylint: disable=too-many-statements +def _zero_order_naive( # noqa: C901 # pylint: disable=too-many-statements fn: Callable[..., torch.Tensor], distribution: Samplable, argnums: tuple[int, ...], @@ -51,7 +52,7 @@ def _zero_order_naive( # pylint: disable=too-many-statements sigma: float, ) -> Callable[..., torch.Tensor]: @functools.wraps(fn) - def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements + def apply(*args: Any) -> torch.Tensor: # noqa: C901 # pylint: disable=too-many-statements diff_params = [args[argnum] for argnum in argnums] flat_diff_params: list[Any] flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params) # type: ignore[arg-type] @@ -81,7 +82,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: output = fn(*origin_args) if not isinstance(output, torch.Tensor): - raise RuntimeError('`output` must be a tensor.') + raise TypeError('`output` must be a tensor.') if output.ndim != 0: raise RuntimeError('`output` must be a scalar tensor.') ctx.save_for_backward(*flat_diff_params, *tensors) @@ -122,9 +123,9 @@ def add_perturbation( for _ in range(num_samples): noises = [distribution.sample(sample_shape=p.shape) for p in flat_diff_params] - flat_noisy_params = [ - add_perturbation(t, n) for t, n in zip(flat_diff_params, noises) # type: ignore[arg-type] - ] + flat_noisy_params = list( + itertools.starmap(add_perturbation, zip(flat_diff_params, noises)), + ) noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment] diff_params_treespec, flat_noisy_params, @@ -149,7 +150,7 @@ def add_perturbation( return apply -def _zero_order_forward( # pylint: disable=too-many-statements +def _zero_order_forward( # noqa: C901 # pylint: disable=too-many-statements fn: Callable[..., torch.Tensor], distribution: Samplable, argnums: tuple[int, ...], @@ -157,7 +158,7 @@ def _zero_order_forward( # pylint: disable=too-many-statements sigma: float, ) -> Callable[..., torch.Tensor]: @functools.wraps(fn) - def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements + def apply(*args: Any) -> torch.Tensor: # noqa: C901 # pylint: disable=too-many-statements diff_params = [args[argnum] for argnum in argnums] flat_diff_params: list[Any] flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params) # type: ignore[arg-type] @@ -187,7 +188,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: output = fn(*origin_args) if not isinstance(output, torch.Tensor): - raise RuntimeError('`output` must be a tensor.') + raise TypeError('`output` must be a tensor.') if output.ndim != 0: raise RuntimeError('`output` must be a scalar tensor.') ctx.save_for_backward(*flat_diff_params, *tensors, output) @@ -226,9 +227,9 @@ def add_perturbation(tensor: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: for _ in range(num_samples): noises = [distribution.sample(sample_shape=p.shape) for p in flat_diff_params] - flat_noisy_params = [ - add_perturbation(t, n) for t, n in zip(flat_diff_params, noises) # type: ignore[arg-type] - ] + flat_noisy_params = list( + itertools.starmap(add_perturbation, zip(flat_diff_params, noises)), + ) noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment] diff_params_treespec, flat_noisy_params, @@ -254,7 +255,7 @@ def add_perturbation(tensor: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: return apply -def _zero_order_antithetic( # pylint: disable=too-many-statements +def _zero_order_antithetic( # noqa: C901 # pylint: disable=too-many-statements fn: Callable[..., torch.Tensor], distribution: Samplable, argnums: tuple[int, ...], @@ -262,7 +263,7 @@ def _zero_order_antithetic( # pylint: disable=too-many-statements sigma: float, ) -> Callable[..., torch.Tensor]: @functools.wraps(fn) - def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements + def apply(*args: Any) -> torch.Tensor: # noqa: C901 # pylint: disable=too-many-statements diff_params = [args[argnum] for argnum in argnums] flat_diff_params: list[Any] flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params) # type: ignore[arg-type] @@ -292,7 +293,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: output = fn(*origin_args) if not isinstance(output, torch.Tensor): - raise RuntimeError('`output` must be a tensor.') + raise TypeError('`output` must be a tensor.') if output.ndim != 0: raise RuntimeError('`output` must be a scalar tensor.') ctx.save_for_backward(*flat_diff_params, *tensors) diff --git a/torchopt/diff/zero_order/nn/__init__.py b/torchopt/diff/zero_order/nn/__init__.py index 1bf64efe..f2753b27 100644 --- a/torchopt/diff/zero_order/nn/__init__.py +++ b/torchopt/diff/zero_order/nn/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/diff/zero_order/nn/module.py b/torchopt/diff/zero_order/nn/module.py index 75da28f9..eeddabeb 100644 --- a/torchopt/diff/zero_order/nn/module.py +++ b/torchopt/diff/zero_order/nn/module.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,14 +20,17 @@ import abc import functools -from typing import Any, Sequence +from typing import TYPE_CHECKING, Any, Sequence import torch import torch.nn as nn from torchopt.diff.zero_order.decorator import Method, Samplable, zero_order from torchopt.nn.stateless import reparametrize -from torchopt.typing import Numeric, TupleOfTensors + + +if TYPE_CHECKING: + from torchopt.typing import Numeric, TupleOfTensors __all__ = ['ZeroOrderGradientModule'] diff --git a/torchopt/distributed/__init__.py b/torchopt/distributed/__init__.py index 534b2dea..31f1283b 100644 --- a/torchopt/distributed/__init__.py +++ b/torchopt/distributed/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/distributed/api.py b/torchopt/distributed/api.py index 3a6f0526..97be682f 100644 --- a/torchopt/distributed/api.py +++ b/torchopt/distributed/api.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -42,15 +42,15 @@ __all__ = [ 'TensorDimensionPartitioner', - 'dim_partitioner', 'batch_partitioner', + 'dim_partitioner', 'mean_reducer', - 'sum_reducer', - 'remote_async_call', - 'remote_sync_call', 'parallelize', 'parallelize_async', 'parallelize_sync', + 'remote_async_call', + 'remote_sync_call', + 'sum_reducer', ] @@ -107,7 +107,7 @@ def __init__( self.workers = workers # pylint: disable-next=too-many-branches,too-many-locals - def __call__( + def __call__( # noqa: C901 self, *args: Any, **kwargs: Any, @@ -271,6 +271,7 @@ def sum_reducer(results: Iterable[torch.Tensor]) -> torch.Tensor: return torch.sum(torch.stack(tuple(results), dim=0), dim=0) +# pylint: disable-next=too-many-arguments def remote_async_call( func: Callable[..., T], *, @@ -309,7 +310,7 @@ def remote_async_call( elif callable(partitioner): partitions = partitioner(*args, **kwargs) # type: ignore[assignment] else: - raise ValueError(f'Invalid partitioner: {partitioner!r}.') + raise TypeError(f'Invalid partitioner: {partitioner!r}.') futures = [] for rank, worker_args, worker_kwargs in partitions: @@ -323,11 +324,12 @@ def remote_async_call( if reducer is not None: return cast( Future[U], - future.then(lambda fut: cast(Callable[[Iterable[T]], U], reducer)(fut.wait())), + future.then(lambda fut: reducer(fut.wait())), ) return future +# pylint: disable-next=too-many-arguments def remote_sync_call( func: Callable[..., T], *, diff --git a/torchopt/distributed/autograd.py b/torchopt/distributed/autograd.py index 4e10d24e..71afdb86 100644 --- a/torchopt/distributed/autograd.py +++ b/torchopt/distributed/autograd.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,15 +17,18 @@ from __future__ import annotations from threading import Lock +from typing import TYPE_CHECKING import torch import torch.distributed.autograd as autograd from torch.distributed.autograd import context -from torchopt.typing import TensorOrTensors, TupleOfOptionalTensors +if TYPE_CHECKING: + from torchopt.typing import TensorOrTensors, TupleOfOptionalTensors -__all__ = ['is_available', 'context'] + +__all__ = ['context', 'is_available'] LOCK = Lock() @@ -121,7 +124,7 @@ def grad( for p in inputs: try: grads.append(all_local_grads[p]) - except KeyError as ex: + except KeyError as ex: # noqa: PERF203 if not allow_unused: raise RuntimeError( 'One of the differentiated Tensors appears to not have been used in the ' @@ -131,4 +134,4 @@ def grad( return tuple(grads) - __all__ += ['DistAutogradContext', 'get_gradients', 'backward', 'grad'] + __all__ += ['DistAutogradContext', 'backward', 'get_gradients', 'grad'] diff --git a/torchopt/distributed/world.py b/torchopt/distributed/world.py index a9821ee0..610e52a0 100644 --- a/torchopt/distributed/world.py +++ b/torchopt/distributed/world.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,19 +26,19 @@ __all__ = [ - 'get_world_info', - 'get_world_rank', - 'get_rank', - 'get_world_size', + 'auto_init_rpc', + 'barrier', 'get_local_rank', 'get_local_world_size', + 'get_rank', 'get_worker_id', - 'barrier', - 'auto_init_rpc', - 'on_rank', + 'get_world_info', + 'get_world_rank', + 'get_world_size', 'not_on_rank', - 'rank_zero_only', + 'on_rank', 'rank_non_zero_only', + 'rank_zero_only', ] diff --git a/torchopt/hook.py b/torchopt/hook.py index 13ed6abf..c11b92f6 100644 --- a/torchopt/hook.py +++ b/torchopt/hook.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,16 +16,19 @@ from __future__ import annotations -from typing import Callable - -import torch +from typing import TYPE_CHECKING, Callable from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation -from torchopt.typing import OptState, Params, Updates -__all__ = ['zero_nan_hook', 'nan_to_num_hook', 'register_hook'] +if TYPE_CHECKING: + import torch + + from torchopt.typing import OptState, Params, Updates + + +__all__ = ['nan_to_num_hook', 'register_hook', 'zero_nan_hook'] def zero_nan_hook(g: torch.Tensor) -> torch.Tensor: diff --git a/torchopt/linalg/__init__.py b/torchopt/linalg/__init__.py index 20dc16aa..fc499d67 100644 --- a/torchopt/linalg/__init__.py +++ b/torchopt/linalg/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/linalg/cg.py b/torchopt/linalg/cg.py index 9cd57cd8..1096a5af 100644 --- a/torchopt/linalg/cg.py +++ b/torchopt/linalg/cg.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,14 +36,17 @@ from __future__ import annotations from functools import partial -from typing import Callable +from typing import TYPE_CHECKING, Callable import torch from torchopt import pytree from torchopt.linalg.utils import cat_shapes, normalize_matvec from torchopt.pytree import tree_vdot_real -from torchopt.typing import TensorTree + + +if TYPE_CHECKING: + from torchopt.typing import TensorTree __all__ = ['cg'] @@ -53,7 +56,7 @@ def _identity(x: TensorTree) -> TensorTree: return x -# pylint: disable-next=too-many-locals +# pylint: disable-next=too-many-arguments,too-many-locals def _cg_solve( A: Callable[[TensorTree], TensorTree], b: TensorTree, @@ -102,6 +105,7 @@ def body_fn( return x_final +# pylint: disable-next=too-many-arguments def _isolve( _isolve_solve: Callable, A: TensorTree | Callable[[TensorTree], TensorTree], @@ -134,6 +138,7 @@ def _isolve( return isolve_solve(A, b) +# pylint: disable-next=too-many-arguments def cg( A: TensorTree | Callable[[TensorTree], TensorTree], b: TensorTree, diff --git a/torchopt/linalg/ns.py b/torchopt/linalg/ns.py index 747ad3cf..5fc8d478 100644 --- a/torchopt/linalg/ns.py +++ b/torchopt/linalg/ns.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,13 +19,16 @@ from __future__ import annotations import functools -from typing import Callable +from typing import TYPE_CHECKING, Callable import torch from torchopt import pytree from torchopt.linalg.utils import normalize_matvec -from torchopt.typing import TensorTree + + +if TYPE_CHECKING: + from torchopt.typing import TensorTree __all__ = ['ns', 'ns_inv'] @@ -123,12 +126,14 @@ def _ns_inv(A: torch.Tensor, maxiter: int, alpha: float | None = None) -> torch. # A^{-1} = a [I - (I - a A)]^{-1} = a [I + (I - a A) + (I - a A)^2 + (I - a A)^3 + ...] M = I - alpha * A for rank in range(maxiter): + # pylint: disable-next=not-callable inv_A_hat = inv_A_hat + torch.linalg.matrix_power(M, rank) inv_A_hat = alpha * inv_A_hat else: # A^{-1} = [I - (I - A)]^{-1} = I + (I - A) + (I - A)^2 + (I - A)^3 + ... M = I - A for rank in range(maxiter): + # pylint: disable-next=not-callable inv_A_hat = inv_A_hat + torch.linalg.matrix_power(M, rank) return inv_A_hat diff --git a/torchopt/linalg/utils.py b/torchopt/linalg/utils.py index e3cd197e..bbcc80aa 100644 --- a/torchopt/linalg/utils.py +++ b/torchopt/linalg/utils.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,12 +17,15 @@ from __future__ import annotations import itertools -from typing import Callable +from typing import TYPE_CHECKING, Callable import torch from torchopt import pytree -from torchopt.typing import TensorTree + + +if TYPE_CHECKING: + from torchopt.typing import TensorTree def cat_shapes(tree: TensorTree) -> tuple[int, ...]: diff --git a/torchopt/linear_solve/__init__.py b/torchopt/linear_solve/__init__.py index 8d9115d3..43ca1da0 100644 --- a/torchopt/linear_solve/__init__.py +++ b/torchopt/linear_solve/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,4 +36,4 @@ from torchopt.linear_solve.normal_cg import solve_normal_cg -__all__ = ['solve_cg', 'solve_normal_cg', 'solve_inv'] +__all__ = ['solve_cg', 'solve_inv', 'solve_normal_cg'] diff --git a/torchopt/linear_solve/cg.py b/torchopt/linear_solve/cg.py index e8f9fb77..23814cc2 100644 --- a/torchopt/linear_solve/cg.py +++ b/torchopt/linear_solve/cg.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,11 +36,14 @@ from __future__ import annotations import functools -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable from torchopt import linalg from torchopt.linear_solve.utils import make_ridge_matvec -from torchopt.typing import LinearSolver, TensorTree + + +if TYPE_CHECKING: + from torchopt.typing import LinearSolver, TensorTree __all__ = ['solve_cg'] diff --git a/torchopt/linear_solve/inv.py b/torchopt/linear_solve/inv.py index e2a377d5..4dbe1542 100644 --- a/torchopt/linear_solve/inv.py +++ b/torchopt/linear_solve/inv.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,13 +36,16 @@ from __future__ import annotations import functools -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable import torch from torchopt import linalg, pytree from torchopt.linear_solve.utils import make_ridge_matvec, materialize_matvec -from torchopt.typing import LinearSolver, TensorTree + + +if TYPE_CHECKING: + from torchopt.typing import LinearSolver, TensorTree __all__ = ['solve_inv'] diff --git a/torchopt/linear_solve/normal_cg.py b/torchopt/linear_solve/normal_cg.py index 78813ecb..a5af49b2 100644 --- a/torchopt/linear_solve/normal_cg.py +++ b/torchopt/linear_solve/normal_cg.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,11 +36,14 @@ from __future__ import annotations import functools -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable from torchopt import linalg from torchopt.linear_solve.utils import make_normal_matvec, make_ridge_matvec, make_rmatvec -from torchopt.typing import LinearSolver, TensorTree + + +if TYPE_CHECKING: + from torchopt.typing import LinearSolver, TensorTree __all__ = ['solve_normal_cg'] diff --git a/torchopt/linear_solve/utils.py b/torchopt/linear_solve/utils.py index 22dcec6f..9d1b8779 100644 --- a/torchopt/linear_solve/utils.py +++ b/torchopt/linear_solve/utils.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,12 +33,15 @@ from __future__ import annotations -from typing import Callable +from typing import TYPE_CHECKING, Callable import functorch from torchopt import pytree -from torchopt.typing import TensorTree + + +if TYPE_CHECKING: + from torchopt.typing import TensorTree def make_rmatvec( diff --git a/torchopt/nn/__init__.py b/torchopt/nn/__init__.py index 8271ad7d..b55e49d7 100644 --- a/torchopt/nn/__init__.py +++ b/torchopt/nn/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,10 +21,10 @@ __all__ = [ - 'MetaGradientModule', 'ImplicitMetaGradientModule', + 'MetaGradientModule', 'ZeroOrderGradientModule', - 'reparametrize', 'reparameterize', + 'reparametrize', 'swap_state', ] diff --git a/torchopt/nn/module.py b/torchopt/nn/module.py index 64623146..8c40f58a 100644 --- a/torchopt/nn/module.py +++ b/torchopt/nn/module.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,14 +17,17 @@ from __future__ import annotations from collections import OrderedDict -from typing import Any, Iterator, NamedTuple +from typing import TYPE_CHECKING, Any, Iterator, NamedTuple from typing_extensions import Self # Python 3.11+ import torch import torch.nn as nn from torchopt import pytree -from torchopt.typing import TensorContainer + + +if TYPE_CHECKING: + from torchopt.typing import TensorContainer class MetaInputsContainer(NamedTuple): @@ -61,7 +64,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=unused """Initialize a new module instance.""" super().__init__() - def __getattr__(self, name: str) -> torch.Tensor | nn.Module: + def __getattr__(self, name: str) -> torch.Tensor | nn.Module: # noqa: C901 """Get an attribute of the module.""" if '_parameters' in self.__dict__: _parameters = self.__dict__['_parameters'] @@ -86,7 +89,7 @@ def __getattr__(self, name: str) -> torch.Tensor | nn.Module: raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") # pylint: disable-next=too-many-branches,too-many-statements - def __setattr__(self, name: str, value: torch.Tensor | nn.Module) -> None: + def __setattr__(self, name: str, value: torch.Tensor | nn.Module) -> None: # noqa: C901 """Set an attribute of the module.""" def remove_from(*dicts_or_sets: dict[str, Any] | set[str]) -> None: diff --git a/torchopt/nn/stateless.py b/torchopt/nn/stateless.py index e547b5cb..c7f92b86 100644 --- a/torchopt/nn/stateless.py +++ b/torchopt/nn/stateless.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,13 +17,15 @@ from __future__ import annotations import contextlib -from typing import Generator, Iterable +from typing import TYPE_CHECKING, Generator, Iterable -import torch -import torch.nn as nn +if TYPE_CHECKING: + import torch + import torch.nn as nn -__all__ = ['swap_state', 'reparametrize', 'reparameterize'] + +__all__ = ['reparameterize', 'reparametrize', 'swap_state'] MISSING: torch.Tensor = object() # type: ignore[assignment] @@ -66,8 +68,8 @@ def recursive_setattr(path: str, value: torch.Tensor) -> torch.Tensor: mod._parameters[attr] = value # type: ignore[assignment] elif hasattr(mod, '_buffers') and attr in mod._buffers: mod._buffers[attr] = value - elif hasattr(mod, '_meta_parameters') and attr in mod._meta_parameters: # type: ignore[operator] - mod._meta_parameters[attr] = value # type: ignore[operator,index] + elif hasattr(mod, '_meta_parameters') and attr in mod._meta_parameters: + mod._meta_parameters[attr] = value else: setattr(mod, attr, value) # pylint: enable=protected-access diff --git a/torchopt/optim/__init__.py b/torchopt/optim/__init__.py index 20da5fca..f620608c 100644 --- a/torchopt/optim/__init__.py +++ b/torchopt/optim/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/adadelta.py b/torchopt/optim/adadelta.py index 7c73cb58..600b69c5 100644 --- a/torchopt/optim/adadelta.py +++ b/torchopt/optim/adadelta.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,13 +16,16 @@ from __future__ import annotations -from typing import Iterable - -import torch +from typing import TYPE_CHECKING, Iterable from torchopt import alias from torchopt.optim.base import Optimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import ScalarOrSchedule __all__ = ['AdaDelta', 'Adadelta'] diff --git a/torchopt/optim/adagrad.py b/torchopt/optim/adagrad.py index a7e8c72b..06091281 100644 --- a/torchopt/optim/adagrad.py +++ b/torchopt/optim/adagrad.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,13 +16,16 @@ from __future__ import annotations -from typing import Iterable - -import torch +from typing import TYPE_CHECKING, Iterable from torchopt import alias from torchopt.optim.base import Optimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import ScalarOrSchedule __all__ = ['AdaGrad', 'Adagrad'] diff --git a/torchopt/optim/adam.py b/torchopt/optim/adam.py index 5d85cbdc..555af22e 100644 --- a/torchopt/optim/adam.py +++ b/torchopt/optim/adam.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,13 +16,16 @@ from __future__ import annotations -from typing import Iterable - -import torch +from typing import TYPE_CHECKING, Iterable from torchopt import alias from torchopt.optim.base import Optimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import ScalarOrSchedule __all__ = ['Adam'] diff --git a/torchopt/optim/adamax.py b/torchopt/optim/adamax.py index 904c05a0..e4996e85 100644 --- a/torchopt/optim/adamax.py +++ b/torchopt/optim/adamax.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,13 +16,16 @@ from __future__ import annotations -from typing import Iterable - -import torch +from typing import TYPE_CHECKING, Iterable from torchopt import alias from torchopt.optim.base import Optimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import ScalarOrSchedule __all__ = ['AdaMax', 'Adamax'] diff --git a/torchopt/optim/adamw.py b/torchopt/optim/adamw.py index be8c6727..a60061ea 100644 --- a/torchopt/optim/adamw.py +++ b/torchopt/optim/adamw.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,13 +16,16 @@ from __future__ import annotations -from typing import Callable, Iterable - -import torch +from typing import TYPE_CHECKING, Callable, Iterable from torchopt import alias from torchopt.optim.base import Optimizer -from torchopt.typing import OptState, Params, ScalarOrSchedule + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import OptState, Params, ScalarOrSchedule __all__ = ['AdamW'] diff --git a/torchopt/optim/base.py b/torchopt/optim/base.py index d0be2fd1..bdaa0d67 100644 --- a/torchopt/optim/base.py +++ b/torchopt/optim/base.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/func/__init__.py b/torchopt/optim/func/__init__.py index f14fc6ae..f136f808 100644 --- a/torchopt/optim/func/__init__.py +++ b/torchopt/optim/func/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/func/base.py b/torchopt/optim/func/base.py index 7a7839a3..fa287f04 100644 --- a/torchopt/optim/func/base.py +++ b/torchopt/optim/func/base.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,13 +16,18 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import torch from torchopt.base import GradientTransformation, UninitializedState -from torchopt.typing import OptState, Params from torchopt.update import apply_updates +if TYPE_CHECKING: + from torchopt.typing import OptState, Params + + __all__ = ['FuncOptimizer'] diff --git a/torchopt/optim/meta/__init__.py b/torchopt/optim/meta/__init__.py index 516f2b5f..9e30dfef 100644 --- a/torchopt/optim/meta/__init__.py +++ b/torchopt/optim/meta/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/meta/adadelta.py b/torchopt/optim/meta/adadelta.py index 36d8d9ad..eb386ae3 100644 --- a/torchopt/optim/meta/adadelta.py +++ b/torchopt/optim/meta/adadelta.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,11 +16,16 @@ from __future__ import annotations -import torch.nn as nn +from typing import TYPE_CHECKING from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch.nn as nn + + from torchopt.typing import ScalarOrSchedule __all__ = ['MetaAdaDelta', 'MetaAdadelta'] diff --git a/torchopt/optim/meta/adagrad.py b/torchopt/optim/meta/adagrad.py index 4e8ef0eb..129c1338 100644 --- a/torchopt/optim/meta/adagrad.py +++ b/torchopt/optim/meta/adagrad.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,11 +16,16 @@ from __future__ import annotations -import torch.nn as nn +from typing import TYPE_CHECKING from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch.nn as nn + + from torchopt.typing import ScalarOrSchedule __all__ = ['MetaAdaGrad', 'MetaAdagrad'] diff --git a/torchopt/optim/meta/adam.py b/torchopt/optim/meta/adam.py index bd9804b9..7a78ea7f 100644 --- a/torchopt/optim/meta/adam.py +++ b/torchopt/optim/meta/adam.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,11 +16,16 @@ from __future__ import annotations -import torch.nn as nn +from typing import TYPE_CHECKING from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch.nn as nn + + from torchopt.typing import ScalarOrSchedule __all__ = ['MetaAdam'] diff --git a/torchopt/optim/meta/adamax.py b/torchopt/optim/meta/adamax.py index 01082af2..d6b40427 100644 --- a/torchopt/optim/meta/adamax.py +++ b/torchopt/optim/meta/adamax.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,11 +16,16 @@ from __future__ import annotations -import torch.nn as nn +from typing import TYPE_CHECKING from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch.nn as nn + + from torchopt.typing import ScalarOrSchedule __all__ = ['MetaAdaMax', 'MetaAdamax'] diff --git a/torchopt/optim/meta/adamw.py b/torchopt/optim/meta/adamw.py index 204a5428..62864582 100644 --- a/torchopt/optim/meta/adamw.py +++ b/torchopt/optim/meta/adamw.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,13 +16,16 @@ from __future__ import annotations -from typing import Callable - -import torch.nn as nn +from typing import TYPE_CHECKING, Callable from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer -from torchopt.typing import OptState, Params, ScalarOrSchedule + + +if TYPE_CHECKING: + import torch.nn as nn + + from torchopt.typing import OptState, Params, ScalarOrSchedule __all__ = ['MetaAdamW'] diff --git a/torchopt/optim/meta/base.py b/torchopt/optim/meta/base.py index 54327f3b..73ecdde7 100644 --- a/torchopt/optim/meta/base.py +++ b/torchopt/optim/meta/base.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/meta/radam.py b/torchopt/optim/meta/radam.py index baf4cdd2..bb07b5ba 100644 --- a/torchopt/optim/meta/radam.py +++ b/torchopt/optim/meta/radam.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,11 +16,16 @@ from __future__ import annotations -import torch.nn as nn +from typing import TYPE_CHECKING from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch.nn as nn + + from torchopt.typing import ScalarOrSchedule __all__ = ['MetaRAdam'] diff --git a/torchopt/optim/meta/rmsprop.py b/torchopt/optim/meta/rmsprop.py index 3aff20e1..a8b4abfa 100644 --- a/torchopt/optim/meta/rmsprop.py +++ b/torchopt/optim/meta/rmsprop.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/meta/sgd.py b/torchopt/optim/meta/sgd.py index 476ed9d6..81e04413 100644 --- a/torchopt/optim/meta/sgd.py +++ b/torchopt/optim/meta/sgd.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/radam.py b/torchopt/optim/radam.py index c2f6a211..20e9dd22 100644 --- a/torchopt/optim/radam.py +++ b/torchopt/optim/radam.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,13 +16,16 @@ from __future__ import annotations -from typing import Iterable - -import torch +from typing import TYPE_CHECKING, Iterable from torchopt import alias from torchopt.optim.base import Optimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import ScalarOrSchedule __all__ = ['RAdam'] diff --git a/torchopt/optim/rmsprop.py b/torchopt/optim/rmsprop.py index 5c4e536f..032e5864 100644 --- a/torchopt/optim/rmsprop.py +++ b/torchopt/optim/rmsprop.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/sgd.py b/torchopt/optim/sgd.py index 3da9595a..27cd53c1 100644 --- a/torchopt/optim/sgd.py +++ b/torchopt/optim/sgd.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/pytree.py b/torchopt/pytree.py index 6d41d0fa..53abc2d2 100644 --- a/torchopt/pytree.py +++ b/torchopt/pytree.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ import functools import operator -from typing import Callable +from typing import TYPE_CHECKING, Callable import optree import optree.typing as typing # pylint: disable=unused-import @@ -26,7 +26,9 @@ import torch.distributed.rpc as rpc from optree import * # pylint: disable=wildcard-import,unused-wildcard-import -from torchopt.typing import Future, RRef, Scalar, T, TensorTree + +if TYPE_CHECKING: + from torchopt.typing import Future, RRef, Scalar, T, TensorTree __all__ = [ diff --git a/torchopt/schedule/__init__.py b/torchopt/schedule/__init__.py index b9916783..d3d3eff5 100644 --- a/torchopt/schedule/__init__.py +++ b/torchopt/schedule/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -35,4 +35,4 @@ from torchopt.schedule.polynomial import linear_schedule, polynomial_schedule -__all__ = ['exponential_decay', 'polynomial_schedule', 'linear_schedule'] +__all__ = ['exponential_decay', 'linear_schedule', 'polynomial_schedule'] diff --git a/torchopt/schedule/exponential_decay.py b/torchopt/schedule/exponential_decay.py index 8811b353..c19c54b9 100644 --- a/torchopt/schedule/exponential_decay.py +++ b/torchopt/schedule/exponential_decay.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,11 +31,15 @@ # ============================================================================== """Exponential learning rate decay.""" +from __future__ import annotations + import logging import math -from typing import Optional +from typing import TYPE_CHECKING + -from torchopt.typing import Numeric, Scalar, Schedule +if TYPE_CHECKING: + from torchopt.typing import Numeric, Scalar, Schedule __all__ = ['exponential_decay'] @@ -48,7 +52,7 @@ def exponential_decay( transition_begin: int = 0, transition_steps: int = 1, staircase: bool = False, - end_value: Optional[float] = None, + end_value: float | None = None, ) -> Schedule: """Construct a schedule with either continuous or discrete exponential decay. diff --git a/torchopt/schedule/polynomial.py b/torchopt/schedule/polynomial.py index 39629c38..d2a5160c 100644 --- a/torchopt/schedule/polynomial.py +++ b/torchopt/schedule/polynomial.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,15 +31,20 @@ # ============================================================================== """Polynomial learning rate schedules.""" +from __future__ import annotations + import logging +from typing import TYPE_CHECKING import numpy as np import torch -from torchopt.typing import Numeric, Scalar, Schedule + +if TYPE_CHECKING: + from torchopt.typing import Numeric, Scalar, Schedule -__all__ = ['polynomial_schedule', 'linear_schedule'] +__all__ = ['linear_schedule', 'polynomial_schedule'] def polynomial_schedule( diff --git a/torchopt/transform/__init__.py b/torchopt/transform/__init__.py index c75fcb5d..fa59a43b 100644 --- a/torchopt/transform/__init__.py +++ b/torchopt/transform/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -46,18 +46,18 @@ __all__ = [ - 'trace', - 'scale', - 'scale_by_schedule', 'add_decayed_weights', 'masked', + 'nan_to_num', + 'scale', + 'scale_by_accelerated_adam', + 'scale_by_adadelta', 'scale_by_adam', 'scale_by_adamax', - 'scale_by_adadelta', 'scale_by_radam', - 'scale_by_accelerated_adam', - 'scale_by_rss', 'scale_by_rms', + 'scale_by_rss', + 'scale_by_schedule', 'scale_by_stddev', - 'nan_to_num', + 'trace', ] diff --git a/torchopt/transform/add_decayed_weights.py b/torchopt/transform/add_decayed_weights.py index 04d564d7..0cb67837 100644 --- a/torchopt/transform/add_decayed_weights.py +++ b/torchopt/transform/add_decayed_weights.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -34,17 +34,20 @@ from __future__ import annotations -from typing import Any, Callable, NamedTuple - -import torch +from typing import TYPE_CHECKING, Any, Callable, NamedTuple from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation, identity from torchopt.transform.utils import tree_map_flat, tree_map_flat_ -from torchopt.typing import OptState, Params, Updates -__all__ = ['masked', 'add_decayed_weights'] +if TYPE_CHECKING: + import torch + + from torchopt.typing import OptState, Params, Updates + + +__all__ = ['add_decayed_weights', 'masked'] class MaskedState(NamedTuple): @@ -189,7 +192,7 @@ def _add_decayed_weights_flat( ) -def _add_decayed_weights( +def _add_decayed_weights( # noqa: C901 weight_decay: float = 0.0, mask: OptState | Callable[[Params], OptState] | None = None, *, @@ -226,19 +229,21 @@ def update_fn( if inplace: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if g.requires_grad: return g.add_(p, alpha=weight_decay) return g.add_(p.data, alpha=weight_decay) - updates = tree_map_(f, updates, params) + tree_map_(f, params, updates) else: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: - return g.add(p, alpha=weight_decay) + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.add(p, alpha=weight_decay) if g is not None else g - updates = tree_map(f, updates, params) + updates = tree_map(f, params, updates) return updates, state diff --git a/torchopt/transform/nan_to_num.py b/torchopt/transform/nan_to_num.py index 27d87499..740df1b0 100644 --- a/torchopt/transform/nan_to_num.py +++ b/torchopt/transform/nan_to_num.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,11 +16,16 @@ from __future__ import annotations -import torch +from typing import TYPE_CHECKING from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import OptState, Params, Updates def nan_to_num( diff --git a/torchopt/transform/scale.py b/torchopt/transform/scale.py index c731003c..2b492bdf 100644 --- a/torchopt/transform/scale.py +++ b/torchopt/transform/scale.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,12 +33,17 @@ from __future__ import annotations -import torch +from typing import TYPE_CHECKING from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation from torchopt.transform.utils import tree_map_flat, tree_map_flat_ -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import OptState, Params, Updates __all__ = ['scale'] diff --git a/torchopt/transform/scale_by_adadelta.py b/torchopt/transform/scale_by_adadelta.py index fb5431a3..6d05e5dd 100644 --- a/torchopt/transform/scale_by_adadelta.py +++ b/torchopt/transform/scale_by_adadelta.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,14 +19,17 @@ from __future__ import annotations -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation from torchopt.transform.utils import tree_map_flat, update_moment -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates __all__ = ['scale_by_adadelta'] @@ -129,23 +132,15 @@ def update_fn( if inplace: - def f( - g: torch.Tensor, # pylint: disable=unused-argument - m: torch.Tensor, - v: torch.Tensor, - ) -> torch.Tensor: - return g.mul_(v.add(eps).div_(m.add(eps)).sqrt_()) + def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.mul_(v.add(eps).div_(m.add(eps)).sqrt_()) if g is not None else g else: - def f( - g: torch.Tensor, # pylint: disable=unused-argument - m: torch.Tensor, - v: torch.Tensor, - ) -> torch.Tensor: - return g.mul(v.add(eps).div_(m.add(eps)).sqrt_()) + def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.mul(v.add(eps).div_(m.add(eps)).sqrt_()) if g is not None else g - updates = tree_map(f, updates, mu, state.nu) + updates = tree_map(f, mu, state.nu, updates) nu = update_moment.impl( # type: ignore[attr-defined] updates, diff --git a/torchopt/transform/scale_by_adam.py b/torchopt/transform/scale_by_adam.py index c3c6254e..d45d1eb2 100644 --- a/torchopt/transform/scale_by_adam.py +++ b/torchopt/transform/scale_by_adam.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ from __future__ import annotations -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple import torch @@ -43,10 +43,13 @@ from torchopt.accelerated_op import AdamOp from torchopt.base import GradientTransformation from torchopt.transform.utils import inc_count, tree_map_flat, update_moment -from torchopt.typing import OptState, Params, Updates -__all__ = ['scale_by_adam', 'scale_by_accelerated_adam'] +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates + + +__all__ = ['scale_by_accelerated_adam', 'scale_by_adam'] TRIPLE_PYTREE_SPEC = pytree.tree_structure((0, 1, 2), none_is_leaf=True) # type: ignore[arg-type] @@ -132,6 +135,7 @@ def _scale_by_adam_flat( ) +# pylint: disable-next=too-many-arguments def _scale_by_adam( b1: float = 0.9, b2: float = 0.999, @@ -200,23 +204,15 @@ def update_fn( if inplace: - def f( - g: torch.Tensor, # pylint: disable=unused-argument - m: torch.Tensor, - v: torch.Tensor, - ) -> torch.Tensor: - return m.div_(v.add_(eps_root).sqrt_().add(eps)) + def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return m.div_(v.add_(eps_root).sqrt_().add(eps)) if g is not None else g else: - def f( - g: torch.Tensor, # pylint: disable=unused-argument - m: torch.Tensor, - v: torch.Tensor, - ) -> torch.Tensor: - return m.div(v.add(eps_root).sqrt_().add(eps)) + def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return m.div(v.add(eps_root).sqrt_().add(eps)) if g is not None else g - updates = tree_map(f, updates, mu_hat, nu_hat) + updates = tree_map(f, mu_hat, nu_hat, updates) return updates, ScaleByAdamState(mu=mu, nu=nu, count=count_inc) return GradientTransformation(init_fn, update_fn) @@ -283,7 +279,8 @@ def _scale_by_accelerated_adam_flat( ) -def _scale_by_accelerated_adam( +# pylint: disable-next=too-many-arguments +def _scale_by_accelerated_adam( # noqa: C901 b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8, diff --git a/torchopt/transform/scale_by_adamax.py b/torchopt/transform/scale_by_adamax.py index 504e82cd..cfacbf35 100644 --- a/torchopt/transform/scale_by_adamax.py +++ b/torchopt/transform/scale_by_adamax.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,14 +19,17 @@ from __future__ import annotations -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation from torchopt.transform.utils import tree_map_flat, update_moment -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates __all__ = ['scale_by_adamax'] @@ -137,23 +140,17 @@ def update_fn( already_flattened=already_flattened, ) - def update_nu( - g: torch.Tensor, - n: torch.Tensor, - ) -> torch.Tensor: - return torch.max(n.mul(b2), g.abs().add_(eps)) + def update_nu(n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return torch.max(n.mul(b2), g.abs().add_(eps)) if g is not None else g - nu = tree_map(update_nu, updates, state.nu) + nu = tree_map(update_nu, state.nu, updates) one_minus_b1_pow_t = 1 - b1**state.t - def f( - n: torch.Tensor, - m: torch.Tensor, - ) -> torch.Tensor: - return m.div(n).div_(one_minus_b1_pow_t) + def f(m: torch.Tensor, n: torch.Tensor | None) -> torch.Tensor: + return m.div(n).div_(one_minus_b1_pow_t) if n is not None else m - updates = tree_map(f, nu, mu) + updates = tree_map(f, mu, nu) return updates, ScaleByAdamaxState(mu=mu, nu=nu, t=state.t + 1) diff --git a/torchopt/transform/scale_by_radam.py b/torchopt/transform/scale_by_radam.py index acb85a82..95f26149 100644 --- a/torchopt/transform/scale_by_radam.py +++ b/torchopt/transform/scale_by_radam.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,14 +20,17 @@ from __future__ import annotations import math -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation from torchopt.transform.utils import tree_map_flat, update_moment -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates __all__ = ['scale_by_radam'] @@ -89,7 +92,7 @@ def _scale_by_radam_flat( ) -def _scale_by_radam( +def _scale_by_radam( # noqa: C901 b1: float = 0.9, b2: float = 0.999, eps: float = 1e-6, diff --git a/torchopt/transform/scale_by_rms.py b/torchopt/transform/scale_by_rms.py index ac2fef16..f2141388 100644 --- a/torchopt/transform/scale_by_rms.py +++ b/torchopt/transform/scale_by_rms.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,14 +33,17 @@ from __future__ import annotations -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation from torchopt.transform.utils import tree_map_flat, tree_map_flat_, update_moment -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates __all__ = ['scale_by_rms'] @@ -135,18 +138,18 @@ def update_fn( ) if inplace: + # pylint: disable-next=invalid-name + def f(n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.div_(n.sqrt().add_(eps)) if g is not None else g - def f(g: torch.Tensor, n: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name - return g.div_(n.sqrt().add_(eps)) - - updates = tree_map_(f, updates, nu) + tree_map_(f, nu, updates) else: + # pylint: disable-next=invalid-name + def f(n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.div(n.sqrt().add(eps)) if g is not None else g - def f(g: torch.Tensor, n: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name - return g.div(n.sqrt().add(eps)) - - updates = tree_map(f, updates, nu) + updates = tree_map(f, nu, updates) return updates, ScaleByRmsState(nu=nu) diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py index 68021e5e..642b2e5c 100644 --- a/torchopt/transform/scale_by_rss.py +++ b/torchopt/transform/scale_by_rss.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,14 +33,17 @@ from __future__ import annotations -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation from torchopt.transform.utils import tree_map_flat, update_moment -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates __all__ = ['scale_by_rss'] @@ -128,23 +131,21 @@ def update_fn( if inplace: - def f(g: torch.Tensor, sos: torch.Tensor) -> torch.Tensor: - return torch.where( - sos > 0.0, - g.div_(sos.sqrt().add_(eps)), - 0.0, + def f(sos: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return ( + torch.where(sos > 0.0, g.div_(sos.sqrt().add_(eps)), 0.0) + if g is not None + else g ) else: - def f(g: torch.Tensor, sos: torch.Tensor) -> torch.Tensor: - return torch.where( - sos > 0.0, - g.div(sos.sqrt().add(eps)), - 0.0, + def f(sos: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return ( + torch.where(sos > 0.0, g.div(sos.sqrt().add(eps)), 0.0) if g is not None else g ) - updates = tree_map(f, updates, sum_of_squares) + updates = tree_map(f, sum_of_squares, updates) return updates, ScaleByRssState(sum_of_squares=sum_of_squares) return GradientTransformation(init_fn, update_fn) diff --git a/torchopt/transform/scale_by_schedule.py b/torchopt/transform/scale_by_schedule.py index f27fb7e8..499e2adb 100644 --- a/torchopt/transform/scale_by_schedule.py +++ b/torchopt/transform/scale_by_schedule.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,14 +33,17 @@ from __future__ import annotations -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation from torchopt.transform.utils import inc_count, tree_map_flat, tree_map_flat_ -from torchopt.typing import Numeric, OptState, Params, Schedule, SequenceOfTensors, Updates + + +if TYPE_CHECKING: + from torchopt.typing import Numeric, OptState, Params, Schedule, SequenceOfTensors, Updates __all__ = ['scale_by_schedule'] @@ -96,20 +99,24 @@ def update_fn( inplace: bool = True, ) -> tuple[Updates, OptState]: if inplace: - - def f(g: torch.Tensor, c: Numeric) -> torch.Tensor: # pylint: disable=invalid-name + # pylint: disable-next=invalid-name + def f(c: Numeric, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g step_size = step_size_fn(c) return g.mul_(step_size) - updates = tree_map_(f, updates, state.count) + tree_map_(f, state.count, updates) else: - - def f(g: torch.Tensor, c: Numeric) -> torch.Tensor: # pylint: disable=invalid-name + # pylint: disable-next=invalid-name + def f(c: Numeric, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g step_size = step_size_fn(c) return g.mul(step_size) - updates = tree_map(f, updates, state.count) + updates = tree_map(f, state.count, updates) return ( updates, diff --git a/torchopt/transform/scale_by_stddev.py b/torchopt/transform/scale_by_stddev.py index bbbfb384..5a3e6655 100644 --- a/torchopt/transform/scale_by_stddev.py +++ b/torchopt/transform/scale_by_stddev.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -35,14 +35,17 @@ from __future__ import annotations -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation from torchopt.transform.utils import tree_map_flat, tree_map_flat_, update_moment -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates __all__ = ['scale_by_stddev'] @@ -148,17 +151,17 @@ def update_fn( if inplace: - def f(g: torch.Tensor, m: torch.Tensor, n: torch.Tensor) -> torch.Tensor: - return g.div_(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) + def f(m: torch.Tensor, n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.div_(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) if g is not None else g - updates = tree_map_(f, updates, mu, nu) + tree_map_(f, mu, nu, updates) else: - def f(g: torch.Tensor, m: torch.Tensor, n: torch.Tensor) -> torch.Tensor: - return g.div(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) + def f(m: torch.Tensor, n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.div(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) if g is not None else g - updates = tree_map(f, updates, mu, nu) + updates = tree_map(f, mu, nu, updates) return updates, ScaleByRStdDevState(mu=mu, nu=nu) diff --git a/torchopt/transform/trace.py b/torchopt/transform/trace.py index 7a1e1971..219cbbec 100644 --- a/torchopt/transform/trace.py +++ b/torchopt/transform/trace.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -35,14 +35,17 @@ from __future__ import annotations -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation, identity from torchopt.transform.utils import tree_map_flat, tree_map_flat_ -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates __all__ = ['trace'] @@ -101,7 +104,7 @@ def _trace_flat( ) -def _trace( +def _trace( # noqa: C901 momentum: float = 0.9, dampening: float = 0.0, nesterov: bool = False, @@ -136,7 +139,7 @@ def init_fn(params: Params) -> OptState: first_call = True - def update_fn( + def update_fn( # noqa: C901 updates: Updates, state: OptState, *, @@ -148,52 +151,60 @@ def update_fn( if nesterov: if inplace: - def f1(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def f1(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if first_call: return t.add_(g) return t.mul_(momentum).add_(g) - def f2(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: - return g.add_(t, alpha=momentum) + def f2(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.add_(t, alpha=momentum) if g is not None else g - new_trace = tree_map(f1, updates, state.trace) - updates = tree_map_(f2, updates, new_trace) + new_trace = tree_map(f1, state.trace, updates) + tree_map_(f2, new_trace, updates) else: - def f1(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def f1(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if first_call: return t.add(g) return t.mul(momentum).add_(g) - def f2(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: - return g.add(t, alpha=momentum) + def f2(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.add(t, alpha=momentum) if g is not None else g - new_trace = tree_map(f1, updates, state.trace) - updates = tree_map(f2, updates, new_trace) + new_trace = tree_map(f1, state.trace, updates) + updates = tree_map(f2, new_trace, updates) else: if inplace: - def f(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def f(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if first_call: return t.add_(g) return t.mul_(momentum).add_(g, alpha=1.0 - dampening) - def copy_(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: - return g.copy_(t) + def copy_to_(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.copy_(t) if g is not None else g - new_trace = tree_map(f, updates, state.trace) - updates = tree_map_(copy_, updates, new_trace) + new_trace = tree_map(f, state.trace, updates) + tree_map_(copy_to_, new_trace, updates) else: - def f(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def f(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if first_call: return t.add(g) return t.mul(momentum).add_(g, alpha=1.0 - dampening) - new_trace = tree_map(f, updates, state.trace) + new_trace = tree_map(f, state.trace, updates) updates = tree_map(torch.clone, new_trace) first_call = False diff --git a/torchopt/transform/utils.py b/torchopt/transform/utils.py index 8c67fd7e..9b38d561 100644 --- a/torchopt/transform/utils.py +++ b/torchopt/transform/utils.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -34,15 +34,18 @@ from __future__ import annotations from collections import deque -from typing import Any, Callable, Sequence +from typing import TYPE_CHECKING, Any, Callable, Sequence import torch from torchopt import pytree -from torchopt.typing import TensorTree, Updates -__all__ = ['tree_map_flat', 'tree_map_flat_', 'inc_count', 'update_moment'] +if TYPE_CHECKING: + from torchopt.typing import TensorTree, Updates + + +__all__ = ['inc_count', 'tree_map_flat', 'tree_map_flat_', 'update_moment'] INT64_MAX = torch.iinfo(torch.int64).max @@ -160,7 +163,8 @@ def _update_moment_flat( ) -def _update_moment( +# pylint: disable-next=too-many-arguments +def _update_moment( # noqa: C901 updates: Updates, moments: TensorTree, decay: float, diff --git a/torchopt/typing.py b/torchopt/typing.py index c5c76984..fcd888fb 100644 --- a/torchopt/typing.py +++ b/torchopt/typing.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +14,8 @@ # ============================================================================== """Typing utilities.""" +from __future__ import annotations + import abc from typing import ( Callable, @@ -45,39 +47,39 @@ __all__ = [ - 'GradientTransformation', 'ChainedGradientTransformation', + 'Device', + 'Distribution', 'EmptyState', - 'UninitializedState', - 'Params', - 'Updates', + 'Future', + 'GradientTransformation', + 'LinearSolver', + 'ListOfOptionalTensors', + 'ListOfTensors', + 'ModuleTensorContainers', + 'Numeric', 'OptState', + 'OptionalTensor', + 'OptionalTensorOrOptionalTensors', + 'OptionalTensorTree', + 'Params', + 'PyTree', + 'Samplable', + 'SampleFunc', 'Scalar', - 'Numeric', - 'Schedule', 'ScalarOrSchedule', - 'PyTree', - 'Tensor', - 'OptionalTensor', - 'ListOfTensors', - 'TupleOfTensors', + 'Schedule', + 'SequenceOfOptionalTensors', 'SequenceOfTensors', + 'Size', + 'Tensor', + 'TensorContainer', 'TensorOrTensors', 'TensorTree', - 'ListOfOptionalTensors', 'TupleOfOptionalTensors', - 'SequenceOfOptionalTensors', - 'OptionalTensorOrOptionalTensors', - 'OptionalTensorTree', - 'TensorContainer', - 'ModuleTensorContainers', - 'Future', - 'LinearSolver', - 'Device', - 'Size', - 'Distribution', - 'SampleFunc', - 'Samplable', + 'TupleOfTensors', + 'UninitializedState', + 'Updates', ] T = TypeVar('T') @@ -138,7 +140,7 @@ class Samplable(Protocol): # pylint: disable=too-few-public-methods def sample( self, sample_shape: Size = Size(), # noqa: B008 # pylint: disable=unused-argument - ) -> Union[Tensor, Sequence[Numeric]]: + ) -> Tensor | Sequence[Numeric]: # pylint: disable-next=line-too-long """Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.""" raise NotImplementedError diff --git a/torchopt/update.py b/torchopt/update.py index 3a2a6984..3f2d71fe 100644 --- a/torchopt/update.py +++ b/torchopt/update.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,10 +33,15 @@ from __future__ import annotations -import torch +from typing import TYPE_CHECKING from torchopt import pytree -from torchopt.typing import Params, Updates + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import Params, Updates __all__ = ['apply_updates'] diff --git a/torchopt/utils.py b/torchopt/utils.py index 5414db80..5f9202a3 100644 --- a/torchopt/utils.py +++ b/torchopt/utils.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -34,11 +34,11 @@ __all__ = [ 'ModuleState', - 'stop_gradient', 'extract_state_dict', - 'recover_state_dict', 'module_clone', 'module_detach_', + 'recover_state_dict', + 'stop_gradient', ] @@ -91,7 +91,7 @@ def fn_(obj: Any) -> None: @overload -def extract_state_dict( +def extract_state_dict( # pylint: disable=too-many-arguments target: nn.Module, *, by: CopyMode = 'reference', @@ -114,8 +114,8 @@ def extract_state_dict( ... -# pylint: disable-next=too-many-branches,too-many-locals -def extract_state_dict( +# pylint: disable-next=too-many-arguments,too-many-branches,too-many-locals +def extract_state_dict( # noqa: C901 target: nn.Module | MetaOptimizer, *, by: CopyMode = 'reference', @@ -272,7 +272,7 @@ def get_variable(t: torch.Tensor | None) -> torch.Tensor | None: return pytree.tree_map(get_variable, state) # type: ignore[arg-type,return-value] - raise RuntimeError(f'Unexpected class of {target}') + raise TypeError(f'Unexpected class of {target}') def extract_module_containers( @@ -346,7 +346,7 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor: state = cast(Sequence[OptState], state) target.load_state_dict(state) else: - raise RuntimeError(f'Unexpected class of {target}') + raise TypeError(f'Unexpected class of {target}') @overload @@ -383,7 +383,7 @@ def module_clone( # pylint: disable-next=too-many-locals -def module_clone( +def module_clone( # noqa: C901 target: nn.Module | MetaOptimizer | TensorTree, *, by: CopyMode = 'reference', diff --git a/torchopt/version.py b/torchopt/version.py index 87c4fe49..9fdcac9b 100644 --- a/torchopt/version.py +++ b/torchopt/version.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ # ============================================================================== """TorchOpt: a high-performance optimizer library built upon PyTorch.""" -__version__ = '0.7.2' +__version__ = '0.7.3' __license__ = 'Apache License, Version 2.0' __author__ = 'TorchOpt Contributors' __release__ = False @@ -25,8 +25,8 @@ try: prefix, sep, suffix = ( - subprocess.check_output( - ['git', 'describe', '--abbrev=7'], # noqa: S603,S607 + subprocess.check_output( # noqa: S603 + ['git', 'describe', '--abbrev=7'], # noqa: S607 cwd=os.path.dirname(os.path.abspath(__file__)), stderr=subprocess.DEVNULL, text=True, @@ -40,7 +40,7 @@ if sep: version_prefix, dot, version_tail = prefix.rpartition('.') prefix = f'{version_prefix}{dot}{int(version_tail) + 1}' - __version__ = sep.join((prefix, suffix)) + __version__ = f'{prefix}{sep}{suffix}' del version_prefix, dot, version_tail else: __version__ = prefix diff --git a/torchopt/visual.py b/torchopt/visual.py index 47a7f5d5..7638d7ec 100644 --- a/torchopt/visual.py +++ b/torchopt/visual.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,16 +19,19 @@ from __future__ import annotations -from typing import Any, Generator, Iterable, Mapping, cast +from typing import TYPE_CHECKING, Any, Generator, Iterable, Mapping, cast import torch from graphviz import Digraph from torchopt import pytree -from torchopt.typing import TensorTree from torchopt.utils import ModuleState +if TYPE_CHECKING: + from torchopt.typing import TensorTree + + __all__ = ['make_dot', 'resize_graph'] @@ -69,7 +72,7 @@ def truncate(s: str) -> str: # pylint: disable=invalid-name # pylint: disable-next=too-many-branches,too-many-statements,too-many-locals -def make_dot( +def make_dot( # noqa: C901 var: TensorTree, params: ( Mapping[str, torch.Tensor] @@ -145,7 +148,7 @@ def size_to_str(size: tuple[int, ...]) -> str: def get_var_name(var: torch.Tensor, name: str | None = None) -> str: if not name: - name = param_map[var] if var in param_map else '' + name = param_map.get(var, '') return f'{name}\n{size_to_str(var.size())}' def get_var_name_with_flag(var: torch.Tensor) -> str | None: @@ -153,7 +156,7 @@ def get_var_name_with_flag(var: torch.Tensor) -> str | None: return f'{param_map[var][0]}\n{size_to_str(param_map[var][1].size())}' return None - def add_nodes(fn: Any) -> None: # pylint: disable=too-many-branches + def add_nodes(fn: Any) -> None: # noqa: C901 # pylint: disable=too-many-branches assert not isinstance(fn, torch.Tensor) if fn in seen: return diff --git a/tutorials/requirements.txt b/tutorials/requirements.txt index ff5a5c42..e8a3be95 100644 --- a/tutorials/requirements.txt +++ b/tutorials/requirements.txt @@ -1,11 +1,11 @@ ---extra-index-url https://download.pytorch.org/whl/cu117 +--extra-index-url https://download.pytorch.org/whl/cu121 # Sync with project.dependencies -torch >= 1.13 +torch >= 2.0 torchvision --requirement ../requirements.txt ipykernel -jax[cpu] >= 0.3 +jax[cpu] >= 0.4 jaxopt optax pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy