diff --git a/.flake8 b/.flake8
new file mode 100644
index 00000000..82919783
--- /dev/null
+++ b/.flake8
@@ -0,0 +1,41 @@
+[flake8]
+max-line-length = 120
+max-doc-length = 100
+select = B,C,E,F,W,Y,SIM
+ignore =
+ # E203: whitespace before ':'
+ # W503: line break before binary operator
+ # W504: line break after binary operator
+ # format by black
+ E203,W503,W504,
+ # E501: line too long
+ # W505: doc line too long
+ # too long docstring due to long example blocks
+ E501,W505,
+per-file-ignores =
+ # F401: module imported but unused
+ # intentionally unused imports
+ __init__.py: F401
+ # F401: module imported but unused
+ # F403: unable to detect undefined names
+ # F405: member mey be undefined, or defined from star imports
+ # members populated from optree
+ torchopt/pytree.py: F401,F403,F405
+ # E301: expected 1 blank line
+ # E302: expected 2 blank lines
+ # E305: expected 2 blank lines after class or function definition
+ # E701: multiple statements on one line (colon)
+ # E704: multiple statements on one line (def)
+ # format by black
+ *.pyi: E301,E302,E305,E701,E704
+exclude =
+ .git,
+ .vscode,
+ venv,
+ third-party,
+ __pycache__,
+ docs/source/conf.py,
+ build,
+ dist,
+ examples,
+ tests
diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 71ba3dd1..eb6753cc 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -56,7 +56,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
- python-version: "3.7 - 3.10" # sync with requires-python in pyproject.toml
+ python-version: "3.8 - 3.11" # sync with requires-python in pyproject.toml
update-environment: true
- name: Set __release__
@@ -96,16 +96,16 @@ jobs:
run: |
make pytest
- build-wheels-py37:
+ build-wheels-py38:
name: Build wheels for Python ${{ matrix.python-version }} on ubuntu-latest
runs-on: ubuntu-latest
needs: [build]
if: github.repository == 'metaopt/torchopt' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/'))
strategy:
matrix:
- python-version: ["3.7"] # sync with requires-python in pyproject.toml
+ python-version: ["3.8"] # sync with requires-python in pyproject.toml
fail-fast: false
- timeout-minutes: 30
+ timeout-minutes: 60
steps:
- name: Checkout
uses: actions/checkout@v3
@@ -132,7 +132,7 @@ jobs:
run: python .github/workflows/set_cibw_build.py
- name: Build wheels
- uses: pypa/cibuildwheel@v2.12.0
+ uses: pypa/cibuildwheel@v2.12.3
env:
CIBW_BUILD: ${{ env.CIBW_BUILD }}
with:
@@ -142,20 +142,20 @@ jobs:
- uses: actions/upload-artifact@v3
with:
- name: wheels-py37
+ name: wheels-py38
path: wheelhouse/*.whl
if-no-files-found: error
build-wheels:
name: Build wheels for Python ${{ matrix.python-version }} on ubuntu-latest
runs-on: ubuntu-latest
- needs: [build, build-wheels-py37]
+ needs: [build, build-wheels-py38]
if: github.repository == 'metaopt/torchopt' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/'))
strategy:
matrix:
- python-version: ["3.8", "3.9", "3.10"] # sync with requires-python in pyproject.toml
+ python-version: ["3.9", "3.10", "3.11"] # sync with requires-python in pyproject.toml
fail-fast: false
- timeout-minutes: 30
+ timeout-minutes: 60
steps:
- name: Checkout
uses: actions/checkout@v3
@@ -182,7 +182,7 @@ jobs:
run: python .github/workflows/set_cibw_build.py
- name: Build wheels
- uses: pypa/cibuildwheel@v2.12.0
+ uses: pypa/cibuildwheel@v2.12.3
env:
CIBW_BUILD: ${{ env.CIBW_BUILD }}
with:
@@ -198,7 +198,7 @@ jobs:
publish:
runs-on: ubuntu-latest
- needs: [build, build-wheels-py37, build-wheels]
+ needs: [build, build-wheels-py38, build-wheels]
if: |
github.repository == 'metaopt/torchopt' && github.event_name != 'pull_request' &&
(github.event_name != 'workflow_dispatch' || github.event.inputs.task == 'build-and-publish') &&
@@ -215,7 +215,7 @@ jobs:
uses: actions/setup-python@v4
if: startsWith(github.ref, 'refs/tags/')
with:
- python-version: "3.7 - 3.11" # sync with requires-python in pyproject.toml
+ python-version: "3.8 - 3.11" # sync with requires-python in pyproject.toml
update-environment: true
- name: Set __release__
@@ -249,7 +249,7 @@ jobs:
with:
# unpacks default artifact into dist/
# if `name: artifact` is omitted, the action will create extra parent dir
- name: wheels-py37
+ name: wheels-py38
path: dist
- name: Download built wheels
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index 55dee661..19c0cf5b 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -71,6 +71,10 @@ jobs:
run: |
make pre-commit
+ - name: ruff
+ run: |
+ make ruff
+
- name: flake8
run: |
make flake8
@@ -89,6 +93,12 @@ jobs:
- name: clang-format
run: |
+ (
+ source /etc/os-release
+ wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key | sudo tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc
+ sudo add-apt-repository "deb http://apt.llvm.org/${UBUNTU_CODENAME} llvm-toolchain-${UBUNTU_CODENAME} main" --yes
+ )
+ sudo apt-get update && sudo apt-get install clang-format --yes
make clang-format
- name: clang-tidy
diff --git a/.github/workflows/set_cibw_build.py b/.github/workflows/set_cibw_build.py
index 03838b4a..ec4383f4 100755
--- a/.github/workflows/set_cibw_build.py
+++ b/.github/workflows/set_cibw_build.py
@@ -10,5 +10,5 @@
CIBW_BUILD = 'CIBW_BUILD=*cp%d%d-*manylinux*' % sys.version_info[:2]
print(CIBW_BUILD)
-with open(os.getenv('GITHUB_ENV'), mode='a', encoding='UTF-8') as file:
+with open(os.getenv('GITHUB_ENV'), mode='a', encoding='utf-8') as file:
print(CIBW_BUILD, file=file)
diff --git a/.github/workflows/set_release.py b/.github/workflows/set_release.py
index 568a38e2..6c437f19 100755
--- a/.github/workflows/set_release.py
+++ b/.github/workflows/set_release.py
@@ -10,7 +10,7 @@
VERSION_FILE = ROOT / 'torchopt' / 'version.py'
-VERSION_CONTENT = VERSION_FILE.read_text(encoding='UTF-8')
+VERSION_CONTENT = VERSION_FILE.read_text(encoding='utf-8')
VERSION_FILE.write_text(
data=re.sub(
@@ -18,5 +18,5 @@
'__release__ = True',
string=VERSION_CONTENT,
),
- encoding='UTF-8',
+ encoding='utf-8',
)
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 8bee5b9d..4f6fad50 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -41,10 +41,10 @@ jobs:
submodules: "recursive"
fetch-depth: 1
- - name: Set up Python 3.7
+ - name: Set up Python 3.8
uses: actions/setup-python@v4
with:
- python-version: "3.7" # the lowest version we support (sync with requires-python in pyproject.toml)
+ python-version: "3.8" # the lowest version we support (sync with requires-python in pyproject.toml)
update-environment: true
- name: Setup CUDA Toolkit
@@ -102,7 +102,7 @@ jobs:
timeout-minutes: 60
strategy:
matrix:
- os: [ubuntu-latest, macos-latest] # jaxlib is not available on Windows
+ os: [ubuntu-latest, windows-latest, macos-latest]
fail-fast: false
steps:
- name: Checkout
@@ -111,10 +111,10 @@ jobs:
submodules: "recursive"
fetch-depth: 1
- - name: Set up Python 3.7
+ - name: Set up Python 3.8
uses: actions/setup-python@v4
with:
- python-version: "3.7" # the lowest version we support (sync with requires-python in pyproject.toml)
+ python-version: "3.8" # the lowest version we support (sync with requires-python in pyproject.toml)
update-environment: true
- name: Upgrade pip
diff --git a/.gitignore b/.gitignore
index 450d7b0c..350ddfb2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -146,6 +146,9 @@ venv.bak/
# mkdocs documentation
/site
+# ruff
+.ruff_cache/
+
# mypy
.mypy_cache/
.dmypy.json
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 66f0bdf0..a16ff100 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -3,8 +3,9 @@
ci:
skip: [pylint]
autofix_prs: true
- autofix_commit_msg: 'fix: [pre-commit.ci] auto fixes [...]'
- autoupdate_commit_msg: 'chore(pre-commit): [pre-commit.ci] autoupdate'
+ autofix_commit_msg: "fix: [pre-commit.ci] auto fixes [...]"
+ autoupdate_commit_msg: "chore(pre-commit): [pre-commit.ci] autoupdate"
+default_stages: [commit, push, manual]
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
@@ -24,30 +25,57 @@ repos:
- id: debug-statements
- id: double-quote-string-fixer
- repo: https://github.com/pre-commit/mirrors-clang-format
- rev: v15.0.7
+ rev: v16.0.3
hooks:
- - id: clang-format
- stages: [commit, push, manual]
+ - id: clang-format
+ - repo: https://github.com/charliermarsh/ruff-pre-commit
+ rev: v0.0.265
+ hooks:
+ - id: ruff
+ args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
- stages: [commit, push, manual]
- repo: https://github.com/psf/black
- rev: 23.1.0
+ rev: 23.3.0
hooks:
- id: black-jupyter
- stages: [commit, push, manual]
- repo: https://github.com/asottile/pyupgrade
- rev: v3.3.1
+ rev: v3.4.0
hooks:
- id: pyupgrade
- args: [--py37-plus] # sync with requires-python
- stages: [commit, push, manual]
+ args: [--py38-plus] # sync with requires-python
exclude: |
(?x)(
^examples/
)
+ - repo: https://github.com/pycqa/flake8
+ rev: 6.0.0
+ hooks:
+ - id: flake8
+ additional_dependencies:
+ - flake8-bugbear
+ - flake8-comprehensions
+ - flake8-docstrings
+ - flake8-pyi
+ - flake8-simplify
+ exclude: |
+ (?x)(
+ ^examples/|
+ ^tests/|
+ ^docs/source/conf.py$
+ )
+ - repo: https://github.com/codespell-project/codespell
+ rev: v2.2.4
+ hooks:
+ - id: codespell
+ additional_dependencies: [".[toml]"]
+ exclude: |
+ (?x)(
+ ^docs/source/spelling_wordlist.txt$|
+ ^docs/source/references.bib$
+ )
- repo: local
hooks:
- id: pylint
@@ -56,7 +84,6 @@ repos:
language: system
types: [python]
require_serial: true
- stages: [commit, push, manual]
exclude: |
(?x)(
^docs/|
@@ -68,7 +95,7 @@ repos:
rev: 6.3.0
hooks:
- id: pydocstyle
- additional_dependencies: ['.[toml]']
+ additional_dependencies: [".[toml]"]
exclude: |
(?x)(
^.github/|
diff --git a/.pylintrc b/.pylintrc
index accc71d5..a21967ee 100644
--- a/.pylintrc
+++ b/.pylintrc
@@ -84,7 +84,7 @@ persistent=yes
# Minimum Python version to use for version dependent checks. Will default to
# the version used to run pylint.
-py-version=3.7 # the lowest version we support (sync with requires-python in pyproject.toml)
+py-version=3.8 # the lowest version we support (sync with requires-python in pyproject.toml)
# Discover python modules and packages in the file system subtree.
recursive=no
diff --git a/.readthedocs.yaml b/.readthedocs.yaml
index 73e1e60f..6a9c387e 100644
--- a/.readthedocs.yaml
+++ b/.readthedocs.yaml
@@ -19,10 +19,6 @@ build:
conda:
environment: docs/conda-recipe.yaml
-# If using Sphinx, optionally build your docs in additional formats such as PDF
-formats:
- - pdf
-
# Build documentation in the docs/ directory with Sphinx
sphinx:
builder: html
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 927cb1db..cb158207 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -29,6 +29,27 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
------
+## [0.7.1] - 2023-05-12
+
+### Added
+
+- Enable CI workflow to build CXX/CUDA extension for Python 3.11 by [@XuehaiPan](https://github.com/XuehaiPan) in [#152](https://github.com/metaopt/torchopt/pull/152).
+- Implement AdaGrad optimizer and exponential learning rate decay schedule by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#80](https://github.com/metaopt/torchopt/pull/80).
+- Enable tests on Windows by [@XuehaiPan](https://github.com/XuehaiPan) in [#140](https://github.com/metaopt/torchopt/pull/140).
+- Add `ruff` and `flake8` plugins integration by [@XuehaiPan](https://github.com/XuehaiPan) in [#138](https://github.com/metaopt/torchopt/pull/138) and [#139](https://github.com/metaopt/torchopt/pull/139).
+- Add more documentation on implicit differentiation by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#143](https://github.com/metaopt/torchopt/pull/143).
+
+### Fixed
+
+- Fix overloaded annotations of `extract_state_dict` by [@StefanoWoerner](https://github.com/StefanoWoerner) in [#162](https://github.com/metaopt/torchopt/pull/162).
+- Fix transpose empty iterable with `zip(*nested)` in transformations by [@XuehaiPan](https://github.com/XuehaiPan) in [#145](https://github.com/metaopt/torchopt/pull/145).
+
+### Removed
+
+- Drop Python 3.7 support by [@XuehaiPan](https://github.com/XuehaiPan) in [#136](https://github.com/metaopt/torchopt/pull/136).
+
+------
+
## [0.7.0] - 2023-02-16
### Added
@@ -166,7 +187,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
------
-[Unreleased]: https://github.com/metaopt/torchopt/compare/v0.7.0...HEAD
+[Unreleased]: https://github.com/metaopt/torchopt/compare/v0.7.1...HEAD
+[0.7.1]: https://github.com/metaopt/torchopt/compare/v0.7.0...v0.7.1
[0.7.0]: https://github.com/metaopt/torchopt/compare/v0.6.0...v0.7.0
[0.6.0]: https://github.com/metaopt/torchopt/compare/v0.5.0...v0.6.0
[0.5.0]: https://github.com/metaopt/torchopt/compare/v0.4.3...v0.5.0
diff --git a/CITATION.cff b/CITATION.cff
index 83a207e6..e7cf54cb 100644
--- a/CITATION.cff
+++ b/CITATION.cff
@@ -32,7 +32,7 @@ authors:
family-names: Yang
affiliation: Peking University
email: yaodong.yang@pku.edu.cn
-version: 0.7.0
-date-released: "2023-02-16"
+version: 0.7.1
+date-released: "2023-05-12"
license: Apache-2.0
repository-code: "https://github.com/metaopt/torchopt"
diff --git a/Makefile b/Makefile
index f856cb21..906d8a64 100644
--- a/Makefile
+++ b/Makefile
@@ -10,7 +10,7 @@ CUDA_FILES = $(shell find $(SOURCE_FOLDERS) -type f -name "*.cuh" -o -name "
COMMIT_HASH = $(shell git log -1 --format=%h)
PATH := $(HOME)/go/bin:$(PATH)
PYTHON ?= $(shell command -v python3 || command -v python)
-CLANG_FORMAT ?= $(shell command -v clang-format-14 || command -v clang-format)
+CLANG_FORMAT ?= $(shell command -v clang-format-17 || command -v clang-format)
PYTESTOPTS ?=
.PHONY: default
@@ -46,12 +46,19 @@ pylint-install:
flake8-install:
$(call check_pip_install,flake8)
- $(call check_pip_install_extra,flake8-bugbear,flake8-bugbear)
+ $(call check_pip_install,flake8-bugbear)
+ $(call check_pip_install,flake8-comprehensions)
+ $(call check_pip_install,flake8-docstrings)
+ $(call check_pip_install,flake8-pyi)
+ $(call check_pip_install,flake8-simplify)
py-format-install:
$(call check_pip_install,isort)
$(call check_pip_install_extra,black,black[jupyter])
+ruff-install:
+ $(call check_pip_install,ruff)
+
mypy-install:
$(call check_pip_install,mypy)
@@ -61,11 +68,7 @@ pre-commit-install:
docs-install:
$(call check_pip_install_extra,pydocstyle,pydocstyle[toml])
- $(call check_pip_install_extra,doc8,"doc8<1.0.0a0")
- if ! $(PYTHON) -c "import sys; exit(sys.version_info < (3, 8))"; then \
- $(PYTHON) -m pip uninstall --yes importlib-metadata; \
- $(call check_pip_install_extra,importlib-metadata,"importlib-metadata<5.0.0a0"); \
- fi
+ $(call check_pip_install,doc8)
$(call check_pip_install,sphinx)
$(call check_pip_install,sphinx-rtd-theme)
$(call check_pip_install,sphinx-autoapi)
@@ -75,7 +78,7 @@ docs-install:
$(call check_pip_install,sphinxcontrib-bibtex)
$(call check_pip_install,sphinx-autodoc-typehints)
$(call check_pip_install,myst-nb)
- $(call check_pip_install_extra,sphinxcontrib.spelling,sphinxcontrib.spelling pyenchant)
+ $(call check_pip_install_extra,sphinxcontrib-spelling,sphinxcontrib-spelling pyenchant)
pytest-install:
$(call check_pip_install,pytest)
@@ -92,8 +95,8 @@ cpplint-install:
$(call check_pip_install,cpplint)
clang-format-install:
- command -v clang-format-14 || command -v clang-format || \
- sudo apt-get install -y clang-format-14 || \
+ command -v clang-format-17 || command -v clang-format || \
+ sudo apt-get install -y clang-format-17 || \
sudo apt-get install -y clang-format
clang-tidy-install:
@@ -122,14 +125,20 @@ pylint: pylint-install
$(PYTHON) -m pylint $(PROJECT_PATH)
flake8: flake8-install
- $(PYTHON) -m flake8 $(PYTHON_FILES) --count --select=E9,F63,F7,F82,E225,E251 --show-source --statistics
+ $(PYTHON) -m flake8 --count --show-source --statistics
py-format: py-format-install
$(PYTHON) -m isort --project $(PROJECT_NAME) --check $(PYTHON_FILES) && \
$(PYTHON) -m black --check $(PYTHON_FILES) tutorials
+ruff: ruff-install
+ $(PYTHON) -m ruff check .
+
+ruff-fix: ruff-install
+ $(PYTHON) -m ruff check . --fix --exit-non-zero-on-fix
+
mypy: mypy-install
- $(PYTHON) -m mypy $(PROJECT_PATH)
+ $(PYTHON) -m mypy $(PROJECT_PATH) --install-types --non-interactive
pre-commit: pre-commit-install
$(PYTHON) -m pre_commit run --all-files
@@ -177,17 +186,19 @@ clean-docs:
# Utility functions
-lint: flake8 py-format mypy pylint clang-format clang-tidy cpplint addlicense docstyle spelling
+lint: ruff flake8 py-format mypy pylint clang-format clang-tidy cpplint addlicense docstyle spelling
-format: py-format-install clang-format-install addlicense-install
+format: py-format-install ruff-install clang-format-install addlicense-install
$(PYTHON) -m isort --project $(PROJECT_NAME) $(PYTHON_FILES)
$(PYTHON) -m black $(PYTHON_FILES) tutorials
+ $(PYTHON) -m ruff check . --fix --exit-zero
$(CLANG_FORMAT) -style=file -i $(CXX_FILES) $(CUDA_FILES)
addlicense -c $(COPYRIGHT) -ignore tests/coverage.xml -l apache -y 2022-$(shell date +"%Y") $(SOURCE_FOLDERS)
clean-py:
find . -type f -name '*.py[co]' -delete
find . -depth -type d -name "__pycache__" -exec rm -r "{}" +
+ find . -depth -type d -name ".ruff_cache" -exec rm -r "{}" +
find . -depth -type d -name ".mypy_cache" -exec rm -r "{}" +
find . -depth -type d -name ".pytest_cache" -exec rm -r "{}" +
rm tests/.coverage
@@ -211,5 +222,8 @@ docker-devel:
docker: docker-base docker-devel
+docker-run-base: docker-base
+ docker run --network=host --gpus=all -v /:/host -h ubuntu -it $(PROJECT_NAME):$(COMMIT_HASH)
+
docker-run-devel: docker-devel
docker run --network=host --gpus=all -v /:/host -h ubuntu -it $(PROJECT_NAME)-devel:$(COMMIT_HASH)
diff --git a/README.md b/README.md
index 321f39e3..5bc474fc 100644
--- a/README.md
+++ b/README.md
@@ -8,10 +8,10 @@
-

+



-

+




@@ -34,7 +34,7 @@ TorchOpt is:
- **Efficient**: TorchOpt provides (1) CPU/GPU acceleration differentiable optimizer (2) RPC-based distributed training framework (3) Fast Tree Operations, to largely increase the training efficiency for bi-level optimization problems.
Beyond differentiable optimization, TorchOpt can also be regarded as a functional optimizer that enables [JAX-like](https://github.com/google/jax) composable functional optimizer for PyTorch.
-With TorchOpt, users can easily conduct neural network optimization in PyTorch with functional style optimizer, similar to [Optax](https://github.com/deepmind/optax) in JAX.
+With TorchOpt, users can easily conduct neural network optimization in PyTorch with a functional style optimizer, similar to [Optax](https://github.com/deepmind/optax) in JAX.
--------------------------------------------------------------------------------
@@ -132,11 +132,11 @@ optimizer.step() # step updates
### Differentiable
-On top of the same optimization function as `torch.optim`, an important benefit of functional optimizer is that one can implement differentiable optimization easily.
-This is particularly helpful when the algorithm requires to differentiate through optimization updates (such as meta-learning practices).
+On top of the same optimization function as `torch.optim`, an important benefit of the functional optimizer is that one can implement differentiable optimization easily.
+This is particularly helpful when the algorithm requires differentiation through optimization updates (such as meta-learning practices).
We take as the inputs the gradients and optimizer states, and use non-in-place operators to compute and output the updates.
The processes can be automatically implemented, with the only need from users being to pass the argument `inplace=False` to the functions.
-Check out section [Explicit Gradient (EG)](#explicit-gradient-eg) functional API for example.
+Check out the section [Explicit Gradient](#explicit-gradient-eg) (EG)](#explicit-gradient-eg) functional API for example.
--------------------------------------------------------------------------------
@@ -155,7 +155,7 @@ From the BR-based perspective, existing gradient methods can be categorized into
### Explicit Gradient (EG)
-The idea of explicit gradient is to treat the gradient step as a differentiable function and try to backpropagate through the unrolled optimization path.
+The idea of the explicit gradient is to treat the gradient step as a differentiable function and try to backpropagate through the unrolled optimization path.
This differentiation mode is suitable for algorithms when the inner-level optimization solution is obtained by a few gradient steps, such as [MAML](https://arxiv.org/abs/1703.03400) and [MGRL](https://arxiv.org/abs/1805.09801).
TorchOpt offers both functional and object-oriented API for EG to fit different user applications.
@@ -163,7 +163,7 @@ TorchOpt offers both functional and object-oriented API for EG to fit different
The functional API is to conduct optimization in a functional programming style.
Note that we pass the argument `inplace=False` to the functions to make the optimization differentiable.
-Refer to the tutorial notebook [Functional Optimizer](tutorials/1_Functional_Optimizer.ipynb) for more guidances.
+Refer to the tutorial notebook [Functional Optimizer](tutorials/1_Functional_Optimizer.ipynb) for more guidance.
```python
# Define functional optimizer
@@ -187,8 +187,8 @@ meta_grads = torch.autograd.grad(loss, meta_params)
#### OOP API
-TorchOpt also provides OOP API compatible with PyTorch programming style.
-Refer to the example and the tutorial notebook [Meta-Optimizer](tutorials/3_Meta_Optimizer.ipynb), [Stop Gradient](tutorials/4_Stop_Gradient.ipynb) for more guidances.
+TorchOpt also provides OOP API compatible with the PyTorch programming style.
+Refer to the example and the tutorial notebook [Meta-Optimizer](tutorials/3_Meta_Optimizer.ipynb), [Stop Gradient](tutorials/4_Stop_Gradient.ipynb) for more guidance.
```python
# Define meta and inner parameters
@@ -215,7 +215,7 @@ Refer to the example [iMAML](https://github.com/waterhorse1/torchopt/tree/readme
#### Functional API
-For implicit gradient, users need to define the stationary condition and TorchOpt provides the decorator to wrap the solve function for enabling implicit gradient computation.
+For the implicit gradient, users need to define the stationary condition and TorchOpt provides the decorator to wrap the solve function for enabling implicit gradient computation.
```python
# The stationary condition for the inner-loop
@@ -240,7 +240,7 @@ meta_grads = torch.autograd.grad(loss, meta_params)
#### OOP API
-TorchOpt also offer an OOP API, users need to inherit from the class `torchopt.nn.ImplicitMetaGradientModule` to construct the inner-loop network.
+TorchOpt also offers an OOP API, which users need to inherit from the class `torchopt.nn.ImplicitMetaGradientModule` to construct the inner-loop network.
Users need to define the stationary condition/objective function and the inner-loop solve function to enable implicit gradient computation.
```python
@@ -288,7 +288,7 @@ When the inner-loop process is non-differentiable or one wants to eliminate the
ZD typically gets gradients based on zero-order estimation, such as finite-difference, or [Evolutionary Strategy](https://arxiv.org/abs/1703.03864).
Instead of optimizing the objective $F$, ES optimizes a smoothed objective.
TorchOpt provides both functional and OOP APIs for the ES method.
-Refer to the tutorial notebook [Zero-order Differentiation](tutorials/6_Zero_Order_Differentiation.ipynb) for more guidances.
+Refer to the tutorial notebook [Zero-order Differentiation](tutorials/6_Zero_Order_Differentiation.ipynb) for more guidance.
#### Functional API
@@ -315,7 +315,7 @@ def forward(params, batch, labels):
#### OOP API
-TorchOpt also offer an OOP API, users need to inherit from the class `torchopt.nn.ZeroOrderGradientModule` to construct the network as an `nn.Module` following a classical PyTorch style.
+TorchOpt also offers an OOP API, which users need to inherit from the class `torchopt.nn.ZeroOrderGradientModule` to construct the network as an `nn.Module` following a classical PyTorch style.
Users need to define the forward process zero-order gradient procedures `forward()` and a noise sampling function `sample()`.
```python
@@ -356,8 +356,8 @@ We take the optimizer as a whole instead of separating it into several basic ope
Therefore, by manually writing the forward and backward functions, we can perform the symbolic reduction.
In addition, we can store some intermediate data that can be reused during the backpropagation.
We write the accelerated functions in C++ OpenMP and CUDA, bind them by [`pybind11`](https://github.com/pybind/pybind11) to allow they can be called by Python, and then define the forward and backward behavior using `torch.autograd.Function`.
-Users can use by simply setting the `use_accelerated_op` flag as `True`.
-Refer to the corresponding sections in tutorials [Functional Optimizer](tutorials/1_Functional_Optimizer.ipynb) and [Meta-Optimizer](tutorials/3_Meta_Optimizer.ipynb)
+Users can use it by simply setting the `use_accelerated_op` flag as `True`.
+Refer to the corresponding sections in the tutorials [Functional Optimizer](tutorials/1_Functional_Optimizer.ipynb)](tutorials/1_Functional_Optimizer.ipynb) and [Meta-Optimizer](tutorials/3_Meta_Optimizer.ipynb)
```python
optimizer = torchopt.MetaAdam(model, lr, use_accelerated_op=True)
@@ -366,7 +366,7 @@ optimizer = torchopt.MetaAdam(model, lr, use_accelerated_op=True)
### Distributed Training
`TorchOpt` provides distributed training features based on the PyTorch RPC module for better training speed and multi-node multi-GPU support.
-Different from the MPI-like parallelization paradigm, which uses multiple homogenous workers and requires carefully designed communication hooks, the RPC APIs allow users to build their optimization pipeline more flexibly.
+Different from the MPI-like parallelization paradigm, which uses multiple homogeneous workers and requires carefully designed communication hooks, the RPC APIs allow users to build their optimization pipeline more flexibly.
Experimental results show that we achieve an approximately linear relationship between the speed-up ratio and the number of workers.
Check out the [Distributed Training Documentation](https://torchopt.readthedocs.io/en/latest/distributed/distributed.html) and [distributed MAML example](https://github.com/metaopt/torchopt/tree/main/examples/distributed/few-shot) for more specific guidance.
@@ -381,12 +381,12 @@ For more guidance and comparison results, please refer to our open-source projec
## Visualization
-Complex gradient flow in meta-learning brings in a great challenge for managing the gradient flow and verifying the correctness of it.
+Complex gradient flow in meta-learning brings in a great challenge for managing the gradient flow and verifying its correctness of it.
TorchOpt provides a visualization tool that draws variable (e.g., network parameters or meta-parameters) names on the gradient graph for better analysis.
The visualization tool is modified from [`torchviz`](https://github.com/szagoruyko/pytorchviz).
Refer to the example [visualization code](examples/visualize.py) and the tutorial notebook [Visualization](tutorials/2_Visualization.ipynb) for more details.
-The figure below show the visualization result.
+The figure below shows the visualization result.
Compared with [`torchviz`](https://github.com/szagoruyko/pytorchviz), TorchOpt fuses the operations within the `Adam` together (orange) to reduce the complexity and provide simpler visualization.
@@ -397,7 +397,7 @@ Compared with [`torchviz`](https://github.com/szagoruyko/pytorchviz), TorchOpt f
## Examples
-In the [`examples`](examples) directory, we offer several examples of functional optimizers and light-weight meta-learning examples with TorchOpt.
+In the [`examples`](examples) directory, we offer several examples of functional optimizers and lightweight meta-learning examples with TorchOpt.
- [Model-Agnostic Meta-Learning (MAML) - Supervised Learning](https://arxiv.org/abs/1703.03400) (ICML 2017)
- [Learning to Reweight Examples for Robust Deep Learning](https://arxiv.org/abs/1803.09050) (ICML 2018)
@@ -406,7 +406,7 @@ In the [`examples`](examples) directory, we offer several examples of functional
- [Learning through opponent learning process (LOLA)](https://arxiv.org/abs/1709.04326) (AAMAS 2018)
- [Meta-Learning with Implicit Gradients](https://arxiv.org/abs/1909.04630) (NeurIPS 2019)
-Also check [`examples`](examples) for more distributed/visualization/functorch-compatible examples.
+Also, check [`examples`](examples) for more distributed/visualization/functorch-compatible examples.
--------------------------------------------------------------------------------
diff --git a/codecov.yml b/codecov.yml
index 65b70e6e..e1d3aab2 100644
--- a/codecov.yml
+++ b/codecov.yml
@@ -1,9 +1,12 @@
coverage:
+ precision: 2
round: nearest
status:
project:
default:
+ target: auto
threshold: 0.05%
patch:
default:
+ target: 100%
informational: true
diff --git a/conda-recipe.yaml b/conda-recipe.yaml
index faee0a7c..997f11c5 100644
--- a/conda-recipe.yaml
+++ b/conda-recipe.yaml
@@ -38,7 +38,7 @@ dependencies:
- torchviz
- sphinxcontrib-katex # for documentation
- jax # for tutorials
- - jaxlib >= 0.3=*cuda* # for tutorials
+ - jaxlib # for tutorials
- optax # for tutorials
- jaxopt # for tests
- tensorboard # for examples
@@ -90,9 +90,15 @@ dependencies:
- mypy >= 0.990
- flake8
- flake8-bugbear
- - doc8 < 1.0.0a0
+ - flake8-comprehensions
+ - flake8-docstrings
+ - flake8-pyi
+ - flake8-simplify
+ - ruff
+ - doc8
- pydocstyle
- clang-format >= 14
- clang-tools >= 14 # clang-tidy
- cpplint
- - pre-commit
+ - conda-forge::pre-commit
+ - conda-forge::identify
diff --git a/docs/source/api/api.rst b/docs/source/api/api.rst
index b2866407..d00e2333 100644
--- a/docs/source/api/api.rst
+++ b/docs/source/api/api.rst
@@ -30,10 +30,11 @@ Functional Optimizers
.. autosummary::
FuncOptimizer
+ adagrad
adam
- sgd
- rmsprop
adamw
+ rmsprop
+ sgd
Wrapper for Function Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -41,6 +42,11 @@ Wrapper for Function Optimizer
.. autoclass:: FuncOptimizer
:members:
+Functional AdaGrad Optimizer
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autofunction:: adagrad
+
Functional Adam Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -51,16 +57,16 @@ Functional AdamW Optimizer
.. autofunction:: adamw
-Functional SGD Optimizer
-~~~~~~~~~~~~~~~~~~~~~~~~
-
-.. autofunction:: sgd
-
Functional RMSProp Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: rmsprop
+Functional SGD Optimizer
+~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autofunction:: sgd
+
------
Classic Optimizers
@@ -70,10 +76,16 @@ Classic Optimizers
.. autosummary::
+ AdaGrad
Adam
- SGD
- RMSProp
AdamW
+ RMSProp
+ SGD
+
+Classic AdaGrad Optimizer
+~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autoclass:: AdaGrad
Classic Adam Optimizer
~~~~~~~~~~~~~~~~~~~~~~
@@ -85,16 +97,16 @@ Classic AdamW Optimizer
.. autoclass:: AdamW
-Classic SGD Optimizer
-~~~~~~~~~~~~~~~~~~~~~
-
-.. autoclass:: SGD
-
Classic RMSProp Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: RMSProp
+Classic SGD Optimizer
+~~~~~~~~~~~~~~~~~~~~~
+
+.. autoclass:: SGD
+
------
Differentiable Meta-Optimizers
@@ -104,10 +116,16 @@ Differentiable Meta-Optimizers
.. autosummary::
+ MetaAdaGrad
MetaAdam
- MetaSGD
- MetaRMSProp
MetaAdamW
+ MetaRMSProp
+ MetaSGD
+
+Differentiable Meta-AdaGrad Optimizer
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autoclass:: MetaAdaGrad
Differentiable Meta-Adam Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -119,16 +137,16 @@ Differentiable Meta-AdamW Optimizer
.. autoclass:: MetaAdamW
-Differentiable Meta-SGD Optimizer
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-.. autoclass:: MetaSGD
-
Differentiable Meta-RMSProp Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: MetaRMSProp
+Differentiable Meta-SGD Optimizer
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autoclass:: MetaSGD
+
------
Implicit Differentiation
diff --git a/docs/source/conf.py b/docs/source/conf.py
index d8233da7..f5d206c7 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -30,6 +30,7 @@
import pathlib
import sys
+import sphinx
import sphinxcontrib.katex as katex
@@ -39,7 +40,7 @@
def get_version() -> str:
sys.path.insert(0, str(PROJECT_ROOT / 'torchopt'))
- import version # noqa
+ import version
return version.__version__
@@ -51,7 +52,7 @@ def get_version() -> str:
else:
class RecursiveForwardRefFilter(logging.Filter):
- def filter(self, record):
+ def filter(self, record: logging.LogRecord) -> bool:
if (
"name 'TensorTree' is not defined" in record.getMessage()
or "name 'OptionalTensorTree' is not defined" in record.getMessage()
@@ -191,7 +192,7 @@ def filter(self, record):
html_logo = '_static/images/logo.png'
-def setup(app):
+def setup(app: sphinx.application.Sphinx) -> None:
app.add_js_file('https://cdn.jsdelivr.net/npm/vega@5.20.2')
app.add_js_file('https://cdn.jsdelivr.net/npm/vega-lite@5.1.0')
app.add_js_file('https://cdn.jsdelivr.net/npm/vega-embed@6.17.0')
diff --git a/docs/source/developer/contributing.rst b/docs/source/developer/contributing.rst
index ee66f560..4e7dd355 100644
--- a/docs/source/developer/contributing.rst
+++ b/docs/source/developer/contributing.rst
@@ -51,7 +51,7 @@ Lint Check
We use several tools to secure code quality, including:
- * PEP8 code style: ``black``, ``isort``, ``pylint``, ``flake8``
+ * Python code style: ``black``, ``isort``, ``pylint``, ``flake8``, ``ruff``
* Type hint check: ``mypy``
* C++ Google-style: ``cpplint``, ``clang-format``, ``clang-tidy``
* License: ``addlicense``
@@ -102,9 +102,9 @@ For example, the following command will build a wheel for Python 3.7:
.. code-block:: bash
- CIBW_BUILD="cp37*manylinux*" python3 -m cibuildwheel --platform=linux --output-dir=wheelhouse --config-file=pyproject.toml
+ CIBW_BUILD="cp38*manylinux*" python3 -m cibuildwheel --platform=linux --output-dir=wheelhouse --config-file=pyproject.toml
-You can change ``cp37*`` to ``cp310*`` to build for Python 3.10. See https://cibuildwheel.readthedocs.io/en/stable/options for more options.
+You can change ``cp38*`` to ``cp310*`` to build for Python 3.10. See https://cibuildwheel.readthedocs.io/en/stable/options for more options.
.. |cibuildwheel| replace:: ``cibuildwheel``
.. _cibuildwheel: https://github.com/pypa/cibuildwheel
diff --git a/docs/source/distributed/distributed.rst b/docs/source/distributed/distributed.rst
index b6f00951..0b1bf536 100644
--- a/docs/source/distributed/distributed.rst
+++ b/docs/source/distributed/distributed.rst
@@ -627,7 +627,7 @@ TorchOpt wraps the distributed autograd context and provides a more convenient i
loss = ... # e.g. remote calls
# Backward pass
- grads = todist.autograd.grads(context_id, loss, model.parameters())
+ grads = todist.autograd.grad(context_id, loss, model.parameters())
or
diff --git a/docs/source/explicit_diff/explicit_diff.rst b/docs/source/explicit_diff/explicit_diff.rst
index 89c38df6..f6b82826 100644
--- a/docs/source/explicit_diff/explicit_diff.rst
+++ b/docs/source/explicit_diff/explicit_diff.rst
@@ -53,10 +53,11 @@ For PyTorch-like API (e.g., ``step()``), we designed a base class :class:`torcho
.. autosummary::
torchopt.MetaOptimizer
+ torchopt.MetaAdaGrad
torchopt.MetaAdam
- torchopt.MetaSGD
- torchopt.MetaRMSProp
torchopt.MetaAdamW
+ torchopt.MetaRMSProp
+ torchopt.MetaSGD
By combining low-level API :class:`torchopt.MetaOptimizer` with the previous functional optimizer, we can achieve high-level API:
diff --git a/docs/source/implicit_diff/implicit_diff.rst b/docs/source/implicit_diff/implicit_diff.rst
index df0927c9..5544c25f 100644
--- a/docs/source/implicit_diff/implicit_diff.rst
+++ b/docs/source/implicit_diff/implicit_diff.rst
@@ -10,7 +10,8 @@ Implicit Differentiation
:width: 80%
:align: center
-Implicit differentiation is the task of differentiating the solution of a minimization problem with respect to its inputs.
+Implicit differentiation is the task of differentiating through the solution of an optimization problem satisfying a mapping function :math:`T` capturing the optimality conditions of the problem.
+The simplest example is to differentiate through the solution of a minimization problem with respect to its inputs.
Namely, given
.. math::
@@ -18,7 +19,50 @@ Namely, given
\boldsymbol{\theta}^{\prime} (\boldsymbol{\phi}) \triangleq \underset{\boldsymbol{\theta}}{\mathop{\operatorname{argmin}}} ~ \mathcal{L}^{\text{in}} (\boldsymbol{\phi},\boldsymbol{\theta}).
By treating the solution :math:`\boldsymbol{\theta}^{\prime}` as an implicit function of :math:`\boldsymbol{\phi}`, the idea of implicit differentiation is to directly get analytical best-response derivatives :math:`\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})` by the implicit function theorem.
-This is suitable for algorithms when the inner-level optimal solution is achieved :math:`\left. \frac{\partial \mathcal{L}^{\text{in}} (\boldsymbol{\phi}, \boldsymbol{\theta})}{\partial \boldsymbol{\theta}} \right\rvert_{\boldsymbol{\theta} = \boldsymbol{\theta}^{\prime}} = 0` (e.g., the function :math:`F` in the figure means the solution is obtained by unrolled gradient steps) or reaches some stationary conditions :math:`F (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime}) = 0`, such as `IMAML
`_ and `DEQ `_.
+
+Root Finding
+~~~~~~~~~~~~
+
+This is suitable for algorithms when the inner-level optimality conditions :math:`T` is defined by a root of a function, such as:
+
+.. math::
+
+ T (\boldsymbol{\phi}, \boldsymbol{\theta}) = \frac{ \partial \mathcal{L}^{\text{in}} (\boldsymbol{\phi}, \boldsymbol{\theta})}{\partial \boldsymbol{\theta}}, \qquad T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})) = \left. \frac{ \partial \mathcal{L}^{\text{in}} (\boldsymbol{\phi}, \boldsymbol{\theta})}{\partial \boldsymbol{\theta}} \right\rvert_{\boldsymbol{\theta} = \boldsymbol{\theta}^{\prime}} = \boldsymbol{0}.
+
+In `IMAML `_, the function :math:`F` in the figure means the inner-level optimal solution is obtained by unrolled gradient update:
+
+.. math::
+
+ \boldsymbol{\theta}_{k + 1} = F (\boldsymbol{\phi}, \boldsymbol{\theta}_k) = \boldsymbol{\theta}_k - \alpha \nabla_{\boldsymbol{\theta}_k} \mathcal{L}^{\text{in}} (\boldsymbol{\phi}, \boldsymbol{\theta}_k).
+
+Fixed-point Iteration
+~~~~~~~~~~~~~~~~~~~~~
+
+Sometimes the inner-level optimal solution can also be achieved by fixed point where the optimality :math:`T` takes the form:
+
+.. math::
+
+ \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi}) = F (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})) \quad \Longleftrightarrow \quad T (\boldsymbol{\phi}, \boldsymbol{\theta}) = F (\boldsymbol{\phi}, \boldsymbol{\theta}) - \boldsymbol{\theta}, \quad T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})) = \boldsymbol{0}.
+
+In `DEQ `_, the function :math:`F` in the figure means the inner-level optimal solution is obtained by fixed point update:
+
+.. math::
+
+ \boldsymbol{\theta}_{k + 1} = F (\boldsymbol{\phi}, \boldsymbol{\theta}_k).
+
+This can be seen as a particular case of root of function by defining the optimality function as :math:`T (\boldsymbol{\phi}, \boldsymbol{\theta}) = F (\boldsymbol{\phi}, \boldsymbol{\theta}) - \boldsymbol{\theta}`.
+This can be implemented with:
+
+.. code-block:: python
+
+ def fixed_point_function(phi: TensorTree, theta: TensorTree) -> TensorTree:
+ ...
+ return new_theta
+
+ # A root function can be derived from the fixed point function
+ def root_function(phi: TensorTree, theta: TensorTree) -> TensorTree:
+ new_theta = fixed_point_function(phi, theta)
+ return torchopt.pytree.tree_sub(new_theta, theta)
Custom Solvers
--------------
@@ -27,8 +71,29 @@ Custom Solvers
torchopt.diff.implicit.custom_root
-TorchOpt provides the :func:`custom_root` decorators, for easily adding implicit differentiation on top of any existing solver (also called forward optimization).
-:func:`custom_root` requires users to define the stationary conditions for the problem solution (e.g., KKT conditions) and will automatically calculate the gradient for backward gradient computation.
+Let :math:`T (\boldsymbol{\phi}, \boldsymbol{\theta}): \mathbb{R}^n \times \mathbb{R}^d \to \mathbb{R}^d` be a user-provided mapping function, that captures the optimality conditions of a problem.
+An optimal solution, denoted :math:`\boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})`, should be a root of :math:`T`:
+
+.. math::
+
+ T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi})) = \boldsymbol{0}.
+
+We can see :math:`\boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})` as an implicitly defined function of :math:`\boldsymbol{\phi} \in \mathbb{R}^n`, i.e., :math:`\boldsymbol{\theta}^{\prime}: \mathbb{R}^n \rightarrow \mathbb{R}^d`.
+More precisely, from the `implicit function theorem `_, we know that for :math:`(\boldsymbol{\phi}_0, \boldsymbol{\theta}^{\prime}_0)` satisfying :math:`T (\boldsymbol{\phi}_0, \boldsymbol{\theta}^{\prime}_0) = \boldsymbol{0}` with a continuously differentiable :math:`T`, if the Jacobian :math:`\nabla_{\boldsymbol{\theta}^{\prime}} T` evaluated at :math:`(\boldsymbol{\phi}_0, \boldsymbol{\theta}^{\prime}_0)` is a square invertible matrix, then there exists a function :math:`\boldsymbol{\theta}^{\prime} (\cdot)` defined on a neighborhood of :math:`\boldsymbol{\phi}_0` such that :math:`\boldsymbol{\theta}^{\prime} (\boldsymbol{\phi}_0) = \boldsymbol{\theta}^{\prime}_0`.
+Furthermore, for all :math:`\boldsymbol{\phi}` in this neighborhood, we have that :math:`T (\boldsymbol{\phi}_0, \boldsymbol{\theta}^{\prime}_0) = \boldsymbol{0}` and :math:`\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})` exists. Using the chain rule, the Jacobian :math:`\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi})` satisfies:
+
+.. math::
+
+ \frac{d T}{d \boldsymbol{\phi}} = \underbrace{\nabla_{\boldsymbol{\theta}^{\prime}} T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi}))}_{\frac{\partial T}{\partial \boldsymbol{\theta}^{\prime}}} \underbrace{\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})}_{\frac{d \boldsymbol{\theta}^{\prime}}{d \boldsymbol{\phi}}} + \underbrace{\nabla_{\boldsymbol{\phi}} T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi}))}_{\frac{\partial T}{\partial \boldsymbol{\phi}}} = \boldsymbol{0}. \qquad ( T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime}) = \boldsymbol{0} = \text{const})
+
+Computing :math:`\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi})` therefore boils down to the resolution of the linear system of equations
+
+.. math::
+
+ \underbrace{\nabla_{\boldsymbol{\theta}^{\prime}} T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi}))}_{A \in \mathbb{R}^{d \times d}} \underbrace{\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})}_{J \in \mathbb{R}^{d \times n}} = \underbrace{- \nabla_{\boldsymbol{\phi}} T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi}))}_{B \in \mathbb{R}^{d \times n}}.
+
+TorchOpt provides a decorator function :func:`custom_root`, for easily adding implicit differentiation on top of any existing inner optimization solver (also called forward optimization).
+The :func:`custom_root` decorator requires users to define the stationary conditions for the problem solution (e.g., `KKT conditions `_) and will automatically calculate the gradient for backward gradient computation.
Here is an example of the :func:`custom_root` decorators, which is also the **functional API** for implicit gradient.
@@ -137,20 +202,23 @@ Here is an example of the linear solver.
.. code-block:: python
+ import torch
from torchopt import linear_solve
- torch.random.seed(42)
- A = torch.random.randn(3, 3)
- b = torch.random.randn(3)
+ torch.manual_seed(42)
+ A = torch.randn(3, 3)
+ b = torch.randn(3)
- def matvec_A(x):
- return torch.dot(A, x)
+ def matvec(x):
+ return torch.matmul(A, x)
- sol = linear_solve.solve_normal_cg(matvec_A, b, tol=1e-5)
- print(sol)
+ solve_fn = linear_solve.solve_normal_cg(atol=1e-5)
+ solution = solve_fn(matvec, b)
+ print(solution)
- sol = linear_solve.solve_cg(matvec_A, b, tol=1e-5)
- print(sol)
+ solve_fn = linear_solve.solve_cg(atol=1e-5)
+ solution = solve_fn(matvec, b)
+ print(solution)
Users can also select the corresponding solver in functional and OOP APIs.
diff --git a/docs/source/optimizer/optim.rst b/docs/source/optimizer/optim.rst
index 850bc8c7..54c8ef71 100644
--- a/docs/source/optimizer/optim.rst
+++ b/docs/source/optimizer/optim.rst
@@ -18,10 +18,11 @@ Currently, TorchOpt supports 4 functional optimizers: :func:`sgd`, :func:`adam`,
.. autosummary::
torchopt.FuncOptimizer
+ torchopt.adagrad
torchopt.adam
- torchopt.sgd
- torchopt.rmsprop
torchopt.adamw
+ torchopt.rmsprop
+ torchopt.sgd
Apply Parameter Updates
-----------------------
@@ -84,10 +85,12 @@ We offer original PyTorch APIs (e.g., ``zero_grad()`` or ``step()``) for traditi
.. autosummary::
torchopt.Optimizer
+ torchopt.AdaGrad
torchopt.Adam
- torchopt.SGD
- torchopt.RMSProp
torchopt.AdamW
+ torchopt.RMSProp
+ torchopt.SGD
+
By combining low-level API :class:`torchopt.Optimizer` with the previous functional optimizer, we can achieve high-level API:
diff --git a/docs/source/spelling_wordlist.txt b/docs/source/spelling_wordlist.txt
index aac17046..49fdbb69 100644
--- a/docs/source/spelling_wordlist.txt
+++ b/docs/source/spelling_wordlist.txt
@@ -172,3 +172,6 @@ abc
ABCMeta
subclasscheck
ctx
+Duchi
+invertible
+AdaGrad
diff --git a/docs/source/zero_order_diff/zero_order_diff.rst b/docs/source/zero_order_diff/zero_order_diff.rst
index 11232c85..4cc7a034 100644
--- a/docs/source/zero_order_diff/zero_order_diff.rst
+++ b/docs/source/zero_order_diff/zero_order_diff.rst
@@ -10,7 +10,7 @@ Evolutionary Strategy
:width: 80%
:align: center
-When the inner-loop process is non-differentiable or one wants to eliminate the heavy computation burdens in the previous two modes (brought by Hessian), one can choose Zeroth-order differentiation。
+When the inner-loop process is non-differentiable or one wants to eliminate the heavy computation burdens in the previous two modes (brought by Hessian), one can choose Zeroth-order differentiation.
Zero-order differentiation typically gets gradients based on zero-order estimation, such as finite-difference, or `Evolutionary Strategy `_ (ES).
`ES-MAML `_ and `NAC `_ successfully solve the non-differentiable optimization problem based on ES.
diff --git a/examples/FuncTorch/maml_omniglot_vmap.py b/examples/FuncTorch/maml_omniglot_vmap.py
index 0933b44d..e1cfe95e 100644
--- a/examples/FuncTorch/maml_omniglot_vmap.py
+++ b/examples/FuncTorch/maml_omniglot_vmap.py
@@ -79,7 +79,10 @@ def main():
argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15)
argparser.add_argument('--device', type=str, help='device', default='cuda')
argparser.add_argument(
- '--task_num', type=int, help='meta batch size, namely task num', default=32
+ '--task_num',
+ type=int,
+ help='meta batch size, namely task num',
+ default=32,
)
argparser.add_argument('--seed', type=int, help='random seed', default=1)
args = argparser.parse_args()
@@ -199,7 +202,7 @@ def train(db, net, device, meta_opt, epoch, log):
if batch_idx % 4 == 0:
print(
- f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
+ f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}',
)
log.append(
{
@@ -208,7 +211,7 @@ def train(db, net, device, meta_opt, epoch, log):
'acc': qry_accs,
'mode': 'train',
'time': time.time(),
- }
+ },
)
@@ -224,7 +227,7 @@ def test(db, net, device, epoch, log):
qry_losses = []
qry_accs = []
- for batch_idx in range(n_test_iter):
+ for _ in range(n_test_iter):
x_spt, y_spt, x_qry, y_qry = db.next('test')
task_num, setsz, c_, h, w = x_spt.size()
@@ -257,7 +260,7 @@ def test(db, net, device, epoch, log):
'acc': qry_accs,
'mode': 'test',
'time': time.time(),
- }
+ },
)
diff --git a/examples/FuncTorch/parallel_train_torchopt.py b/examples/FuncTorch/parallel_train_torchopt.py
index 640763cb..f28bded7 100644
--- a/examples/FuncTorch/parallel_train_torchopt.py
+++ b/examples/FuncTorch/parallel_train_torchopt.py
@@ -15,8 +15,6 @@
import argparse
import math
-from collections import namedtuple
-from typing import Any, NamedTuple
import functorch
import torch
@@ -137,7 +135,9 @@ def test_parallel_train_step_fn(self, num_models):
weights, opt_state = parallel_init_fn(torch.ones(num_models, 1))
for i in range(2000):
loss, (weights, opt_states) = parallel_train_step_fn(
- (weights, opt_state), points, labels
+ (weights, opt_state),
+ points,
+ labels,
)
if i % 200 == 0:
print(loss)
@@ -188,7 +188,9 @@ def test_parallel_train_step_fn(self, num_models):
optimizer = torchopt.adam(lr=0.2)
opt_state = optimizer.init(weights)
functorch_original = ParallelTrainFunctorchTorchOpt(
- loss_fn=loss_fn, optimizer=optimizer, device=DEVICE
+ loss_fn=loss_fn,
+ optimizer=optimizer,
+ device=DEVICE,
)
# Step 4: Let's verify this actually trains.
# We should see the loss decrease.
@@ -201,7 +203,7 @@ def test_parallel_train_step_fn(self, num_models):
# Step 7: Now, the flaw with step 6 is that we were training on the same exact
# data. This can lead to all of the models in the ensemble overfitting in the
# same way. The solution that http://willwhitney.com/parallel-training-jax.html
- # applies is to randomly subset the data in a way that the models do not recieve
+ # applies is to randomly subset the data in a way that the models do not receive
# exactly the same data in each training step!
# Because the goal of this doc is to show that we can use eager-mode vmap to
# achieve similar things as JAX, the rest of this is left as an exercise to the reader.
diff --git a/examples/L2R/helpers/model.py b/examples/L2R/helpers/model.py
index 80fae8ac..dbde0e8d 100644
--- a/examples/L2R/helpers/model.py
+++ b/examples/L2R/helpers/model.py
@@ -35,7 +35,7 @@
class LeNet5(nn.Module):
def __init__(self, args):
- super(LeNet5, self).__init__()
+ super().__init__()
self.model = nn.Sequential(
nn.Conv2d(1, 16, 5),
nn.ReLU(),
@@ -51,7 +51,7 @@ def __init__(self, args):
)
self.args = args
self.meta_weights = torch.zeros(self.args.batch_size, requires_grad=True).to(
- self.args.device
+ self.args.device,
)
self.criterion = nn.BCELoss()
diff --git a/examples/L2R/helpers/utils.py b/examples/L2R/helpers/utils.py
index fe923860..7e95ca6f 100644
--- a/examples/L2R/helpers/utils.py
+++ b/examples/L2R/helpers/utils.py
@@ -89,16 +89,10 @@ def get_imbalance_dataset(
y_val_subset = np.concatenate([np.zeros([x_val_0.shape[0]]), np.ones([x_val_1.shape[0]])])
y_test_subset = np.concatenate([np.zeros([x_test_0.shape[0]]), np.ones([x_test_1.shape[0]])])
- y_train_pos_subset = np.ones([x_train_1.shape[0]])
- y_train_neg_subset = np.zeros([x_train_0.shape[0]])
-
x_train_subset = np.concatenate([x_train_0, x_train_1], axis=0)[:, None, :, :]
x_val_subset = np.concatenate([x_val_0, x_val_1], axis=0)[:, None, :, :]
x_test_subset = np.concatenate([x_test_0, x_test_1], axis=0)[:, None, :, :]
- x_train_pos_subset = x_train_1[:, None, :, :]
- x_train_neg_subset = x_train_0[:, None, :, :]
-
# Final shuffle.
idx = np.arange(x_train_subset.shape[0])
np.random.shuffle(idx)
@@ -146,7 +140,7 @@ def set_seed(seed, cudnn=True):
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
# note: the below slows down the code but makes it reproducible
- # Sets the seed for generating random numbers on all GPUs. It’s safe to
+ # Sets the seed for generating random numbers on all GPUs. It's safe to
# call this function if CUDA is not available; in that case, it is
# silently ignored.
torch.cuda.manual_seed_all(seed)
@@ -157,7 +151,6 @@ def set_seed(seed, cudnn=True):
def plot(baseline, l2r):
import matplotlib.pyplot as plt
- import numpy as np
import seaborn as sns
sns.set(style='darkgrid')
diff --git a/examples/L2R/l2r.py b/examples/L2R/l2r.py
index 5ce4839d..64990976 100644
--- a/examples/L2R/l2r.py
+++ b/examples/L2R/l2r.py
@@ -51,14 +51,13 @@ def run_baseline(args, mnist_train, mnist_test):
ntest = args.ntest
epoch = args.epoch
- folder = './result/baseline/'
writer = SummaryWriter('./result/baseline')
with open('./result/baseline/config.json', 'w') as f:
json.dump(args.__dict__, f)
args.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
- train_set, val_set, test_set = get_imbalance_dataset(
+ train_set, _, test_set = get_imbalance_dataset(
mnist_train,
mnist_test,
pos_ratio=pos_ratio,
@@ -67,7 +66,6 @@ def run_baseline(args, mnist_train, mnist_test):
ntest=ntest,
)
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4)
- valid_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=True, num_workers=1)
test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=True, num_workers=1)
model = LeNet5(args).to(args.device)
@@ -91,7 +89,7 @@ def run_baseline(args, mnist_train, mnist_test):
if step % 10 == 0 and step > 0:
running_train_mean = np.mean(np.array(running_train_loss))
- print('EPOCH: {}, BATCH: {}, LOSS: {}'.format(_epoch, idx, running_train_mean))
+ print(f'EPOCH: {_epoch}, BATCH: {idx}, LOSS: {running_train_mean}')
writer.add_scalar('running_train_loss', running_train_mean, step)
running_train_loss = []
@@ -106,7 +104,7 @@ def run_baseline(args, mnist_train, mnist_test):
writer.add_scalar('train_acc', train_acc, _epoch)
writer.add_scalar('test_acc', test_acc, _epoch)
test_acc_result.append(test_acc)
- print('EPOCH: {}, TRAIN_ACC: {}, TEST_ACC: {}'.format(_epoch, train_acc, test_acc))
+ print(f'EPOCH: {_epoch}, TRAIN_ACC: {train_acc}, TEST_ACC: {test_acc}')
return test_acc_result
@@ -120,7 +118,6 @@ def run_L2R(args, mnist_train, mnist_test):
ntest = args.ntest
epoch = args.epoch
- folder = './result/l2r/'
writer = SummaryWriter('./result/l2r/log')
with open('./result/l2r/config.json', 'w') as f:
json.dump(args.__dict__, f)
@@ -143,7 +140,6 @@ def run_L2R(args, mnist_train, mnist_test):
real_model_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
step = 0
- time_bp = 0
running_valid_loss = []
valid = iter(valid_loader)
running_train_loss = []
@@ -173,7 +169,7 @@ def run_L2R(args, mnist_train, mnist_test):
inner_loss = model.inner_loss(train_x, train_label)
model_optimizer.step(inner_loss)
- # caclulate outer_loss, deirve meta-gradient and normalise
+ # calculate outer_loss, derive meta-gradient and normalize
outer_loss = model.outer_loss(valid_x, valid_label)
model.meta_weights = -torch.autograd.grad(outer_loss, model.meta_weights)[0]
model.meta_weights = torch.nn.ReLU()(model.meta_weights)
@@ -203,8 +199,11 @@ def run_L2R(args, mnist_train, mnist_test):
running_train_mean = np.mean(np.array(running_train_loss))
print(
'EPOCH: {}, BATCH: {}, WEIGHTED_TRAIN_LOSS: {}, VALID_LOSS: {}'.format(
- _epoch, idx, running_train_mean, running_valid_mean
- )
+ _epoch,
+ idx,
+ running_train_mean,
+ running_valid_mean,
+ ),
)
running_valid_loss = []
running_train_loss = []
@@ -222,7 +221,7 @@ def run_L2R(args, mnist_train, mnist_test):
writer.add_scalar('train_acc', train_acc, _epoch)
writer.add_scalar('test_acc', test_acc, _epoch)
test_acc_result.append(test_acc)
- print('EPOCH: {}, TRAIN_ACC: {}, TEST_ACC: {}'.format(_epoch, train_acc, test_acc))
+ print(f'EPOCH: {_epoch}, TRAIN_ACC: {train_acc}, TEST_ACC: {test_acc}')
return test_acc_result
diff --git a/examples/LOLA/helpers/env.py b/examples/LOLA/helpers/env.py
index f1ef6e6f..f496276e 100644
--- a/examples/LOLA/helpers/env.py
+++ b/examples/LOLA/helpers/env.py
@@ -54,7 +54,7 @@ def __eq__(self, other):
class IPD(gym.Env):
"""
A two-agent vectorized environment.
- Possible actions for each agent are (C)ooperate and (D)efect.
+ Possible actions for each agent are Cooperate (C) and Defect (D).
"""
# Possible actions
diff --git a/examples/LOLA/helpers/utils.py b/examples/LOLA/helpers/utils.py
index afa9e609..20f67be5 100644
--- a/examples/LOLA/helpers/utils.py
+++ b/examples/LOLA/helpers/utils.py
@@ -27,7 +27,7 @@ def step(ipd, theta1, theta2, values1, values2, args):
(s1, s2), _ = ipd.reset()
score1 = 0
score2 = 0
- for t in range(args.len_rollout):
+ for _ in range(args.len_rollout):
a1, lp1, v1 = act(s1, theta1, values1)
a2, lp2, v2 = act(s2, theta2, values2)
(s1, s2), (r1, r2), _, _ = ipd.step((a1, a2))
@@ -82,7 +82,7 @@ def dice_objective(self, use_baseline=True):
if use_baseline:
# variance_reduction:
baseline_term = torch.mean(
- torch.sum((1 - magic_box(stochastic_nodes)) * discounted_values, dim=1)
+ torch.sum((1 - magic_box(stochastic_nodes)) * discounted_values, dim=1),
)
dice_objective = dice_objective + baseline_term
@@ -109,7 +109,7 @@ def sample(ipd, policy, value, args):
(s1, s2), _ = ipd.reset()
memory_agent1 = Memory(args)
memory_agent2 = Memory(args)
- for t in range(args.len_rollout):
+ for _ in range(args.len_rollout):
a1, lp1, v1 = act(s1, theta1, value1)
a2, lp2, v2 = act(s2, theta2, value2)
(s1, s2), (r1, r2), _, _ = ipd.step((a1, a2))
diff --git a/examples/LOLA/lola_dice.py b/examples/LOLA/lola_dice.py
index 20c0ff0e..6dbaaf24 100644
--- a/examples/LOLA/lola_dice.py
+++ b/examples/LOLA/lola_dice.py
@@ -96,17 +96,16 @@ def main(args):
score = step(ipd, agent1.theta, agent2.theta, agent1.values, agent2.values, args)
joint_scores.append(0.5 * (score[0] + score[1]))
- # print
if update % 10 == 0:
p1 = [p.item() for p in torch.sigmoid(agent1.theta)]
p2 = [p.item() for p in torch.sigmoid(agent2.theta)]
print(
'update',
update,
- 'score (%.3f,%.3f)' % (score[0], score[1]),
+ f'score ({score[0]:.3f},{score[1]:.3f})',
'policy (agent1) = {%.3f, %.3f, %.3f, %.3f, %.3f}'
% (p1[0], p1[1], p1[2], p1[3], p1[4]),
- ' (agent2) = {%.3f, %.3f, %.3f, %.3f, %.3f}' % (p2[0], p2[1], p2[2], p2[3], p2[4]),
+ f' (agent2) = {{{p2[0]:.3f}, {p2[1]:.3f}, {p2[2]:.3f}, {p2[3]:.3f}, {p2[4]:.3f}}}',
)
return joint_scores
@@ -114,7 +113,7 @@ def main(args):
if __name__ == '__main__':
args = parse_args()
- joint_score = dict()
+ joint_score = {}
for nla in range(3):
args.n_lookaheads = nla
joint_score[nla] = main(args)
diff --git a/examples/MAML-RL/func_maml.py b/examples/MAML-RL/func_maml.py
index 2534caeb..f3a00642 100644
--- a/examples/MAML-RL/func_maml.py
+++ b/examples/MAML-RL/func_maml.py
@@ -103,9 +103,10 @@ def evaluate(env, seed, task_num, fpolicy, params):
inner_opt = torchopt.MetaSGD(lr=0.5)
env = gym.make(
'TabularMDP-v0',
- **dict(
- num_states=STATE_DIM, num_actions=ACTION_DIM, max_episode_steps=TRAJ_LEN, seed=args.seed
- ),
+ num_states=STATE_DIM,
+ num_actions=ACTION_DIM,
+ max_episode_steps=TRAJ_LEN,
+ seed=args.seed,
)
tasks = env.sample_tasks(num_tasks=task_num)
@@ -131,9 +132,10 @@ def main(args):
# Env
env = gym.make(
'TabularMDP-v0',
- **dict(
- num_states=STATE_DIM, num_actions=ACTION_DIM, max_episode_steps=TRAJ_LEN, seed=args.seed
- ),
+ num_states=STATE_DIM,
+ num_actions=ACTION_DIM,
+ max_episode_steps=TRAJ_LEN,
+ seed=args.seed,
)
# Policy
policy = CategoricalMLPPolicy(input_size=STATE_DIM, output_size=ACTION_DIM)
@@ -171,7 +173,11 @@ def main(args):
outer_opt.step()
test_pre_reward_ls, test_post_reward_ls = evaluate(
- env, args.seed, TASK_NUM, fpolicy, params
+ env,
+ args.seed,
+ TASK_NUM,
+ fpolicy,
+ params,
)
train_pre_reward.append(sum(train_pre_reward_ls) / TASK_NUM)
@@ -188,7 +194,7 @@ def main(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser(
- description='Reinforcement learning with Model-Agnostic Meta-Learning (MAML) - Train'
+ description='Reinforcement learning with Model-Agnostic Meta-Learning (MAML) - Train',
)
parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
args = parser.parse_args()
diff --git a/examples/MAML-RL/helpers/policy_torchrl.py b/examples/MAML-RL/helpers/policy_torchrl.py
index 103a4ec5..91bdb269 100644
--- a/examples/MAML-RL/helpers/policy_torchrl.py
+++ b/examples/MAML-RL/helpers/policy_torchrl.py
@@ -13,9 +13,7 @@
# limitations under the License.
# ==============================================================================
-import torch
import torch.nn as nn
-from torch.distributions import Categorical
from torchrl.modules import (
ActorValueOperator,
OneHotCategorical,
diff --git a/examples/MAML-RL/helpers/tabular_mdp.py b/examples/MAML-RL/helpers/tabular_mdp.py
index 3a6bee60..f8feb7b7 100644
--- a/examples/MAML-RL/helpers/tabular_mdp.py
+++ b/examples/MAML-RL/helpers/tabular_mdp.py
@@ -49,7 +49,10 @@ def __init__(self, num_states, num_actions, max_episode_steps, seed, task=None):
self.action_space = spaces.Discrete(num_actions)
self.observation_space = spaces.Box(
- low=0.0, high=1.0, shape=(num_states,), dtype=np.float32
+ low=0.0,
+ high=1.0,
+ shape=(num_states,),
+ dtype=np.float32,
)
self._task = task
@@ -62,7 +65,8 @@ def __init__(self, num_states, num_actions, max_episode_steps, seed, task=None):
),
)
self._rewards_mean = task.get(
- 'rewards_mean', np.zeros((num_states, num_actions), dtype=np.float32)
+ 'rewards_mean',
+ np.zeros((num_states, num_actions), dtype=np.float32),
)
self._state = 0
self._elapsed_steps = None
@@ -79,7 +83,9 @@ def sample_tasks(self, num_tasks):
size=(num_tasks, self.num_states, self.num_actions),
)
rewards_mean = self.np_random.normal(
- 1.0, 1.0, size=(num_tasks, self.num_states, self.num_actions)
+ 1.0,
+ 1.0,
+ size=(num_tasks, self.num_states, self.num_actions),
)
tasks = [
{'transitions': transition, 'rewards_mean': reward_mean}
@@ -93,7 +99,6 @@ def reset_task(self, task):
self._rewards_mean = task['rewards_mean']
def reset(self):
- # From [1]: "an episode always starts on the first state"
self._state = 0
observation = np.zeros(self.num_states, dtype=np.float32)
observation[self._state] = 1.0
@@ -107,13 +112,11 @@ def step(self, action):
reward = self.np_random.normal(mean, 1.0)
self._state = self.np_random.choice(
- self.num_states, p=self._transitions[self._state, action]
+ self.num_states,
+ p=self._transitions[self._state, action],
)
observation = np.zeros(self.num_states, dtype=np.float32)
observation[self._state] = 1.0
self._elapsed_steps += 1
- if self._elapsed_steps >= self.max_episode_steps:
- done = True
- else:
- done = False
+ done = self._elapsed_steps >= self.max_episode_steps
return observation, reward, done, {'task': self._task}
diff --git a/examples/MAML-RL/maml.py b/examples/MAML-RL/maml.py
index d4aa8c3c..42fddbac 100644
--- a/examples/MAML-RL/maml.py
+++ b/examples/MAML-RL/maml.py
@@ -108,12 +108,10 @@ def evaluate(env, seed, task_num, policy):
inner_opt = torchopt.MetaSGD(policy, lr=0.1)
env = gym.make(
'TabularMDP-v0',
- **dict(
- num_states=STATE_DIM,
- num_actions=ACTION_DIM,
- max_episode_steps=TRAJ_LEN,
- seed=args.seed,
- ),
+ num_states=STATE_DIM,
+ num_actions=ACTION_DIM,
+ max_episode_steps=TRAJ_LEN,
+ seed=args.seed,
)
tasks = env.sample_tasks(num_tasks=task_num)
policy_state_dict = torchopt.extract_state_dict(policy)
@@ -141,12 +139,10 @@ def main(args):
# Env
env = gym.make(
'TabularMDP-v0',
- **dict(
- num_states=STATE_DIM,
- num_actions=ACTION_DIM,
- max_episode_steps=TRAJ_LEN,
- seed=args.seed,
- ),
+ num_states=STATE_DIM,
+ num_actions=ACTION_DIM,
+ max_episode_steps=TRAJ_LEN,
+ seed=args.seed,
)
# Policy
policy = CategoricalMLPPolicy(input_size=STATE_DIM, output_size=ACTION_DIM)
@@ -197,7 +193,7 @@ def main(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser(
- description='Reinforcement learning with Model-Agnostic Meta-Learning (MAML) - Train'
+ description='Reinforcement learning with Model-Agnostic Meta-Learning (MAML) - Train',
)
parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
args = parser.parse_args()
diff --git a/examples/MAML-RL/maml_torchrl.py b/examples/MAML-RL/maml_torchrl.py
index 3cb72b49..225f73bc 100644
--- a/examples/MAML-RL/maml_torchrl.py
+++ b/examples/MAML-RL/maml_torchrl.py
@@ -14,9 +14,7 @@
# ==============================================================================
import argparse
-import time
-import numpy as np
import torch
import torch.optim as optim
import tqdm
@@ -60,8 +58,6 @@ def a2c_loss(traj, policy_module, value_module, value_coef):
next_traj = step_tensordict(traj)
next_value = value_module(next_traj).get('state_value').detach()
- # tderror = TDEstimate(GAMMA, value_module, gradient_mode=True)
- # tderror = TDLambdaEstimate(GAMMA, LAMBDA, value_module, gradient_mode=True)
advantage = td_lambda_advantage_estimate(GAMMA, LAMBDA, value, next_value, reward, done)
action_loss = -(advantage.detach() * log_probs.view_as(advantage)).mean()
value_error = advantage
@@ -131,14 +127,17 @@ def main(args):
# init training
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
+
# Env
- lambda_env = lambda: GymEnv(
- 'TabularMDP-v0',
- num_states=STATE_DIM,
- num_actions=ACTION_DIM,
- max_episode_steps=TRAJ_LEN,
- device=device,
- )
+ def lambda_env():
+ return GymEnv(
+ 'TabularMDP-v0',
+ num_states=STATE_DIM,
+ num_actions=ACTION_DIM,
+ max_episode_steps=TRAJ_LEN,
+ device=device,
+ )
+
if args.parallel:
env = ParallelEnv(
NUM_ENVS,
@@ -171,8 +170,7 @@ def main(args):
dummy_env.set_seed(args.seed)
pbar = tqdm.tqdm(range(outer_iters))
- for i in pbar:
- # print("i: ", i)
+ for _ in pbar:
tasks = dummy_env.sample_tasks(num_tasks=TASK_NUM)
train_pre_reward_ls = []
train_post_reward_ls = []
@@ -184,7 +182,7 @@ def main(args):
env.reset_task(tasks[idx])
policy_module = actor_critic_module.get_policy_operator()
value_module = actor_critic_module.get_value_operator()
- for k in range(inner_iters):
+ for __ in range(inner_iters):
with set_exploration_mode('random'), torch.no_grad():
pre_traj_td = (
env.rollout(
@@ -236,7 +234,7 @@ def main(args):
f'train_pre_reward: {train_pre_reward[-1]: 4.4f}, '
f'train_post_reward: {train_post_reward[-1]: 4.4f}, '
f'test_pre_reward: {test_pre_reward[-1]: 4.4f}, '
- f'test_post_reward: {test_post_reward[-1]: 4.4f}, '
+ f'test_post_reward: {test_post_reward[-1]: 4.4f}, ',
)
env.close()
@@ -244,7 +242,7 @@ def main(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser(
- description='Reinforcement learning with Model-Agnostic Meta-Learning (MAML) - Train'
+ description='Reinforcement learning with Model-Agnostic Meta-Learning (MAML) - Train',
)
parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
parser.add_argument('--parallel', action='store_true', help='run envs in parallel')
diff --git a/examples/MGRL/mgrl.py b/examples/MGRL/mgrl.py
index 152e4177..49eb79c4 100644
--- a/examples/MGRL/mgrl.py
+++ b/examples/MGRL/mgrl.py
@@ -55,7 +55,7 @@ def forward(self, x):
meta_optimizer = torchopt.SGD([gamma], lr=5e-1)
net_state = torchopt.extract_state_dict(net)
for i in range(outer_iters):
- for j in range(inner_iters):
+ for _ in range(inner_iters):
trajectory, state = Rollout.get()
backup = Rollout.rollout(trajectory, torch.sigmoid(gamma))
pred_value = net(state.float())
diff --git a/examples/distributed/few-shot/helpers/omniglot_loaders.py b/examples/distributed/few-shot/helpers/omniglot_loaders.py
index e8f02042..52fab28a 100644
--- a/examples/distributed/few-shot/helpers/omniglot_loaders.py
+++ b/examples/distributed/few-shot/helpers/omniglot_loaders.py
@@ -80,7 +80,7 @@ def __len__(self):
def _check_exists(self):
return os.path.exists(
- os.path.join(self.root, self.processed_folder, 'images_evaluation')
+ os.path.join(self.root, self.processed_folder, 'images_evaluation'),
) and os.path.exists(os.path.join(self.root, self.processed_folder, 'images_background'))
def download(self):
@@ -118,7 +118,7 @@ def download(self):
def find_classes(root_dir):
retour = []
- for root, dirs, files in os.walk(root_dir):
+ for root, _, files in os.walk(root_dir):
for f in files:
if f.endswith('png'):
r = root.split('/')
@@ -164,14 +164,14 @@ def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, rng, device=Non
lambda x: np.reshape(x, (imgsz, imgsz, 1)),
lambda x: np.transpose(x, [2, 0, 1]),
lambda x: x / 255.0,
- ]
+ ],
),
)
# {label: [img1, img2..., img20], label2: [img1, img2, ...], ... 1623 labels in total}
temp = {}
for img, label in self.x:
- if label in temp.keys():
+ if label in temp:
temp[label].append(img)
else:
temp[label] = [img]
@@ -196,12 +196,9 @@ def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, rng, device=Non
self.x = np.load(os.path.join(root, 'omniglot.npy'))
print('load from omniglot.npy.')
- # [1623, 20, 84, 84, 1]
- # TODO: can not shuffle here, we must keep training and test set distinct!
+ # NOTE: do not shuffle here, we must keep training and test set distinct!
self.x_train, self.x_test = self.x[:1200], self.x[1200:]
- # self.normalization()
-
self.batchsz = batchsz
self.n_cls = self.x.shape[0] # 1623
self.n_way = n_way # n way
@@ -230,7 +227,6 @@ def normalization(self):
self.std = np.std(self.x_train)
self.max = np.max(self.x_train)
self.min = np.min(self.x_train)
- # print("before norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
self.x_train = (self.x_train - self.mean) / self.std
self.x_test = (self.x_test - self.mean) / self.std
@@ -239,8 +235,6 @@ def normalization(self):
self.max = np.max(self.x_train)
self.min = np.min(self.x_train)
- # print("after norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
-
def load_data_cache(self, data_pack):
"""
Collects several batches data for N-shot learning
@@ -253,10 +247,9 @@ def load_data_cache(self, data_pack):
querysz = self.k_query * self.n_way
data_cache = []
- # print('preload next 50 caches of batchsz of batch.')
- for sample in range(10): # num of episodes
+ for _sample in range(10): # num of episodes
x_spts, y_spts, x_qrys, y_qrys = [], [], [], []
- for i in range(self.batchsz): # one batch means one set
+ for _ in range(self.batchsz): # one batch means one set
x_spt, y_spt, x_qry, y_qry = [], [], [], []
selected_cls = self.rng.choice(data_pack.shape[0], self.n_way, False)
@@ -272,12 +265,18 @@ def load_data_cache(self, data_pack):
# shuffle inside a batch
perm = self.rng.permutation(self.n_way * self.k_shot)
x_spt = np.array(x_spt).reshape(
- self.n_way * self.k_shot, 1, self.resize, self.resize
+ self.n_way * self.k_shot,
+ 1,
+ self.resize,
+ self.resize,
)[perm]
y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm]
perm = self.rng.permutation(self.n_way * self.k_query)
x_qry = np.array(x_qry).reshape(
- self.n_way * self.k_query, 1, self.resize, self.resize
+ self.n_way * self.k_query,
+ 1,
+ self.resize,
+ self.resize,
)[perm]
y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm]
@@ -287,20 +286,29 @@ def load_data_cache(self, data_pack):
x_qrys.append(x_qry)
y_qrys.append(y_qry)
- # [b, setsz, 1, 84, 84]
x_spts = np.array(x_spts, dtype=np.float32).reshape(
- self.batchsz, setsz, 1, self.resize, self.resize
- )
- y_spts = np.array(y_spts, dtype=np.int).reshape(self.batchsz, setsz)
- # [b, qrysz, 1, 84, 84]
+ self.batchsz,
+ setsz,
+ 1,
+ self.resize,
+ self.resize,
+ ) # [b, setsz, 1, 84, 84]
+ y_spts = np.array(y_spts, dtype=np.int).reshape(
+ self.batchsz,
+ setsz,
+ ) # [b, qrysz, 1, 84, 84]
x_qrys = np.array(x_qrys, dtype=np.float32).reshape(
- self.batchsz, querysz, 1, self.resize, self.resize
+ self.batchsz,
+ querysz,
+ 1,
+ self.resize,
+ self.resize,
)
y_qrys = np.array(y_qrys, dtype=np.int).reshape(self.batchsz, querysz)
- x_spts, y_spts, x_qrys, y_qrys = [
+ x_spts, y_spts, x_qrys, y_qrys = (
torch.from_numpy(z).to(self.device) for z in [x_spts, y_spts, x_qrys, y_qrys]
- ]
+ )
data_cache.append([x_spts, y_spts, x_qrys, y_qrys])
diff --git a/examples/distributed/few-shot/maml_omniglot.py b/examples/distributed/few-shot/maml_omniglot.py
index 867caf43..24601dfa 100644
--- a/examples/distributed/few-shot/maml_omniglot.py
+++ b/examples/distributed/few-shot/maml_omniglot.py
@@ -127,7 +127,10 @@ def main():
argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=5)
argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15)
argparser.add_argument(
- '--task_num', type=int, help='meta batch size, namely task num', default=32
+ '--task_num',
+ type=int,
+ help='meta batch size, namely task num',
+ default=32,
)
argparser.add_argument('--seed', type=int, help='random seed', default=1)
args = argparser.parse_args()
@@ -231,7 +234,7 @@ def train(db: OmniglotNShot, net: nn.Module, meta_opt: optim.Adam, epoch: int, l
iter_time = time.time() - start_time
print(
- f'[Epoch {i:.2f}] Train Loss: {qry_loss:.2f} | Acc: {qry_acc:.2f} | Time: {iter_time:.2f}'
+ f'[Epoch {i:.2f}] Train Loss: {qry_loss:.2f} | Acc: {qry_acc:.2f} | Time: {iter_time:.2f}',
)
log.append(
{
@@ -240,7 +243,7 @@ def train(db: OmniglotNShot, net: nn.Module, meta_opt: optim.Adam, epoch: int, l
'acc': qry_acc,
'mode': 'train',
'time': time.time(),
- }
+ },
)
@@ -280,7 +283,7 @@ def test(db, net, epoch, log):
'acc': qry_accs,
'mode': 'test',
'time': time.time(),
- }
+ },
)
diff --git a/examples/distributed/few-shot/maml_omniglot_local_loader.py b/examples/distributed/few-shot/maml_omniglot_local_loader.py
index 7f042854..d7413770 100644
--- a/examples/distributed/few-shot/maml_omniglot_local_loader.py
+++ b/examples/distributed/few-shot/maml_omniglot_local_loader.py
@@ -163,7 +163,10 @@ def main():
argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=5)
argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15)
argparser.add_argument(
- '--task_num', type=int, help='meta batch size, namely task num', default=32
+ '--task_num',
+ type=int,
+ help='meta batch size, namely task num',
+ default=32,
)
argparser.add_argument('--seed', type=int, help='random seed', default=1)
args = argparser.parse_args()
@@ -274,7 +277,7 @@ def train(net: nn.Module, meta_opt: optim.Adam, epoch: int, log: list):
iter_time = time.time() - start_time
print(
- f'[Epoch {i:.2f}] Train Loss: {qry_loss:.2f} | Acc: {qry_acc:.2f} | Time: {iter_time:.2f}'
+ f'[Epoch {i:.2f}] Train Loss: {qry_loss:.2f} | Acc: {qry_acc:.2f} | Time: {iter_time:.2f}',
)
log.append(
{
@@ -283,7 +286,7 @@ def train(net: nn.Module, meta_opt: optim.Adam, epoch: int, log: list):
'acc': qry_acc,
'mode': 'train',
'time': time.time(),
- }
+ },
)
@@ -324,7 +327,7 @@ def test(net, epoch, log):
'acc': qry_accs,
'mode': 'test',
'time': time.time(),
- }
+ },
)
diff --git a/examples/few-shot/helpers/omniglot_loaders.py b/examples/few-shot/helpers/omniglot_loaders.py
index e8f02042..52fab28a 100644
--- a/examples/few-shot/helpers/omniglot_loaders.py
+++ b/examples/few-shot/helpers/omniglot_loaders.py
@@ -80,7 +80,7 @@ def __len__(self):
def _check_exists(self):
return os.path.exists(
- os.path.join(self.root, self.processed_folder, 'images_evaluation')
+ os.path.join(self.root, self.processed_folder, 'images_evaluation'),
) and os.path.exists(os.path.join(self.root, self.processed_folder, 'images_background'))
def download(self):
@@ -118,7 +118,7 @@ def download(self):
def find_classes(root_dir):
retour = []
- for root, dirs, files in os.walk(root_dir):
+ for root, _, files in os.walk(root_dir):
for f in files:
if f.endswith('png'):
r = root.split('/')
@@ -164,14 +164,14 @@ def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, rng, device=Non
lambda x: np.reshape(x, (imgsz, imgsz, 1)),
lambda x: np.transpose(x, [2, 0, 1]),
lambda x: x / 255.0,
- ]
+ ],
),
)
# {label: [img1, img2..., img20], label2: [img1, img2, ...], ... 1623 labels in total}
temp = {}
for img, label in self.x:
- if label in temp.keys():
+ if label in temp:
temp[label].append(img)
else:
temp[label] = [img]
@@ -196,12 +196,9 @@ def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, rng, device=Non
self.x = np.load(os.path.join(root, 'omniglot.npy'))
print('load from omniglot.npy.')
- # [1623, 20, 84, 84, 1]
- # TODO: can not shuffle here, we must keep training and test set distinct!
+ # NOTE: do not shuffle here, we must keep training and test set distinct!
self.x_train, self.x_test = self.x[:1200], self.x[1200:]
- # self.normalization()
-
self.batchsz = batchsz
self.n_cls = self.x.shape[0] # 1623
self.n_way = n_way # n way
@@ -230,7 +227,6 @@ def normalization(self):
self.std = np.std(self.x_train)
self.max = np.max(self.x_train)
self.min = np.min(self.x_train)
- # print("before norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
self.x_train = (self.x_train - self.mean) / self.std
self.x_test = (self.x_test - self.mean) / self.std
@@ -239,8 +235,6 @@ def normalization(self):
self.max = np.max(self.x_train)
self.min = np.min(self.x_train)
- # print("after norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
-
def load_data_cache(self, data_pack):
"""
Collects several batches data for N-shot learning
@@ -253,10 +247,9 @@ def load_data_cache(self, data_pack):
querysz = self.k_query * self.n_way
data_cache = []
- # print('preload next 50 caches of batchsz of batch.')
- for sample in range(10): # num of episodes
+ for _sample in range(10): # num of episodes
x_spts, y_spts, x_qrys, y_qrys = [], [], [], []
- for i in range(self.batchsz): # one batch means one set
+ for _ in range(self.batchsz): # one batch means one set
x_spt, y_spt, x_qry, y_qry = [], [], [], []
selected_cls = self.rng.choice(data_pack.shape[0], self.n_way, False)
@@ -272,12 +265,18 @@ def load_data_cache(self, data_pack):
# shuffle inside a batch
perm = self.rng.permutation(self.n_way * self.k_shot)
x_spt = np.array(x_spt).reshape(
- self.n_way * self.k_shot, 1, self.resize, self.resize
+ self.n_way * self.k_shot,
+ 1,
+ self.resize,
+ self.resize,
)[perm]
y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm]
perm = self.rng.permutation(self.n_way * self.k_query)
x_qry = np.array(x_qry).reshape(
- self.n_way * self.k_query, 1, self.resize, self.resize
+ self.n_way * self.k_query,
+ 1,
+ self.resize,
+ self.resize,
)[perm]
y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm]
@@ -287,20 +286,29 @@ def load_data_cache(self, data_pack):
x_qrys.append(x_qry)
y_qrys.append(y_qry)
- # [b, setsz, 1, 84, 84]
x_spts = np.array(x_spts, dtype=np.float32).reshape(
- self.batchsz, setsz, 1, self.resize, self.resize
- )
- y_spts = np.array(y_spts, dtype=np.int).reshape(self.batchsz, setsz)
- # [b, qrysz, 1, 84, 84]
+ self.batchsz,
+ setsz,
+ 1,
+ self.resize,
+ self.resize,
+ ) # [b, setsz, 1, 84, 84]
+ y_spts = np.array(y_spts, dtype=np.int).reshape(
+ self.batchsz,
+ setsz,
+ ) # [b, qrysz, 1, 84, 84]
x_qrys = np.array(x_qrys, dtype=np.float32).reshape(
- self.batchsz, querysz, 1, self.resize, self.resize
+ self.batchsz,
+ querysz,
+ 1,
+ self.resize,
+ self.resize,
)
y_qrys = np.array(y_qrys, dtype=np.int).reshape(self.batchsz, querysz)
- x_spts, y_spts, x_qrys, y_qrys = [
+ x_spts, y_spts, x_qrys, y_qrys = (
torch.from_numpy(z).to(self.device) for z in [x_spts, y_spts, x_qrys, y_qrys]
- ]
+ )
data_cache.append([x_spts, y_spts, x_qrys, y_qrys])
diff --git a/examples/few-shot/maml_omniglot.py b/examples/few-shot/maml_omniglot.py
index 17172bdd..d798aa1d 100644
--- a/examples/few-shot/maml_omniglot.py
+++ b/examples/few-shot/maml_omniglot.py
@@ -65,7 +65,10 @@ def main():
argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=5)
argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15)
argparser.add_argument(
- '--task_num', type=int, help='meta batch size, namely task num', default=32
+ '--task_num',
+ type=int,
+ help='meta batch size, namely task num',
+ default=32,
)
argparser.add_argument('--seed', type=int, help='random seed', default=1)
args = argparser.parse_args()
@@ -178,7 +181,7 @@ def train(db, net, meta_opt, epoch, log):
iter_time = time.time() - start_time
print(
- f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
+ f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}',
)
log.append(
{
@@ -187,7 +190,7 @@ def train(db, net, meta_opt, epoch, log):
'acc': qry_accs,
'mode': 'train',
'time': time.time(),
- }
+ },
)
@@ -204,7 +207,7 @@ def test(db, net, epoch, log):
qry_losses = []
qry_accs = []
- for batch_idx in range(n_test_iter):
+ for _ in range(n_test_iter):
x_spt, y_spt, x_qry, y_qry = db.next('test')
task_num = x_spt.size(0)
@@ -245,7 +248,7 @@ def test(db, net, epoch, log):
'acc': qry_accs,
'mode': 'test',
'time': time.time(),
- }
+ },
)
diff --git a/examples/iMAML/helpers/omniglot_loaders.py b/examples/iMAML/helpers/omniglot_loaders.py
index e8f02042..52fab28a 100644
--- a/examples/iMAML/helpers/omniglot_loaders.py
+++ b/examples/iMAML/helpers/omniglot_loaders.py
@@ -80,7 +80,7 @@ def __len__(self):
def _check_exists(self):
return os.path.exists(
- os.path.join(self.root, self.processed_folder, 'images_evaluation')
+ os.path.join(self.root, self.processed_folder, 'images_evaluation'),
) and os.path.exists(os.path.join(self.root, self.processed_folder, 'images_background'))
def download(self):
@@ -118,7 +118,7 @@ def download(self):
def find_classes(root_dir):
retour = []
- for root, dirs, files in os.walk(root_dir):
+ for root, _, files in os.walk(root_dir):
for f in files:
if f.endswith('png'):
r = root.split('/')
@@ -164,14 +164,14 @@ def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, rng, device=Non
lambda x: np.reshape(x, (imgsz, imgsz, 1)),
lambda x: np.transpose(x, [2, 0, 1]),
lambda x: x / 255.0,
- ]
+ ],
),
)
# {label: [img1, img2..., img20], label2: [img1, img2, ...], ... 1623 labels in total}
temp = {}
for img, label in self.x:
- if label in temp.keys():
+ if label in temp:
temp[label].append(img)
else:
temp[label] = [img]
@@ -196,12 +196,9 @@ def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, rng, device=Non
self.x = np.load(os.path.join(root, 'omniglot.npy'))
print('load from omniglot.npy.')
- # [1623, 20, 84, 84, 1]
- # TODO: can not shuffle here, we must keep training and test set distinct!
+ # NOTE: do not shuffle here, we must keep training and test set distinct!
self.x_train, self.x_test = self.x[:1200], self.x[1200:]
- # self.normalization()
-
self.batchsz = batchsz
self.n_cls = self.x.shape[0] # 1623
self.n_way = n_way # n way
@@ -230,7 +227,6 @@ def normalization(self):
self.std = np.std(self.x_train)
self.max = np.max(self.x_train)
self.min = np.min(self.x_train)
- # print("before norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
self.x_train = (self.x_train - self.mean) / self.std
self.x_test = (self.x_test - self.mean) / self.std
@@ -239,8 +235,6 @@ def normalization(self):
self.max = np.max(self.x_train)
self.min = np.min(self.x_train)
- # print("after norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
-
def load_data_cache(self, data_pack):
"""
Collects several batches data for N-shot learning
@@ -253,10 +247,9 @@ def load_data_cache(self, data_pack):
querysz = self.k_query * self.n_way
data_cache = []
- # print('preload next 50 caches of batchsz of batch.')
- for sample in range(10): # num of episodes
+ for _sample in range(10): # num of episodes
x_spts, y_spts, x_qrys, y_qrys = [], [], [], []
- for i in range(self.batchsz): # one batch means one set
+ for _ in range(self.batchsz): # one batch means one set
x_spt, y_spt, x_qry, y_qry = [], [], [], []
selected_cls = self.rng.choice(data_pack.shape[0], self.n_way, False)
@@ -272,12 +265,18 @@ def load_data_cache(self, data_pack):
# shuffle inside a batch
perm = self.rng.permutation(self.n_way * self.k_shot)
x_spt = np.array(x_spt).reshape(
- self.n_way * self.k_shot, 1, self.resize, self.resize
+ self.n_way * self.k_shot,
+ 1,
+ self.resize,
+ self.resize,
)[perm]
y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm]
perm = self.rng.permutation(self.n_way * self.k_query)
x_qry = np.array(x_qry).reshape(
- self.n_way * self.k_query, 1, self.resize, self.resize
+ self.n_way * self.k_query,
+ 1,
+ self.resize,
+ self.resize,
)[perm]
y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm]
@@ -287,20 +286,29 @@ def load_data_cache(self, data_pack):
x_qrys.append(x_qry)
y_qrys.append(y_qry)
- # [b, setsz, 1, 84, 84]
x_spts = np.array(x_spts, dtype=np.float32).reshape(
- self.batchsz, setsz, 1, self.resize, self.resize
- )
- y_spts = np.array(y_spts, dtype=np.int).reshape(self.batchsz, setsz)
- # [b, qrysz, 1, 84, 84]
+ self.batchsz,
+ setsz,
+ 1,
+ self.resize,
+ self.resize,
+ ) # [b, setsz, 1, 84, 84]
+ y_spts = np.array(y_spts, dtype=np.int).reshape(
+ self.batchsz,
+ setsz,
+ ) # [b, qrysz, 1, 84, 84]
x_qrys = np.array(x_qrys, dtype=np.float32).reshape(
- self.batchsz, querysz, 1, self.resize, self.resize
+ self.batchsz,
+ querysz,
+ 1,
+ self.resize,
+ self.resize,
)
y_qrys = np.array(y_qrys, dtype=np.int).reshape(self.batchsz, querysz)
- x_spts, y_spts, x_qrys, y_qrys = [
+ x_spts, y_spts, x_qrys, y_qrys = (
torch.from_numpy(z).to(self.device) for z in [x_spts, y_spts, x_qrys, y_qrys]
- ]
+ )
data_cache.append([x_spts, y_spts, x_qrys, y_qrys])
diff --git a/examples/iMAML/imaml_omniglot.py b/examples/iMAML/imaml_omniglot.py
index 09344900..8a6960ba 100644
--- a/examples/iMAML/imaml_omniglot.py
+++ b/examples/iMAML/imaml_omniglot.py
@@ -90,10 +90,16 @@ def main():
argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=5)
argparser.add_argument('--inner_steps', type=int, help='number of inner steps', default=5)
argparser.add_argument(
- '--reg_params', type=float, help='regularization parameters', default=2.0
+ '--reg_params',
+ type=float,
+ help='regularization parameters',
+ default=2.0,
)
argparser.add_argument(
- '--task_num', type=int, help='meta batch size, namely task num', default=16
+ '--task_num',
+ type=int,
+ help='meta batch size, namely task num',
+ default=16,
)
argparser.add_argument('--seed', type=int, help='random seed', default=1)
args = argparser.parse_args()
@@ -193,7 +199,7 @@ def train(db, net, meta_opt, epoch, log, args):
iter_time = time.time() - start_time
print(
- f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
+ f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}',
)
log.append(
{
@@ -202,7 +208,7 @@ def train(db, net, meta_opt, epoch, log, args):
'acc': qry_accs,
'mode': 'train',
'time': time.time(),
- }
+ },
)
@@ -222,7 +228,7 @@ def test(db, net, epoch, log, args):
n_inner_iter = args.inner_steps
reg_param = args.reg_params
- for batch_idx in range(n_test_iter):
+ for _ in range(n_test_iter):
x_spt, y_spt, x_qry, y_qry = db.next('test')
task_num = x_spt.size(0)
@@ -254,7 +260,7 @@ def test(db, net, epoch, log, args):
'acc': qry_accs,
'mode': 'test',
'time': time.time(),
- }
+ },
)
diff --git a/examples/iMAML/imaml_omniglot_functional.py b/examples/iMAML/imaml_omniglot_functional.py
index 1c0a089a..60fd4108 100644
--- a/examples/iMAML/imaml_omniglot_functional.py
+++ b/examples/iMAML/imaml_omniglot_functional.py
@@ -49,10 +49,16 @@ def main():
argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=5)
argparser.add_argument('--inner_steps', type=int, help='number of inner steps', default=5)
argparser.add_argument(
- '--reg_params', type=float, help='regularization parameters', default=2.0
+ '--reg_params',
+ type=float,
+ help='regularization parameters',
+ default=2.0,
)
argparser.add_argument(
- '--task_num', type=int, help='meta batch size, namely task num', default=16
+ '--task_num',
+ type=int,
+ help='meta batch size, namely task num',
+ default=16,
)
argparser.add_argument('--seed', type=int, help='random seed', default=1)
args = argparser.parse_args()
@@ -167,7 +173,7 @@ def train(db, model, meta_opt_and_state, epoch, log, args):
iter_time = time.time() - start_time
print(
- f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
+ f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}',
)
log.append(
{
@@ -176,7 +182,7 @@ def train(db, model, meta_opt_and_state, epoch, log, args):
'acc': qry_accs,
'mode': 'train',
'time': time.time(),
- }
+ },
)
return (meta_opt, meta_opt_state)
@@ -196,7 +202,7 @@ def test(db, model, epoch, log, args):
qry_losses = []
qry_accs = []
- for batch_idx in range(n_test_iter):
+ for _ in range(n_test_iter):
x_spt, y_spt, x_qry, y_qry = db.next('test')
task_num = x_spt.size(0)
@@ -235,7 +241,7 @@ def test(db, model, epoch, log, args):
'acc': qry_accs,
'mode': 'test',
'time': time.time(),
- }
+ },
)
@@ -274,7 +280,9 @@ def train_imaml_inner_solver(params, meta_params, data, aux):
final_loss = loss + regularization_loss
grads = torch.autograd.grad(final_loss, params) # compute gradients
updates, inner_opt_state = inner_opt.update(
- grads, inner_opt_state, inplace=True
+ grads,
+ inner_opt_state,
+ inplace=True,
) # get updates
params = torchopt.apply_updates(params, updates, inplace=True)
return params
@@ -298,7 +306,9 @@ def test_imaml_inner_solver(params, meta_params, data, aux):
final_loss = loss + regularization_loss
grads = torch.autograd.grad(final_loss, params) # compute gradients
updates, inner_opt_state = inner_opt.update(
- grads, inner_opt_state, inplace=True
+ grads,
+ inner_opt_state,
+ inplace=True,
) # get updates
params = torchopt.apply_updates(params, updates, inplace=True)
return params
diff --git a/examples/visualize.py b/examples/visualize.py
index 56de2bd5..5e08267f 100644
--- a/examples/visualize.py
+++ b/examples/visualize.py
@@ -66,7 +66,8 @@ def draw_torchopt():
loss = F.mse_loss(pred, torch.ones_like(pred))
# draw computation graph
torchopt.visual.make_dot(loss, [net_state_0, net_state_1, {meta_param: 'meta_param'}]).render(
- 'torchopt_graph', format='svg'
+ 'torchopt_graph',
+ format='svg',
)
diff --git a/pyproject.toml b/pyproject.toml
index 12fd6fe3..47424855 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -11,7 +11,7 @@ description = "An efficient library for differentiable optimization for PyTorch.
readme = "README.md"
# Change this if wheels for `torch` is available
# Search "requires-python" and update all corresponding items
-requires-python = ">= 3.7"
+requires-python = ">= 3.8"
authors = [
{ name = "TorchOpt Contributors" },
{ name = "Jie Ren", email = "jieren9806@gmail.com" },
@@ -34,7 +34,6 @@ classifiers = [
"License :: OSI Approved :: Apache Software License",
# Sync with requires-python
"Programming Language :: Python :: 3",
- "Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
@@ -74,24 +73,35 @@ lint = [
"mypy >= 0.990",
"flake8",
"flake8-bugbear",
- "doc8 < 1.0.0a0", # unpin this when we drop support for Python 3.7
+ "flake8-comprehensions",
+ "flake8-docstrings",
+ "flake8-pyi",
+ "flake8-simplify",
+ "ruff",
+ "doc8",
"pydocstyle[toml]",
"pyenchant",
"cpplint",
"pre-commit",
]
test = [
- 'pytest',
- 'pytest-cov',
- 'pytest-xdist',
- 'jax[cpu] >= 0.3',
- 'jaxopt',
- 'optax',
+ "pytest",
+ "pytest-cov",
+ "pytest-xdist",
+ "jax[cpu] >= 0.3; platform_system != 'Windows'",
+ "jaxopt; platform_system != 'Windows'",
+ "optax; platform_system != 'Windows'",
]
+[tool.setuptools]
+include-package-data = true
+
[tool.setuptools.packages.find]
include = ["torchopt", "torchopt.*"]
+[tool.setuptools.package-data]
+torchopt = ['*.so', '*.pyd']
+
# Wheel builder ################################################################
# Reference: https://cibuildwheel.readthedocs.io
[tool.cibuildwheel]
@@ -170,7 +180,7 @@ safe = true
line-length = 100
skip-string-normalization = true
# Sync with requires-python
-target-version = ["py37", "py38", "py39", "py310", "py311"]
+target-version = ["py38", "py39", "py310", "py311"]
[tool.isort]
atomic = true
@@ -184,7 +194,7 @@ multi_line_output = 3
[tool.mypy]
# Sync with requires-python
-python_version = 3.7
+python_version = 3.8
pretty = true
show_error_codes = true
show_error_context = true
@@ -209,8 +219,99 @@ convention = "google"
[tool.doc8]
max-line-length = 500
+[tool.codespell]
+ignore-words = "docs/source/spelling_wordlist.txt"
+
+[tool.ruff]
+# Sync with requires-python
+target-version = "py38"
+line-length = 100
+show-source = true
+src = ["torchopt", "tests"]
+extend-exclude = ["examples"]
+select = [
+ "E", "W", # pycodestyle
+ "F", # pyflakes
+ "UP", # pyupgrade
+ "ANN", # flake8-annotations
+ "S", # flake8-bandit
+ "BLE", # flake8-blind-except
+ "B", # flake8-bugbear
+ "COM", # flake8-commas
+ "C4", # flake8-comprehensions
+ "EXE", # flake8-executable
+ "ISC", # flake8-implicit-str-concat
+ "PIE", # flake8-pie
+ "PYI", # flake8-pyi
+ "Q", # flake8-quotes
+ "RSE", # flake8-raise
+ "RET", # flake8-return
+ "SIM", # flake8-simplify
+ "TID", # flake8-tidy-imports
+ "RUF", # ruff
+]
+ignore = [
+ # E501: line too long
+ # W505: doc line too long
+ # too long docstring due to long example blocks
+ "E501",
+ "W505",
+ # ANN101: missing type annotation for `self` in method
+ # ANN102: missing type annotation for `cls` in classmethod
+ "ANN101",
+ "ANN102",
+ # ANN401: dynamically typed expressions (typing.Any) are disallowed
+ "ANN401",
+ # S101: use of `assert` detected
+ # internal use and may never raise at runtime
+ "S101",
+ # PLR0402: use from {module} import {name} in lieu of alias
+ # use alias for import convention (e.g., `import torch.nn as nn`)
+ "PLR0402",
+]
+typing-modules = ["torchopt.typing"]
+
+[tool.ruff.per-file-ignores]
+"__init__.py" = [
+ "F401", # unused-import
+]
+"torchopt/pytree.py" = [
+ "F401", # unused-import
+ "F403", # import-star
+ "F405", # import-star-usage
+]
+"setup.py" = [
+ "ANN", # flake8-annotations
+]
+"tests/**/*.py" = [
+ "ANN", # flake8-annotations
+ "S", # flake8-bandit
+ "BLE", # flake8-blind-except
+]
+"tests/test_import.py" = [
+ "B018", # useless-expression
+ "F401", # unused-import
+ "F811", # redefined-while-unused
+]
+
+[tool.ruff.flake8-annotations]
+allow-star-arg-any = true
+
+[tool.ruff.flake8-quotes]
+docstring-quotes = "double"
+multiline-quotes = "double"
+inline-quotes = "single"
+
+[tool.ruff.flake8-tidy-imports]
+ban-relative-imports = "all"
+
+[tool.ruff.pylint]
+allow-magic-value-types = ["int", "str", "float"]
+
[tool.pytest.ini_options]
filterwarnings = [
"error",
'ignore:Explicitly requested dtype float64 requested in .* is not available, and will be truncated to dtype float32\.:UserWarning',
+ 'ignore:jax\.numpy\.DeviceArray is deprecated\. Use jax\.Array\.:DeprecationWarning',
+ 'ignore:.*functorch.*deprecate.*:UserWarning',
]
diff --git a/setup.py b/setup.py
index 0297d43e..cce04c65 100644
--- a/setup.py
+++ b/setup.py
@@ -85,9 +85,9 @@ def build_extension(self, ext):
try:
os.chdir(build_temp)
- self.spawn([cmake, ext.source_dir] + cmake_args)
+ self.spawn([cmake, ext.source_dir, *cmake_args])
if not self.dry_run:
- self.spawn([cmake, '--build', '.'] + build_args)
+ self.spawn([cmake, '--build', '.', *build_args])
finally:
os.chdir(HERE)
@@ -96,16 +96,16 @@ def build_extension(self, ext):
LINUX = platform.system() == 'Linux'
MACOS = platform.system() == 'Darwin'
WINDOWS = platform.system() == 'Windows'
-ext_kwargs = dict(
- cmdclass={'build_ext': cmake_build_ext},
- ext_modules=[
+ext_kwargs = {
+ 'cmdclass': {'build_ext': cmake_build_ext},
+ 'ext_modules': [
CMakeExtension(
'torchopt._C',
source_dir=HERE,
optional=not (LINUX and CIBUILDWHEEL),
- )
+ ),
],
-)
+}
TORCHOPT_NO_EXTENSIONS = (
bool(os.getenv('TORCHOPT_NO_EXTENSIONS', '')) or WINDOWS or (MACOS and CIBUILDWHEEL)
@@ -119,14 +119,14 @@ def build_extension(self, ext):
try:
if not version.__release__:
try:
- VERSION_CONTENT = VERSION_FILE.read_text(encoding='UTF-8')
+ VERSION_CONTENT = VERSION_FILE.read_text(encoding='utf-8')
VERSION_FILE.write_text(
data=re.sub(
r"""__version__\s*=\s*('[^']+'|"[^"]+")""",
- f"__version__ = '{version.__version__}'",
+ f'__version__ = {version.__version__!r}',
string=VERSION_CONTENT,
),
- encoding='UTF-8',
+ encoding='utf-8',
)
except OSError:
VERSION_CONTENT = None
@@ -134,11 +134,9 @@ def build_extension(self, ext):
setup(
name='torchopt',
version=version.__version__,
- package_data={'sharedlib': ['*.so', '*.pyd']},
- include_package_data=True,
**ext_kwargs,
)
finally:
if VERSION_CONTENT is not None:
- with VERSION_FILE.open(mode='wt', encoding='UTF-8', newline='') as file:
+ with VERSION_FILE.open(mode='wt', encoding='utf-8', newline='') as file:
file.write(VERSION_CONTENT)
diff --git a/src/adam_op/adam_op_impl_cpu.cpp b/src/adam_op/adam_op_impl_cpu.cpp
index b9c14e49..1135206d 100644
--- a/src/adam_op/adam_op_impl_cpu.cpp
+++ b/src/adam_op/adam_op_impl_cpu.cpp
@@ -40,8 +40,9 @@ void adamForwardInplaceCPUKernel(const other_t b1,
scalar_t *__restrict__ updates_ptr,
scalar_t *__restrict__ mu_ptr,
scalar_t *__restrict__ nu_ptr) {
-#pragma omp parallel for num_threads(std::min( \
- n / MIN_NUMEL_USE_OMP, static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP)
+#pragma omp parallel for num_threads( \
+ std::min(n / MIN_NUMEL_USE_OMP, \
+ static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP)
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t updates = updates_ptr[tid];
const scalar_t mu = mu_ptr[tid];
@@ -95,8 +96,9 @@ void adamForwardMuCPUKernel(const scalar_t *__restrict__ updates_ptr,
const other_t b1,
const size_t n,
scalar_t *__restrict__ mu_out_ptr) {
-#pragma omp parallel for num_threads(std::min( \
- n / MIN_NUMEL_USE_OMP, static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP)
+#pragma omp parallel for num_threads( \
+ std::min(n / MIN_NUMEL_USE_OMP, \
+ static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP)
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t updates = updates_ptr[tid];
const scalar_t mu = mu_ptr[tid];
@@ -128,8 +130,9 @@ void adamForwardNuCPUKernel(const scalar_t *__restrict__ updates_ptr,
const other_t b2,
const size_t n,
scalar_t *__restrict__ nu_out_ptr) {
-#pragma omp parallel for num_threads(std::min( \
- n / MIN_NUMEL_USE_OMP, static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP)
+#pragma omp parallel for num_threads( \
+ std::min(n / MIN_NUMEL_USE_OMP, \
+ static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP)
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t updates = updates_ptr[tid];
const scalar_t nu = nu_ptr[tid];
@@ -165,8 +168,9 @@ void adamForwardUpdatesCPUKernel(const scalar_t *__restrict__ new_mu_ptr,
const other_t eps_root,
const size_t n,
scalar_t *__restrict__ updates_out_ptr) {
-#pragma omp parallel for num_threads(std::min( \
- n / MIN_NUMEL_USE_OMP, static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP)
+#pragma omp parallel for num_threads( \
+ std::min(n / MIN_NUMEL_USE_OMP, \
+ static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP)
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t new_mu = new_mu_ptr[tid];
const scalar_t new_nu = new_nu_ptr[tid];
@@ -210,8 +214,9 @@ void adamBackwardMuCPUKernel(const scalar_t *__restrict__ dmu_ptr,
const size_t n,
scalar_t *__restrict__ dupdates_out_ptr,
scalar_t *__restrict__ dmu_out_ptr) {
-#pragma omp parallel for num_threads(std::min( \
- n / MIN_NUMEL_USE_OMP, static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP)
+#pragma omp parallel for num_threads( \
+ std::min(n / MIN_NUMEL_USE_OMP, \
+ static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP)
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t dmu = dmu_ptr[tid];
@@ -246,8 +251,9 @@ void adamBackwardNuCPUKernel(const scalar_t *__restrict__ dnu_ptr,
const size_t n,
scalar_t *__restrict__ dupdates_out_ptr,
scalar_t *__restrict__ dnu_out_ptr) {
-#pragma omp parallel for num_threads(std::min( \
- n / MIN_NUMEL_USE_OMP, static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP)
+#pragma omp parallel for num_threads( \
+ std::min(n / MIN_NUMEL_USE_OMP, \
+ static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP)
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t dnu = dnu_ptr[tid];
const scalar_t updates = updates_ptr[tid];
@@ -286,8 +292,9 @@ void adamBackwardUpdatesCPUKernel(const scalar_t *__restrict__ dupdates_ptr,
const size_t n,
scalar_t *__restrict__ dnew_mu_out_ptr,
scalar_t *__restrict__ dnew_nu_out_ptr) {
-#pragma omp parallel for num_threads(std::min( \
- n / MIN_NUMEL_USE_OMP, static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP)
+#pragma omp parallel for num_threads( \
+ std::min(n / MIN_NUMEL_USE_OMP, \
+ static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP)
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t dupdates = dupdates_ptr[tid];
const scalar_t updates = updates_ptr[tid];
diff --git a/tests/.coveragerc b/tests/.coveragerc
index 462c4c3a..4238e71d 100644
--- a/tests/.coveragerc
+++ b/tests/.coveragerc
@@ -6,3 +6,12 @@ omit =
../docs/*
../examples/*
../tutorials/*
+
+[report]
+exclude_lines =
+ pragma: no cover
+ raise NotImplementedError
+ class .*\bProtocol\):
+ @(abc\.)?abstractmethod
+ if __name__ == ('__main__'|"__main__"):
+ if TYPE_CHECKING:
diff --git a/tests/helpers.py b/tests/helpers.py
index 23e178f0..bedf0fb6 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -15,6 +15,7 @@
from __future__ import annotations
+import contextlib
import copy
import itertools
import os
@@ -79,15 +80,18 @@ def seed_everything(seed: int) -> None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
- try:
+ with contextlib.suppress(AttributeError):
torch.use_deterministic_algorithms(True)
- except AttributeError:
- pass
class MyLinear(nn.Module):
def __init__(
- self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool = True,
+ device=None,
+ dtype=None,
) -> None:
super().__init__()
self.linear = nn.Linear(
@@ -138,7 +142,8 @@ def get_model():
@torch.no_grad()
def get_models(
- device: torch.types.Device = None, dtype: torch.dtype = torch.float32
+ device: torch.types.Device = None,
+ dtype: torch.dtype = torch.float32,
) -> tuple[nn.Module, nn.Module, nn.Module, data.DataLoader]:
seed_everything(seed=42)
diff --git a/tests/requirements.txt b/tests/requirements.txt
index 6706dca5..87c994e1 100644
--- a/tests/requirements.txt
+++ b/tests/requirements.txt
@@ -3,9 +3,9 @@ torch >= 1.13
--requirement ../requirements.txt
-jax[cpu] >= 0.3
-jaxopt
-optax
+jax[cpu] >= 0.3; platform_system != 'Windows'
+jaxopt; platform_system != 'Windows'
+optax; platform_system != 'Windows'
pytest
pytest-cov
@@ -16,8 +16,12 @@ pylint[spelling] >= 2.15.0
mypy >= 0.990
flake8
flake8-bugbear
-# https://github.com/PyCQA/doc8/issues/112
-doc8 < 1.0.0a0
+flake8-comprehensions
+flake8-docstrings
+flake8-pyi
+flake8-simplify
+ruff
+doc8
pydocstyle[toml]
pyenchant
cpplint
diff --git a/tests/test_accelerated_op.py b/tests/test_accelerated_op.py
index 4821a03d..6cb45ca0 100644
--- a/tests/test_accelerated_op.py
+++ b/tests/test_accelerated_op.py
@@ -103,7 +103,10 @@ def test_accelerated_op(
grads = torch.autograd.grad(loss_ref, params_ref, allow_unused=True)
updates, optim_state_ref = optim_ref.update(
- grads, optim_state_ref, params=params, inplace=inplace
+ grads,
+ optim_state_ref,
+ params=params,
+ inplace=inplace,
)
params_ref = torchopt.apply_updates(params_ref, updates, inplace=inplace)
@@ -154,10 +157,14 @@ def maml_inner_solver(params, data, use_accelerated_op):
pred = f(params, b, x)
inner_loss = F.cross_entropy(pred, y) # compute loss
grads = torch.autograd.grad(
- inner_loss, params, allow_unused=True
+ inner_loss,
+ params,
+ allow_unused=True,
) # compute gradients
updates, inner_opt_state = inner_optimizer.update(
- grads, inner_opt_state, inplace=False
+ grads,
+ inner_opt_state,
+ inplace=False,
) # get updates
params = torchopt.apply_updates(params, updates, inplace=False)
return (params, b)
@@ -169,7 +176,9 @@ def maml_inner_solver(params, data, use_accelerated_op):
params_prime, buffers_prime = maml_inner_solver(params, data, use_accelerated_op=True)
params_prime_ref, buffers_prime_ref = maml_inner_solver(
- params_ref, data_ref, use_accelerated_op=False
+ params_ref,
+ data_ref,
+ use_accelerated_op=False,
)
pred = fmodel(params_prime, buffers_prime, xs)
@@ -179,13 +188,19 @@ def maml_inner_solver(params, data, use_accelerated_op):
grads = torch.autograd.grad(outer_loss, params, allow_unused=True)
updates, outer_optim_state = outer_optim.update(
- grads, outer_optim_state, params=params, inplace=inplace
+ grads,
+ outer_optim_state,
+ params=params,
+ inplace=inplace,
)
params = torchopt.apply_updates(params, updates, inplace=inplace)
grads = torch.autograd.grad(outer_loss_ref, params_ref, allow_unused=True)
updates, outer_optim_state_ref = outer_optim_ref.update(
- grads, outer_optim_state_ref, params=params, inplace=inplace
+ grads,
+ outer_optim_state_ref,
+ params=params,
+ inplace=inplace,
)
params_ref = torchopt.apply_updates(params_ref, updates, inplace=inplace)
diff --git a/tests/test_alias.py b/tests/test_alias.py
index b609cf58..a0a78129 100644
--- a/tests/test_alias.py
+++ b/tests/test_alias.py
@@ -24,7 +24,56 @@
import helpers
import torchopt
+from torchopt import pytree
from torchopt.alias.utils import _set_use_chain_flat
+from torchopt.typing import TensorTree
+
+
+@helpers.parametrize(
+ optimizer=[
+ torchopt.sgd,
+ torchopt.adam,
+ torchopt.adamw,
+ torchopt.rmsprop,
+ ],
+ tensortree=[
+ {},
+ (),
+ [],
+ (None,),
+ {'a': (), 'b': {'c': []}, 'd': None},
+ ],
+ maximize=[False, True],
+ inplace=[True, False],
+ use_chain_flat=[True, False],
+)
+def test_empty(
+ optimizer: Callable,
+ tensortree: TensorTree,
+ maximize: bool,
+ inplace: bool,
+ use_chain_flat: bool,
+) -> None:
+ _set_use_chain_flat(use_chain_flat)
+
+ params = pytree.tree_map(lambda x: x, tensortree)
+ grads = pytree.tree_map(lambda x: x, tensortree)
+
+ optim = optimizer(1e-3, maximize=maximize)
+ optim_state = optim.init(params)
+ updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
+ _ = torchopt.apply_updates(params, updates)
+
+ try:
+ optim = optimizer(1e-3, maximize=maximize, use_accelerated_op=True)
+ except TypeError:
+ pass
+ else:
+ optim_state = optim.init(params)
+ updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
+ _ = torchopt.apply_updates(params, updates)
+
+ _set_use_chain_flat(True)
@helpers.parametrize(
@@ -222,10 +271,15 @@ def maml_inner_solver_torchopt(params, data, use_accelerated_op):
pred = f(params, b, x)
inner_loss = F.cross_entropy(pred, y) # compute loss
grads = torch.autograd.grad(
- inner_loss, params, allow_unused=True
+ inner_loss,
+ params,
+ allow_unused=True,
) # compute gradients
updates, inner_opt_state = inner_optimizer.update(
- grads, inner_opt_state, params=params, inplace=False
+ grads,
+ inner_opt_state,
+ params=params,
+ inplace=False,
) # get updates
params = torchopt.apply_updates(params, updates, inplace=False)
return (params, b)
@@ -235,14 +289,19 @@ def maml_inner_solver_torchopt(params, data, use_accelerated_op):
data = (xs, ys, fmodel, buffers)
params_prime, buffers_prime = maml_inner_solver_torchopt(
- params, data, use_accelerated_op=True
+ params,
+ data,
+ use_accelerated_op=True,
)
pred = fmodel(params_prime, buffers_prime, xs)
outer_loss = F.cross_entropy(pred, ys)
grads = torch.autograd.grad(outer_loss, params, allow_unused=True)
updates, outer_optim_state = outer_optim.update(
- grads, outer_optim_state, params=params, inplace=inplace
+ grads,
+ outer_optim_state,
+ params=params,
+ inplace=inplace,
)
params = torchopt.apply_updates(params, updates, inplace=inplace)
@@ -391,6 +450,70 @@ def test_adam_accelerated_cuda(
_set_use_chain_flat(True)
+@helpers.parametrize(
+ dtype=[torch.float64],
+ lr=[1e-2, 1e-3, 1e-4],
+ lr_decay=[0.0, 1e-2],
+ initial_accumulator_value=[0.0, 1e-1],
+ eps=[1e-8],
+ inplace=[True, False],
+ weight_decay=[0.0, 1e-2],
+ maximize=[False, True],
+ use_chain_flat=[True, False],
+)
+def test_adagrad(
+ dtype: torch.dtype,
+ lr: float,
+ lr_decay: float,
+ initial_accumulator_value: float,
+ eps: float,
+ inplace: bool,
+ weight_decay: float,
+ maximize: bool,
+ use_chain_flat: bool,
+) -> None:
+ _set_use_chain_flat(use_chain_flat)
+
+ model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype)
+
+ fmodel, params, buffers = functorch.make_functional_with_buffers(model)
+ optim = torchopt.adagrad(
+ lr=lr,
+ lr_decay=lr_decay,
+ weight_decay=weight_decay,
+ initial_accumulator_value=initial_accumulator_value,
+ eps=eps,
+ maximize=maximize,
+ )
+ optim_state = optim.init(params)
+ optim_ref = torch.optim.Adagrad(
+ model_ref.parameters(),
+ lr=lr,
+ lr_decay=lr_decay,
+ weight_decay=weight_decay,
+ initial_accumulator_value=initial_accumulator_value,
+ eps=eps,
+ maximize=maximize,
+ )
+ for xs, ys in loader:
+ xs = xs.to(dtype=dtype)
+ pred = fmodel(params, buffers, xs)
+ pred_ref = model_ref(xs)
+ loss = F.cross_entropy(pred, ys)
+ loss_ref = F.cross_entropy(pred_ref, ys)
+
+ grads = torch.autograd.grad(loss, params, allow_unused=True)
+ updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
+ params = torchopt.apply_updates(params, updates, inplace=inplace)
+
+ optim_ref.zero_grad()
+ loss_ref.backward()
+ optim_ref.step()
+
+ helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype)
+ _set_use_chain_flat(True)
+
+
@helpers.parametrize(
dtype=[torch.float64],
lr=[1e-2, 1e-3, 1e-4],
diff --git a/tests/test_combine.py b/tests/test_combine.py
index ad018d21..39b3e37f 100644
--- a/tests/test_combine.py
+++ b/tests/test_combine.py
@@ -35,7 +35,8 @@ def test_chain() -> None:
adam = torchopt.adam()
assert isinstance(adam, torchopt.base.ChainedGradientTransformation)
assert isinstance(
- adam.chain(torchopt.base.identity()), torchopt.base.ChainedGradientTransformation
+ adam.chain(torchopt.base.identity()),
+ torchopt.base.ChainedGradientTransformation,
)
assert adam.chain(torchopt.base.identity()) == adam
assert torchopt.base.identity().chain(adam) == adam
@@ -44,7 +45,8 @@ def test_chain() -> None:
assert isinstance(adam, torchopt.base.GradientTransformation)
assert isinstance(
- adam.chain(torchopt.base.identity()), torchopt.base.ChainedGradientTransformation
+ adam.chain(torchopt.base.identity()),
+ torchopt.base.ChainedGradientTransformation,
)
assert adam.chain(torchopt.base.identity()) == adam
assert torchopt.base.identity().chain(adam) == adam
diff --git a/tests/test_implicit.py b/tests/test_implicit.py
index 9e3722d3..db19f829 100644
--- a/tests/test_implicit.py
+++ b/tests/test_implicit.py
@@ -16,15 +16,12 @@
from __future__ import annotations
import copy
+import re
from collections import OrderedDict
from types import FunctionType
import functorch
-import jax
-import jax.numpy as jnp
-import jaxopt
import numpy as np
-import optax
import pytest
import torch
import torch.nn as nn
@@ -38,6 +35,18 @@
from torchopt.diff.implicit import ImplicitMetaGradientModule
+try:
+ import jax
+ import jax.numpy as jnp
+ import jaxopt
+ import optax
+
+ HAS_JAX = True
+except ImportError:
+ jax = jnp = jaxopt = optax = None
+ HAS_JAX = False
+
+
BATCH_SIZE = 8
NUM_UPDATES = 3
@@ -66,14 +75,15 @@ def func(params, x):
[
('weight', jnp.ones((MODEL_NUM_INPUTS, MODEL_NUM_CLASSES), dtype=dtype)),
('bias', jnp.zeros((MODEL_NUM_CLASSES,), dtype=dtype)),
- ]
+ ],
)
return func, params
@torch.no_grad()
def get_model_torch(
- device: torch.types.Device = None, dtype: torch.dtype = torch.float32
+ device: torch.types.Device | None = None,
+ dtype: torch.dtype = torch.float32,
) -> tuple[nn.Module, data.DataLoader]:
helpers.seed_everything(seed=42)
@@ -103,11 +113,10 @@ def get_rr_dataset_torch() -> data.DataLoader:
torch.randn((BATCH_SIZE * NUM_UPDATES, MODEL_NUM_INPUTS)),
torch.randn((BATCH_SIZE * NUM_UPDATES,)),
)
- loader = data.DataLoader(dataset, BATCH_SIZE, shuffle=False)
-
- return loader
+ return data.DataLoader(dataset, BATCH_SIZE, shuffle=False)
+@pytest.mark.skipif(not HAS_JAX, reason='JAX is not installed')
@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
@@ -115,7 +124,10 @@ def get_rr_dataset_torch() -> data.DataLoader:
inner_update=[20, 50, 100],
)
def test_imaml_solve_normal_cg(
- dtype: torch.dtype, lr: float, inner_lr: float, inner_update: int
+ dtype: torch.dtype,
+ lr: float,
+ inner_lr: float,
+ inner_update: int,
) -> None:
np_dtype = helpers.dtype_torch2numpy(dtype)
@@ -135,8 +147,7 @@ def imaml_objective_torchopt(params, meta_params, data):
regularization_loss = 0
for p1, p2 in zip(params, meta_params):
regularization_loss += 0.5 * torch.sum(torch.square(p1 - p2))
- loss = F.cross_entropy(y_pred, y) + regularization_loss
- return loss
+ return F.cross_entropy(y_pred, y) + regularization_loss
@torchopt.diff.implicit.custom_root(
functorch.grad(imaml_objective_torchopt, argnums=0),
@@ -191,10 +202,9 @@ def compute_loss(params, meta_params, x, y):
regularization_loss = 0
for p1, p2 in zip(params.values(), meta_params.values()):
regularization_loss += 0.5 * jnp.sum(jnp.square(p1 - p2))
- final_loss = loss + regularization_loss
- return final_loss
+ return loss + regularization_loss
- for i in range(inner_update):
+ for _ in range(inner_update):
grads = jax.grad(compute_loss)(params, meta_params, x, y) # compute gradients
updates, opt_state = optimizer.update(grads, opt_state) # get updates
params = optax.apply_updates(params, updates)
@@ -204,7 +214,8 @@ def compute_loss(params, meta_params, x, y):
xs = xs.to(dtype=dtype)
data = (xs, ys, fmodel)
inner_params = pytree.tree_map(
- lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), params
+ lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad),
+ params,
)
optimal_inner_params, aux = inner_solver_torchopt(inner_params, params, data)
assert aux == (0, {'a': 1, 'b': 2})
@@ -220,8 +231,7 @@ def compute_loss(params, meta_params, x, y):
def outer_level(p, xs, ys):
optimal_params, aux = inner_solver_jax(copy.deepcopy(p), p, xs, ys)
assert aux == (0, {'a': 1, 'b': 2})
- outer_loss = jax_model(optimal_params, xs).mean()
- return outer_loss
+ return jax_model(optimal_params, xs).mean()
grads_jax = jax.grad(outer_level, argnums=0)(jax_params, xs, ys)
updates_jax, optim_state_jax = optim_jax.update(grads_jax, optim_state_jax) # get updates
@@ -234,6 +244,7 @@ def outer_level(p, xs, ys):
helpers.assert_pytree_all_close(params, jax_params_as_tensor)
+@pytest.mark.skipif(not HAS_JAX, reason='JAX is not installed')
@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
@@ -266,8 +277,7 @@ def imaml_objective_torchopt(params, meta_params, data):
regularization_loss = 0
for p1, p2 in zip(params, meta_params):
regularization_loss += 0.5 * torch.sum(torch.square(p1 - p2))
- loss = F.cross_entropy(y_pred, y) + regularization_loss
- return loss
+ return F.cross_entropy(y_pred, y) + regularization_loss
@torchopt.diff.implicit.custom_root(
functorch.grad(imaml_objective_torchopt, argnums=0),
@@ -320,10 +330,9 @@ def compute_loss(params, meta_params, x, y):
regularization_loss = 0
for p1, p2 in zip(params.values(), meta_params.values()):
regularization_loss += 0.5 * jnp.sum(jnp.square(p1 - p2))
- final_loss = loss + regularization_loss
- return final_loss
+ return loss + regularization_loss
- for i in range(inner_update):
+ for _ in range(inner_update):
grads = jax.grad(compute_loss)(params, meta_params, x, y) # compute gradients
updates, opt_state = optimizer.update(grads, opt_state) # get updates
params = optax.apply_updates(params, updates)
@@ -333,7 +342,8 @@ def compute_loss(params, meta_params, x, y):
xs = xs.to(dtype=dtype)
data = (xs, ys, fmodel)
inner_params = pytree.tree_map(
- lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), params
+ lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad),
+ params,
)
optimal_inner_params = inner_solver_torchopt(inner_params, params, data)
outer_loss = fmodel(optimal_inner_params, xs).mean()
@@ -347,8 +357,7 @@ def compute_loss(params, meta_params, x, y):
def outer_level(p, xs, ys):
optimal_params = inner_solver_jax(copy.deepcopy(p), p, xs, ys)
- outer_loss = jax_model(optimal_params, xs).mean()
- return outer_loss
+ return jax_model(optimal_params, xs).mean()
grads_jax = jax.grad(outer_level, argnums=0)(jax_params, xs, ys)
updates_jax, optim_state_jax = optim_jax.update(grads_jax, optim_state_jax) # get updates
@@ -361,6 +370,7 @@ def outer_level(p, xs, ys):
helpers.assert_pytree_all_close(params, jax_params_as_tensor)
+@pytest.mark.skipif(not HAS_JAX, reason='JAX is not installed')
@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
@@ -432,10 +442,9 @@ def compute_loss(params, meta_params, x, y):
regularization_loss = 0
for p1, p2 in zip(params.values(), meta_params.values()):
regularization_loss += 0.5 * jnp.sum(jnp.square(p1 - p2))
- final_loss = loss + regularization_loss
- return final_loss
+ return loss + regularization_loss
- for i in range(inner_update):
+ for _ in range(inner_update):
grads = jax.grad(compute_loss)(params, meta_params, x, y) # compute gradients
updates, opt_state = optimizer.update(grads, opt_state) # get updates
params = optax.apply_updates(params, updates)
@@ -458,8 +467,7 @@ def compute_loss(params, meta_params, x, y):
def outer_level(p, xs, ys):
optimal_params, aux = inner_solver_jax(copy.deepcopy(p), p, xs, ys)
assert aux == (0, {'a': 1, 'b': 2})
- outer_loss = jax_model(optimal_params, xs).mean()
- return outer_loss
+ return jax_model(optimal_params, xs).mean()
grads_jax = jax.grad(outer_level, argnums=0)(jax_params, xs, ys)
updates_jax, optim_state_jax = optim_jax.update(grads_jax, optim_state_jax) # get updates
@@ -472,6 +480,7 @@ def outer_level(p, xs, ys):
helpers.assert_pytree_all_close(tuple(model.parameters()), jax_params_as_tensor)
+@pytest.mark.skipif(not HAS_JAX, reason='JAX is not installed')
@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
@@ -563,8 +572,7 @@ def matvec(u):
def outer_level(params_jax, l2reg_jax, xs, ys, xq, yq):
w_fit = ridge_solver_jax_cg(params_jax, l2reg_jax, xs, ys)
y_pred = xq @ w_fit
- loss_value = jnp.mean(jnp.square(y_pred - yq))
- return loss_value
+ return jnp.mean(jnp.square(y_pred - yq))
grads_jax = jax.grad(outer_level, argnums=1)(init_params_jax, l2reg_jax, xs, ys, xq, yq)
updates_jax, optim_state_jax = optim_jax.update(grads_jax, optim_state_jax) # get updates
@@ -574,6 +582,7 @@ def outer_level(params_jax, l2reg_jax, xs, ys, xq, yq):
helpers.assert_all_close(l2reg_torch, l2reg_jax_as_tensor)
+@pytest.mark.skipif(not HAS_JAX, reason='JAX is not installed')
@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
@@ -668,8 +677,7 @@ def matvec(u):
def outer_level(params_jax, l2reg_jax, xs, ys, xq, yq):
w_fit = ridge_solver_jax_inv(params_jax, l2reg_jax, xs, ys)
y_pred = xq @ w_fit
- loss_value = jnp.mean(jnp.square(y_pred - yq))
- return loss_value
+ return jnp.mean(jnp.square(y_pred - yq))
grads_jax = jax.grad(outer_level, argnums=1)(init_params_jax, l2reg_jax, xs, ys, xq, yq)
updates_jax, optim_state_jax = optim_jax.update(grads_jax, optim_state_jax) # get updates
@@ -677,3 +685,184 @@ def outer_level(params_jax, l2reg_jax, xs, ys, xq, yq):
l2reg_jax_as_tensor = torch.tensor(np.asarray(l2reg_jax), dtype=dtype)
helpers.assert_all_close(l2reg_torch, l2reg_jax_as_tensor)
+
+
+def test_module_empty_parameters() -> None:
+ class EmptyParameters(ImplicitMetaGradientModule):
+ def __init__(self, x):
+ super().__init__()
+ self.x = x
+
+ def objective(self):
+ return self.x.mean()
+
+ def solve(self):
+ pass
+
+ model = EmptyParameters(torch.zeros(8))
+ with pytest.raises(RuntimeError, match='The module has no parameters.'):
+ model.solve()
+
+ model = EmptyParameters(torch.zeros(8))
+ model.register_parameter('y', torch.zeros(8, requires_grad=True))
+ with pytest.raises(RuntimeError, match='The module has no meta-parameters.'):
+ model.solve()
+
+ model = EmptyParameters(torch.zeros(8, requires_grad=True))
+ with pytest.raises(RuntimeError, match='The module has no parameters.'):
+ model.solve()
+
+ model = EmptyParameters(torch.zeros(8, requires_grad=True))
+ with pytest.raises(RuntimeError, match='The module has no parameters.'):
+ model.optimality()
+
+ model = EmptyParameters(torch.zeros(8))
+ model.register_parameter('y', torch.zeros(8, requires_grad=True))
+ with pytest.raises(RuntimeError, match='The module has no meta-parameters.'):
+ model.optimality()
+
+ model = EmptyParameters(torch.zeros(8, requires_grad=True))
+ model.register_parameter('y', torch.zeros(8, requires_grad=True))
+ model.solve()
+
+ model = EmptyParameters(nn.Linear(8, 8).eval())
+ with pytest.raises(RuntimeError, match='The module has no meta-parameters.'):
+ model.solve()
+
+ model = EmptyParameters(nn.Linear(8, 8))
+ model.register_parameter('y', torch.zeros(8, requires_grad=True))
+ model.solve()
+
+
+def test_module_enable_implicit_gradients_twice() -> None:
+ class MyModule1(torchopt.nn.ImplicitMetaGradientModule):
+ def objective(self):
+ return torch.tensor(0.0)
+
+ def solve(self):
+ pass
+
+ from torchopt.diff.implicit.nn.module import (
+ enable_implicit_gradients,
+ make_optimality_from_objective,
+ )
+
+ with pytest.raises(
+ TypeError,
+ match='Implicit gradients are already enabled for the `solve` method.',
+ ):
+ enable_implicit_gradients(MyModule1)
+
+ class MyModule2(torchopt.nn.ImplicitMetaGradientModule):
+ def optimality(self):
+ return torch.tensor(0.0)
+
+ def solve(self):
+ pass
+
+ with pytest.raises(
+ TypeError,
+ match='The objective function is not defined.',
+ ):
+ make_optimality_from_objective(MyModule2)
+
+
+def test_module_abstract_methods() -> None:
+ class MyModule1(torchopt.nn.ImplicitMetaGradientModule):
+ def objective(self):
+ return torch.tensor(0.0)
+
+ with pytest.raises(TypeError, match="Can't instantiate abstract class"):
+ MyModule1()
+
+ with pytest.raises(
+ TypeError,
+ match=re.escape(
+ 'ImplicitMetaGradientModule requires either an optimality() method or an objective() method',
+ ),
+ ):
+
+ class MyModule2(torchopt.nn.ImplicitMetaGradientModule):
+ def solve(self):
+ pass
+
+ class MyModule3(torchopt.nn.ImplicitMetaGradientModule):
+ def optimality(self):
+ return ()
+
+ def solve(self):
+ pass
+
+ with pytest.raises(
+ TypeError,
+ match=re.escape('method optimality() must not be a staticmethod.'),
+ ):
+
+ class MyModule4(torchopt.nn.ImplicitMetaGradientModule):
+ @staticmethod
+ def optimality():
+ return ()
+
+ def solve(self):
+ pass
+
+ with pytest.raises(
+ TypeError,
+ match=re.escape('method optimality() must not be a classmethod.'),
+ ):
+
+ class MyModule5(torchopt.nn.ImplicitMetaGradientModule):
+ @classmethod
+ def optimality(self):
+ return ()
+
+ def solve(self):
+ pass
+
+ with pytest.raises(
+ TypeError,
+ match=re.escape('method optimality() must be callable.'),
+ ):
+
+ class MyModule6(torchopt.nn.ImplicitMetaGradientModule):
+ optimality = 0
+
+ def solve(self):
+ pass
+
+ with pytest.raises(
+ TypeError,
+ match=re.escape('method objective() must not be a staticmethod.'),
+ ):
+
+ class MyModule7(torchopt.nn.ImplicitMetaGradientModule):
+ @staticmethod
+ def objective():
+ return ()
+
+ def solve(self):
+ pass
+
+ with pytest.raises(
+ TypeError,
+ match=re.escape('method objective() must not be a classmethod.'),
+ ):
+
+ class MyModule8(torchopt.nn.ImplicitMetaGradientModule):
+ @classmethod
+ def objective(self):
+ return ()
+
+ def solve(self):
+ pass
+
+ with pytest.raises(
+ TypeError,
+ match=re.escape('method objective() must be callable.'),
+ ):
+
+ class MyModule9(torchopt.nn.ImplicitMetaGradientModule):
+ objective = 0
+
+ def solve(self):
+ pass
diff --git a/tests/test_import.py b/tests/test_import.py
index 30cf914e..1b6dea38 100644
--- a/tests/test_import.py
+++ b/tests/test_import.py
@@ -25,6 +25,7 @@ def test_accelerated_op_import() -> None:
def test_alias_import() -> None:
+ torchopt.adagrad
torchopt.adam
torchopt.adamw
torchopt.rmsprop
@@ -33,8 +34,8 @@ def test_alias_import() -> None:
torchopt.alias.adamw
torchopt.alias.rmsprop
torchopt.alias.sgd
- from torchopt import adam, adamw, rmsprop, sgd
- from torchopt.alias import adam, adamw, rmsprop, sgd
+ from torchopt import adagrad, adam, adamw, rmsprop, sgd
+ from torchopt.alias import adagrad, adam, adamw, rmsprop, sgd
def test_diff_import() -> None:
@@ -107,17 +108,23 @@ def test_nn_import() -> None:
def test_optim_import() -> None:
torchopt.FuncOptimizer
+ torchopt.MetaAdaGrad
+ torchopt.MetaAdagrad
torchopt.MetaAdam
torchopt.MetaAdamW
torchopt.MetaRMSProp
torchopt.MetaRMSprop
torchopt.MetaSGD
+ torchopt.AdaGrad
+ torchopt.Adagrad
torchopt.Adam
torchopt.AdamW
torchopt.Optimizer
torchopt.RMSProp
torchopt.RMSprop
torchopt.SGD
+ torchopt.optim.meta.MetaAdaGrad
+ torchopt.optim.meta.MetaAdagrad
torchopt.optim.meta.MetaAdam
torchopt.optim.meta.MetaAdamW
torchopt.optim.meta.MetaRMSProp
@@ -132,14 +139,18 @@ def test_optim_import() -> None:
torchopt.optim.func.FuncOptimizer
from torchopt import (
SGD,
+ AdaGrad,
+ Adagrad,
Adam,
AdamW,
FuncOptimizer,
+ MetaAdaGrad,
+ MetaAdagrad,
MetaAdam,
MetaAdamW,
MetaOptimizer,
- MetaRMSProp,
MetaRMSprop,
+ MetaRMSProp,
MetaSGD,
Optimizer,
RMSProp,
@@ -147,6 +158,8 @@ def test_optim_import() -> None:
from torchopt.optim import SGD, Adam, AdamW, FuncOptimizer, Optimizer, RMSProp
from torchopt.optim.func import FuncOptimizer
from torchopt.optim.meta import (
+ MetaAdaGrad,
+ MetaAdagrad,
MetaAdam,
MetaAdamW,
MetaOptimizer,
diff --git a/tests/test_nn.py b/tests/test_nn.py
index 1b48c06b..8e89bdb5 100644
--- a/tests/test_nn.py
+++ b/tests/test_nn.py
@@ -69,7 +69,13 @@ def test_register_tensors() -> None:
assert m._meta_parameters['x'] is x
assert m._parameters['y'] is y
- assert hasattr(m, 'z') and m.z is z and 'z' not in m._buffers
+ assert (
+ hasattr(m, 'z')
+ and m.z is z
+ and 'z' not in m._meta_parameters
+ and 'z' not in m._parameters
+ and 'z' not in m._buffers
+ )
del m.x
object.__setattr__(m, 'x', x)
@@ -82,56 +88,122 @@ def test_register_tensors() -> None:
m.b = b
assert m.b is b and 'b' in m._buffers
+ m = torchopt.nn.MetaGradientModule(x, b)
+
+ with pytest.raises(
+ TypeError,
+ match=re.escape('parameter name should be a string. Got bytes'),
+ ):
+ m.register_meta_parameter(b'x', x)
+
+ with pytest.raises(
+ KeyError,
+ match=re.escape("parameter name can't contain '.'"),
+ ):
+ m.register_meta_parameter('x.x', x)
+
+ with pytest.raises(
+ KeyError,
+ match=re.escape("parameter name can't be empty string ''"),
+ ):
+ m.register_meta_parameter('', x)
+
+ m.register_buffer('z', None)
+ with pytest.raises(
+ KeyError,
+ match=re.escape("attribute 'z' already exists"),
+ ):
+ m.register_meta_parameter('z', x)
+
+ with pytest.raises(
+ ValueError,
+ match=re.escape(
+ "cannot assign Tensor that is a meta-parameter to parameter 'x'. "
+ 'Use self.register_meta_parameter() instead.',
+ ),
+ ):
+ m.register_parameter('x', x)
+
+ m.x = x
+ with pytest.raises(
+ KeyError,
+ match=re.escape("attribute 'x' already exists"),
+ ):
+ m.register_parameter('x', x)
+
+ with pytest.raises(
+ TypeError,
+ match=re.escape('parameter name should be a string. Got bytes'),
+ ):
+ m.register_parameter(b'y', y)
+
+ with pytest.raises(
+ KeyError,
+ match=re.escape("parameter name can't contain '.'"),
+ ):
+ m.register_parameter('y.x', y)
+
+ with pytest.raises(
+ KeyError,
+ match=re.escape("parameter name can't be empty string ''"),
+ ):
+ m.register_parameter('', y)
+
def test_no_super_init() -> None:
class NoSuper1(torchopt.nn.MetaGradientModule):
- def __init__(self, x):
+ def __init__(self, x) -> None:
self.x = x
with pytest.raises(
- AttributeError, match=re.escape('cannot assign parameters before Module.__init__() call')
+ AttributeError,
+ match=re.escape('cannot assign parameters before Module.__init__() call'),
):
NoSuper1(torch.tensor(1.0, requires_grad=True))
class NoSuper2(torchopt.nn.MetaGradientModule):
- def __init__(self):
+ def __init__(self) -> None:
self.x = torch.tensor(1.0, requires_grad=True)
with pytest.raises(
- AttributeError, match=re.escape('cannot assign parameters before Module.__init__() call')
+ AttributeError,
+ match=re.escape('cannot assign parameters before Module.__init__() call'),
):
NoSuper2()
class NoSuper3(torchopt.nn.MetaGradientModule):
- def __init__(self):
+ def __init__(self) -> None:
self.register_buffer('x', torch.tensor(1.0))
with pytest.raises(
- AttributeError, match=re.escape('cannot assign buffer before Module.__init__() call')
+ AttributeError,
+ match=re.escape('cannot assign buffer before Module.__init__() call'),
):
NoSuper3()
class NoSuper4(torchopt.nn.MetaGradientModule):
- def __init__(self):
+ def __init__(self) -> None:
self.x = torch.tensor(1.0, requires_grad=False)
NoSuper4() # no error
class NoSuper5(torchopt.nn.MetaGradientModule):
- def __init__(self, x):
+ def __init__(self, x) -> None:
self.x = x
with pytest.raises(
- AttributeError, match=re.escape('cannot assign module before Module.__init__() call')
+ AttributeError,
+ match=re.escape('cannot assign module before Module.__init__() call'),
):
NoSuper5(nn.Linear(1, 1))
class NoSuper6(torchopt.nn.MetaGradientModule):
- def __init__(self):
+ def __init__(self) -> None:
self.x = nn.Linear(1, 1)
with pytest.raises(
- AttributeError, match=re.escape('cannot assign module before Module.__init__() call')
+ AttributeError,
+ match=re.escape('cannot assign module before Module.__init__() call'),
):
NoSuper6()
diff --git a/tests/test_optim.py b/tests/test_optim.py
index b2be7500..6ec81918 100644
--- a/tests/test_optim.py
+++ b/tests/test_optim.py
@@ -268,6 +268,63 @@ def test_Adam_accelerated_cuda(
helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype)
+@helpers.parametrize(
+ dtype=[torch.float64],
+ lr=[1e-2, 1e-3, 1e-4],
+ lr_decay=[0.0, 1e-2],
+ initial_accumulator_value=[0.0, 1e-1],
+ eps=[1e-8],
+ weight_decay=[0.0, 1e-2],
+ maximize=[False, True],
+)
+def test_AdaGrad(
+ dtype: torch.dtype,
+ lr: float,
+ lr_decay: float,
+ initial_accumulator_value: float,
+ eps: float,
+ weight_decay: float,
+ maximize: bool,
+) -> None:
+ model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype)
+
+ optim = torchopt.AdaGrad(
+ model.parameters(),
+ lr=lr,
+ lr_decay=lr_decay,
+ weight_decay=weight_decay,
+ initial_accumulator_value=initial_accumulator_value,
+ eps=eps,
+ maximize=maximize,
+ )
+ optim_ref = torch.optim.Adagrad(
+ model_ref.parameters(),
+ lr=lr,
+ lr_decay=lr_decay,
+ weight_decay=weight_decay,
+ initial_accumulator_value=initial_accumulator_value,
+ eps=eps,
+ maximize=maximize,
+ )
+
+ for xs, ys in loader:
+ xs = xs.to(dtype=dtype)
+ pred = model(xs)
+ pred_ref = model_ref(xs)
+ loss = F.cross_entropy(pred, ys)
+ loss_ref = F.cross_entropy(pred_ref, ys)
+
+ optim.zero_grad()
+ loss.backward()
+ optim.step()
+
+ optim_ref.zero_grad()
+ loss_ref.backward()
+ optim_ref.step()
+
+ helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype)
+
+
@helpers.parametrize(
dtype=[torch.float64],
lr=[1e-2, 1e-3, 1e-4],
@@ -330,10 +387,11 @@ def test_RMSProp(
dtype=[torch.float64, torch.float32],
lr=[1e-2, 1e-3],
optimizers=[
- (torchopt.sgd, torch.optim.SGD),
- (torchopt.adam, torch.optim.Adam),
- (torchopt.adamw, torch.optim.AdamW),
- (torchopt.rmsprop, torch.optim.RMSprop),
+ (torchopt.sgd, torch.optim.SGD, {}),
+ (torchopt.adam, torch.optim.Adam, {}),
+ (torchopt.adamw, torch.optim.AdamW, {}),
+ (torchopt.adagrad, torch.optim.Adagrad, {'eps': 1e-8}),
+ (torchopt.rmsprop, torch.optim.RMSprop, {}),
],
inplace=[True, False],
weight_decay=[0.0, 1e-2],
@@ -347,13 +405,14 @@ def test_FuncOptimizer(
) -> None:
model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype)
- torchopt_optimizer, torch_optimizer = optimizers
+ torchopt_optimizer, torch_optimizer, optimizer_kwargs = optimizers
fmodel, params, buffers = functorch.make_functional_with_buffers(model)
optim = torchopt.FuncOptimizer(
torchopt_optimizer(
lr=lr,
weight_decay=weight_decay,
+ **optimizer_kwargs,
),
inplace=inplace,
)
@@ -361,6 +420,7 @@ def test_FuncOptimizer(
model_ref.parameters(),
lr,
weight_decay=weight_decay,
+ **optimizer_kwargs,
)
for xs, ys in loader:
diff --git a/tests/test_pytree.py b/tests/test_pytree.py
index 5594e30b..d82d81f2 100644
--- a/tests/test_pytree.py
+++ b/tests/test_pytree.py
@@ -177,12 +177,14 @@ def test_tree_vdot_real() -> None:
helpers.assert_pytree_all_close(actual, expected)
tree_a_complex, tree_b_complex = pytree.tree_map(
- lambda x: torch.randn(x.size(), dtype=torch.cfloat), (tree_a, tree_b)
+ lambda x: torch.randn(x.size(), dtype=torch.cfloat),
+ (tree_a, tree_b),
)
expected = (
torch.vdot(tree_a_complex[0].contiguous().view(-1), tree_b_complex[0].contiguous().view(-1))
+ torch.vdot(
- tree_a_complex[1].contiguous().view(-1), tree_b_complex[1].contiguous().view(-1)
+ tree_a_complex[1].contiguous().view(-1),
+ tree_b_complex[1].contiguous().view(-1),
)
).real
actual = torch.tensor(pytree.tree_vdot_real(tree_a_complex, tree_b_complex))
@@ -197,14 +199,15 @@ def test_tree_vdot_real() -> None:
'tree_b_dict',
'tensor_a',
'tensor_b',
- ]
+ ],
)
def test_tree_wait(tree_name: str) -> None:
tree = globals()[tree_name]
future_tree = pytree.tree_map(lambda x: torch.futures.Future(), tree)
new_future_tree = pytree.tree_map(
- lambda fut: fut.then(lambda f: torch.square(f.wait()) + 1.0), future_tree
+ lambda fut: fut.then(lambda f: torch.square(f.wait()) + 1.0),
+ future_tree,
)
pytree.tree_map_(lambda fut, x: fut.set_result(x), future_tree, tree)
diff --git a/tests/test_schedule.py b/tests/test_schedule.py
index ae714875..1fdc4669 100644
--- a/tests/test_schedule.py
+++ b/tests/test_schedule.py
@@ -15,7 +15,7 @@
from __future__ import annotations
-from typing import Callable
+from typing import Any, Callable
import functorch
import numpy as np
@@ -27,6 +27,43 @@
from torchopt.alias.utils import _set_use_chain_flat
+@helpers.parametrize(
+ init_value=[1.0, 1e-1],
+ decay_rate=[1e-2, 1e-3],
+ transition_begin=[1, 5],
+ transition_steps=[10, 100],
+ staircase=[False, True],
+ end_value=[0.0, None, 8e-1],
+)
+def test_exponential_decay(
+ init_value: float,
+ decay_rate: float,
+ transition_begin: int,
+ transition_steps: int | None,
+ staircase: bool,
+ end_value: float | None,
+) -> None:
+ schedule = torchopt.schedule.exponential_decay(
+ init_value=init_value,
+ decay_rate=decay_rate,
+ transition_steps=transition_steps,
+ transition_begin=transition_begin,
+ staircase=staircase,
+ end_value=end_value,
+ )
+ if end_value is not None:
+ clip_fn = max if decay_rate < 1.0 else min
+ for i in range(transition_begin, transition_steps):
+ lr = schedule(i)
+ if staircase:
+ lr_gt = init_value * (decay_rate ** np.floor((i - transition_begin) / transition_steps))
+ else:
+ lr_gt = init_value * (decay_rate ** ((i - transition_begin) / transition_steps))
+ if end_value is not None:
+ lr_gt = clip_fn(lr_gt, end_value)
+ assert np.allclose(lr, lr_gt)
+
+
def test_linear_schedule() -> None:
init_value = 1.0
end_value = 0.0
@@ -51,10 +88,11 @@ def test_linear_schedule() -> None:
lr=[1e-2, 1e-3],
total_iters=[helpers.NUM_UPDATES, helpers.NUM_UPDATES * 2],
optimizers=[
- (torchopt.sgd, torch.optim.SGD),
- (torchopt.adam, torch.optim.Adam),
- (torchopt.adamw, torch.optim.AdamW),
- (torchopt.rmsprop, torch.optim.RMSprop),
+ (torchopt.sgd, torch.optim.SGD, {}),
+ (torchopt.adam, torch.optim.Adam, {}),
+ (torchopt.adamw, torch.optim.AdamW, {}),
+ (torchopt.adagrad, torch.optim.Adagrad, {'eps': 1e-8}),
+ (torchopt.rmsprop, torch.optim.RMSprop, {}),
],
inplace=[True, False],
weight_decay=[0.0, 1e-2],
@@ -64,7 +102,7 @@ def test_lr_linear_schedule(
dtype: torch.dtype,
lr: float,
total_iters: int,
- optimizers: tuple[Callable, torch.optim.Optimizer],
+ optimizers: tuple[Callable, torch.optim.Optimizer, dict[str, Any]],
inplace: bool,
weight_decay: float,
use_chain_flat: bool,
@@ -73,23 +111,31 @@ def test_lr_linear_schedule(
model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype)
- torchopt_optimizer, torch_optimizer = optimizers
+ torchopt_optimizer, torch_optimizer, optimizer_kwargs = optimizers
fmodel, params, buffers = functorch.make_functional_with_buffers(model)
optim = torchopt_optimizer(
torchopt.schedule.linear_schedule(
- init_value=lr, end_value=0.1 * lr, transition_steps=total_iters, transition_begin=0
+ init_value=lr,
+ end_value=0.1 * lr,
+ transition_steps=total_iters,
+ transition_begin=0,
),
weight_decay=weight_decay,
+ **optimizer_kwargs,
)
optim_state = optim.init(params)
optim_ref = torch_optimizer(
model_ref.parameters(),
lr,
weight_decay=weight_decay,
+ **optimizer_kwargs,
)
torch_scheduler = torch.optim.lr_scheduler.LinearLR(
- optim_ref, start_factor=1.0, end_factor=0.1, total_iters=total_iters
+ optim_ref,
+ start_factor=1.0,
+ end_factor=0.1,
+ total_iters=total_iters,
)
for xs, ys in loader:
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 0c80cec0..d1be7c6f 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -96,7 +96,7 @@ def test_extract_state_dict():
state_dict = torchopt.extract_state_dict(fc, by='deepcopy', device=torch.device('meta'))
for param_dict in state_dict.params:
- for k, v in param_dict.items():
+ for v in param_dict.values():
assert v.is_meta
assert v.grad_fn is None
diff --git a/tests/test_zero_order.py b/tests/test_zero_order.py
index ac7ae840..61f75f9a 100644
--- a/tests/test_zero_order.py
+++ b/tests/test_zero_order.py
@@ -14,6 +14,7 @@
# ==============================================================================
import functorch
+import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -58,17 +59,20 @@ def test_zero_order(lr: float, method: str, sigma: float) -> None:
distribution = torch.distributions.Normal(loc=0, scale=1)
@torchopt.diff.zero_order(
- distribution=distribution, method=method, argnums=0, sigma=sigma, num_samples=num_samples
+ distribution=distribution,
+ method=method,
+ argnums=0,
+ sigma=sigma,
+ num_samples=num_samples,
)
def forward_process(params, fn, x, y):
y_pred = fn(params, x)
- loss = F.mse_loss(y_pred, y)
- return loss
+ return F.mse_loss(y_pred, y)
optimizer = torchopt.adam(lr=lr)
opt_state = optimizer.init(params) # init optimizer
- for i in range(num_iterations):
+ for _ in range(num_iterations):
loss = forward_process(params, fmodel, x, y) # compute loss
grads = torch.autograd.grad(loss, params) # compute gradients
@@ -91,7 +95,10 @@ def test_zero_order_module(lr: float, method: str, sigma: float) -> None:
num_samples = 500
class FcNetWithLoss(
- torchopt.nn.ZeroOrderGradientModule, method=method, sigma=sigma, num_samples=num_samples
+ torchopt.nn.ZeroOrderGradientModule,
+ method=method,
+ sigma=sigma,
+ num_samples=num_samples,
):
def __init__(self, dim, out):
super().__init__()
@@ -102,7 +109,7 @@ def __init__(self, dim, out):
def forward(self, x, y):
return self.loss(self.net(x), y)
- def sample(self, sample_shape=torch.Size()):
+ def sample(self, sample_shape=torch.Size()): # noqa: B008
return self.distribution.sample(sample_shape)
x = torch.randn(batch_size, input_size) * coef
@@ -111,9 +118,55 @@ def sample(self, sample_shape=torch.Size()):
optimizer = torchopt.Adam(model_with_loss.parameters(), lr=lr)
- for i in range(num_iterations):
+ for _ in range(num_iterations):
loss = model_with_loss(x, y) # compute loss
optimizer.zero_grad()
loss.backward() # compute gradients
optimizer.step() # update network parameters
+
+
+def test_module_enable_zero_order_gradients_twice() -> None:
+ class MyModule(torchopt.nn.ZeroOrderGradientModule):
+ def forward(self):
+ return torch.tensor(0.0)
+
+ def sample(self, sample_shape):
+ return torch.tensor(0.0)
+
+ from torchopt.diff.zero_order.nn.module import enable_zero_order_gradients
+
+ with pytest.raises(
+ TypeError,
+ match='Zero-order gradient estimation is already enabled for the `forward` method.',
+ ):
+ enable_zero_order_gradients(MyModule)
+
+
+def test_module_empty_parameters() -> None:
+ class MyModule(torchopt.nn.ZeroOrderGradientModule):
+ def forward(self):
+ return torch.tensor(0.0)
+
+ def sample(self, sample_shape):
+ return torch.tensor(0.0)
+
+ m = MyModule()
+ with pytest.raises(RuntimeError, match='The module has no parameters.'):
+ m()
+
+
+def test_module_abstract_methods() -> None:
+ class MyModule1(torchopt.nn.ZeroOrderGradientModule):
+ def forward(self):
+ return torch.tensor(0.0)
+
+ with pytest.raises(TypeError, match="Can't instantiate abstract class"):
+ MyModule1()
+
+ class MyModule2(torchopt.nn.ZeroOrderGradientModule):
+ def sample(self, sample_shape):
+ return torch.tensor(0.0)
+
+ with pytest.raises(TypeError, match="Can't instantiate abstract class"):
+ MyModule2()
diff --git a/torchopt/_C/adam_op.pyi b/torchopt/_C/adam_op.pyi
index 7ecfe7c2..04f141fd 100644
--- a/torchopt/_C/adam_op.pyi
+++ b/torchopt/_C/adam_op.pyi
@@ -15,8 +15,6 @@
# pylint: disable=all
-from __future__ import annotations
-
import torch
def forward_(
@@ -41,10 +39,16 @@ def forward_updates(
count: int,
) -> torch.Tensor: ...
def backward_mu(
- dmu: torch.Tensor, updates: torch.Tensor, mu: torch.Tensor, b1: float
+ dmu: torch.Tensor,
+ updates: torch.Tensor,
+ mu: torch.Tensor,
+ b1: float,
) -> tuple[torch.Tensor, torch.Tensor]: ...
def backward_nu(
- dnu: torch.Tensor, updates: torch.Tensor, nu: torch.Tensor, b2: float
+ dnu: torch.Tensor,
+ updates: torch.Tensor,
+ nu: torch.Tensor,
+ b2: float,
) -> tuple[torch.Tensor, torch.Tensor]: ...
def backward_updates(
dupdates: torch.Tensor,
diff --git a/torchopt/__init__.py b/torchopt/__init__.py
index 0c36ac07..a8c9fa1d 100644
--- a/torchopt/__init__.py
+++ b/torchopt/__init__.py
@@ -33,13 +33,15 @@
visual,
)
from torchopt.accelerated_op import is_available as accelerated_op_available
-from torchopt.alias import adam, adamw, rmsprop, sgd
+from torchopt.alias import adagrad, adam, adamw, rmsprop, sgd
from torchopt.clip import clip_grad_norm
from torchopt.combine import chain
from torchopt.hook import register_hook
-from torchopt.optim import SGD, Adam, AdamW, Optimizer, RMSProp, RMSprop
+from torchopt.optim import SGD, AdaGrad, Adagrad, Adam, AdamW, Optimizer, RMSProp, RMSprop
from torchopt.optim.func import FuncOptimizer
from torchopt.optim.meta import (
+ MetaAdaGrad,
+ MetaAdagrad,
MetaAdam,
MetaAdamW,
MetaOptimizer,
@@ -63,6 +65,7 @@
'accelerated_op_available',
'adam',
'adamw',
+ 'adagrad',
'rmsprop',
'sgd',
'clip_grad_norm',
@@ -73,12 +76,16 @@
'SGD',
'Adam',
'AdamW',
+ 'AdaGrad',
+ 'Adagrad',
'RMSProp',
'RMSprop',
'MetaOptimizer',
'MetaSGD',
'MetaAdam',
'MetaAdamW',
+ 'MetaAdaGrad',
+ 'MetaAdagrad',
'MetaRMSProp',
'MetaRMSprop',
'FuncOptimizer',
diff --git a/torchopt/accelerated_op/__init__.py b/torchopt/accelerated_op/__init__.py
index ede60009..3ac943e3 100644
--- a/torchopt/accelerated_op/__init__.py
+++ b/torchopt/accelerated_op/__init__.py
@@ -43,5 +43,5 @@ def is_available(devices: Device | Iterable[Device] | None = None) -> bool:
updates = torch.tensor(1.0, device=device)
op(updates, updates, updates, 1)
return True
- except Exception: # pylint: disable=broad-except
+ except Exception: # noqa: BLE001 # pylint: disable=broad-except
return False
diff --git a/torchopt/accelerated_op/_src/adam_op.py b/torchopt/accelerated_op/_src/adam_op.py
index ab5ea195..c8fc8898 100644
--- a/torchopt/accelerated_op/_src/adam_op.py
+++ b/torchopt/accelerated_op/_src/adam_op.py
@@ -36,8 +36,8 @@ def forward_(
nu = nu.mul_(b2).addcmul_(updates, updates, value=1.0 - b2)
updates.copy_(
mu.div(1.0 - pow(b1, count)).div_(
- nu.div(1.0 - pow(b2, count)).add_(eps_root).sqrt_().add_(eps)
- )
+ nu.div(1.0 - pow(b2, count)).add_(eps_root).sqrt_().add_(eps),
+ ),
)
return updates, mu, nu
@@ -71,7 +71,7 @@ def forward_updates(
) -> torch.Tensor:
"""Adam forward updates."""
return new_mu.div(1.0 - pow(b1, count)).div_(
- new_nu.div(1.0 - pow(b2, count)).add_(eps_root).sqrt_().add_(eps)
+ new_nu.div(1.0 - pow(b2, count)).add_(eps_root).sqrt_().add_(eps),
)
diff --git a/torchopt/accelerated_op/adam_op.py b/torchopt/accelerated_op/adam_op.py
index 232513d6..d6f9e9f9 100644
--- a/torchopt/accelerated_op/adam_op.py
+++ b/torchopt/accelerated_op/adam_op.py
@@ -109,7 +109,14 @@ def backward(ctx: Any, *args: Any) -> Any:
updates, new_mu, new_nu = ctx.saved_tensors
b1, b2, _, eps_root, count = ctx.others
result = adam_op.backward_updates(
- dupdates, updates, new_mu, new_nu, b1, b2, eps_root, count
+ dupdates,
+ updates,
+ new_mu,
+ new_nu,
+ b1,
+ b2,
+ eps_root,
+ count,
)
return result[0], result[1], None
diff --git a/torchopt/alias/__init__.py b/torchopt/alias/__init__.py
index b00b3c35..ae7dd2b5 100644
--- a/torchopt/alias/__init__.py
+++ b/torchopt/alias/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2022 MetaOPT Team. All Rights Reserved.
+# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -31,10 +31,11 @@
# ==============================================================================
r"""The aliases of preset :class:`GradientTransformation`\s for optimizers."""
+from torchopt.alias.adagrad import adagrad
from torchopt.alias.adam import adam
from torchopt.alias.adamw import adamw
from torchopt.alias.rmsprop import rmsprop
from torchopt.alias.sgd import sgd
-__all__ = ['adam', 'adamw', 'rmsprop', 'sgd']
+__all__ = ['adagrad', 'adam', 'adamw', 'rmsprop', 'sgd']
diff --git a/torchopt/alias/adagrad.py b/torchopt/alias/adagrad.py
new file mode 100644
index 00000000..25910abd
--- /dev/null
+++ b/torchopt/alias/adagrad.py
@@ -0,0 +1,166 @@
+# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# This file is modified from:
+# https://github.com/deepmind/optax/blob/master/optax/_src/alias.py
+# ==============================================================================
+# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Preset :class:`GradientTransformation` for the AdaGrad optimizer."""
+
+import logging
+
+from torchopt.alias.utils import (
+ _get_use_chain_flat,
+ flip_sign_and_add_weight_decay,
+ scale_by_neg_lr,
+)
+from torchopt.combine import chain
+from torchopt.transform import scale_by_rss, scale_by_schedule
+from torchopt.typing import GradientTransformation, Numeric, Scalar, ScalarOrSchedule, Schedule
+
+
+__all__ = ['adagrad']
+
+
+def _adagrad_lr_schedule(
+ decay_rate: Scalar,
+ transition_begin: int = 0,
+) -> Schedule:
+ """Construct a schedule dedicated to AdaGrad optimizer.
+
+ This function applies an learning rate decay function to a provided initial value. The function
+ returns the decayed value as follows:
+
+ .. code-block:: python
+
+ decayed_value = init_value / (1 + count * decay_rate)
+
+ Args:
+ decay_rate (float): The decay rate.
+ transition_begin (int, optional): Must be *positive*. After how many steps to start
+ annealing. (default: :const:`0`)
+
+ Returns:
+ schedule: A function that maps step counts to values.
+ """
+ if transition_begin < 0: # pragma: no cover
+ logging.info(
+ 'The AdaGrad learning rate schedule was set with a negative `transition_begin` '
+ 'value; this will result in `transition_begin` falling back to `0`.',
+ )
+ transition_begin = 0
+
+ def schedule(count: Numeric) -> Numeric:
+ decreased_count = count - transition_begin
+ return 1 / (1 + decay_rate * decreased_count)
+
+ return schedule
+
+
+# pylint: disable-next=too-many-arguments
+def adagrad(
+ lr: ScalarOrSchedule = 1e-2,
+ lr_decay: float = 0.0,
+ weight_decay: float = 0.0,
+ initial_accumulator_value: float = 0.0,
+ eps: float = 1e-10,
+ *,
+ maximize: bool = False,
+) -> GradientTransformation:
+ """The functional AdaGrad optimizer.
+
+ AdaGrad is an algorithm for gradient based optimization that anneals the learning rate for each
+ parameter during the course of training.
+
+ .. warning::
+ AdaGrad's main limit is the monotonic accumulation of squared gradients in the denominator.
+ Since all terms are ``> 0``, the sum keeps growing during training, and the learning rate
+ eventually becomes very small.
+
+ References:
+ Duchi et al., 2011: https://jmlr.org/papers/v12/duchi11a.html
+
+ Args:
+ lr (float or callable, optional): This is a fixed global scaling factor or a learning rate
+ scheduler. (default: :const:`1e-2`)
+ lr_decay (float, optional): Learning rate decay. (default: :const:`0.0`)
+ weight_decay (float, optional): Weight decay, add L2 penalty to parameters.
+ (default: :const:`0.0`)
+ initial_accumulator_value (float, optional): Initial value for the accumulator.
+ (default: :const:`0.0`)
+ eps (float, optional): A small constant applied to denominator outside of the square root
+ (as in the Adam paper) to avoid dividing by zero when rescaling.
+ (default: :const:`1e-10`)
+ maximize (bool, optional): Maximize the params based on the objective, instead of minimizing.
+ (default: :data:`False`)
+
+ Returns:
+ The corresponding :class:`GradientTransformation` instance.
+
+ See Also:
+ The functional optimizer wrapper :class:`torchopt.FuncOptimizer`.
+ """
+ # pylint: disable=unneeded-not
+ if not (callable(lr) or lr >= 0.0): # pragma: no cover
+ raise ValueError(f'Invalid learning rate: {lr}')
+ if not lr_decay >= 0.0: # pragma: no cover
+ raise ValueError(f'Invalid lr_decay value: {lr_decay}')
+ if not weight_decay >= 0.0: # pragma: no cover
+ raise ValueError(f'Invalid weight_decay value: {weight_decay}')
+ if not initial_accumulator_value >= 0.0: # pragma: no cover
+ raise ValueError(f'Invalid initial_accumulator_value value: {initial_accumulator_value}')
+ if not eps >= 0.0: # pragma: no cover
+ raise ValueError(f'Invalid epsilon value: {eps}')
+ # pylint: enable=unneeded-not
+
+ chain_fn = chain
+ flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay
+ adagrad_scaler_fn = scale_by_rss
+ scale_by_neg_lr_fn = scale_by_neg_lr
+ scale_by_schedule_fn = scale_by_schedule
+
+ if _get_use_chain_flat(): # default behavior
+ chain_fn = chain_fn.flat # type: ignore[attr-defined]
+ flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay_fn.flat # type: ignore[attr-defined]
+ adagrad_scaler_fn = adagrad_scaler_fn.flat # type: ignore[attr-defined]
+ scale_by_neg_lr_fn = scale_by_neg_lr_fn.flat # type: ignore[attr-defined]
+ scale_by_schedule_fn = scale_by_schedule_fn.flat # type: ignore[attr-defined]
+
+ return chain_fn(
+ flip_sign_and_add_weight_decay_fn(weight_decay=weight_decay, maximize=maximize),
+ adagrad_scaler_fn(
+ initial_accumulator_value=initial_accumulator_value,
+ eps=eps,
+ ),
+ scale_by_schedule_fn(
+ step_size_fn=_adagrad_lr_schedule(
+ decay_rate=lr_decay,
+ transition_begin=0,
+ ),
+ ),
+ scale_by_neg_lr_fn(lr),
+ )
diff --git a/torchopt/alias/adam.py b/torchopt/alias/adam.py
index 08654577..dc889285 100644
--- a/torchopt/alias/adam.py
+++ b/torchopt/alias/adam.py
@@ -65,7 +65,7 @@ def adam(
exponential moving averages).
References:
- - Kingma et al, 2014: https://arxiv.org/abs/1412.6980
+ - Kingma et al., 2014: https://arxiv.org/abs/1412.6980
Args:
lr (float or callable, optional): This is a fixed global scaling factor or a learning rate
@@ -96,24 +96,21 @@ def adam(
"""
b1, b2 = betas # pylint: disable=invalid-name
# pylint: disable=unneeded-not
- if not (callable(lr) or 0.0 <= lr): # pragma: no cover
+ if not (callable(lr) or lr >= 0.0): # pragma: no cover
raise ValueError(f'Invalid learning rate: {lr}')
- if not 0.0 <= eps: # pragma: no cover
+ if not eps >= 0.0: # pragma: no cover
raise ValueError(f'Invalid epsilon value: {eps}')
if not 0.0 <= b1 < 1.0: # pragma: no cover
raise ValueError(f'Invalid beta parameter at index 0: {b1}')
if not 0.0 <= b2 < 1.0: # pragma: no cover
raise ValueError(f'Invalid beta parameter at index 1: {b2}')
- if not 0.0 <= weight_decay: # pragma: no cover
+ if not weight_decay >= 0.0: # pragma: no cover
raise ValueError(f'Invalid weight_decay value: {weight_decay}')
# pylint: enable=unneeded-not
chain_fn = chain
flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay
- if use_accelerated_op:
- adam_scaler_fn = scale_by_accelerated_adam
- else:
- adam_scaler_fn = scale_by_adam
+ adam_scaler_fn = scale_by_accelerated_adam if use_accelerated_op else scale_by_adam
scale_by_neg_lr_fn = scale_by_neg_lr
if _get_use_chain_flat(): # default behavior
diff --git a/torchopt/alias/adamw.py b/torchopt/alias/adamw.py
index 21ef84ef..e8bed2ab 100644
--- a/torchopt/alias/adamw.py
+++ b/torchopt/alias/adamw.py
@@ -69,7 +69,7 @@ def adamw(
does not behave as intended for adaptive gradient algorithms such as Adam.
References:
- - Loshchilov et al, 2019: https://arxiv.org/abs/1711.05101
+ - Loshchilov et al., 2019: https://arxiv.org/abs/1711.05101
Args:
lr (float or callable, optional): This is a fixed global scaling factor or a learning rate
@@ -81,7 +81,7 @@ def adamw(
(default: :const:`1e-8`)
weight_decay (float, optional): Strength of the weight decay regularization. Note that this
weight decay is multiplied with the learning rate. This is consistent with other
- frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight
+ frameworks such as PyTorch, but different from (Loshchilov et al., 2019) where the weight
decay is only multiplied with the "schedule multiplier", but not the base learning rate.
(default: :const:`1e-2`)
eps_root (float, optional): A small constant applied to denominator inside the square root
@@ -109,24 +109,21 @@ def adamw(
"""
b1, b2 = betas # pylint: disable=invalid-name
# pylint: disable=unneeded-not
- if not (callable(lr) or 0.0 <= lr): # pragma: no cover
+ if not (callable(lr) or lr >= 0.0): # pragma: no cover
raise ValueError(f'Invalid learning rate: {lr}')
- if not 0.0 <= eps: # pragma: no cover
+ if not eps >= 0.0: # pragma: no cover
raise ValueError(f'Invalid epsilon value: {eps}')
if not 0.0 <= b1 < 1.0: # pragma: no cover
raise ValueError(f'Invalid beta parameter at index 0: {b1}')
if not 0.0 <= b2 < 1.0: # pragma: no cover
raise ValueError(f'Invalid beta parameter at index 1: {b2}')
- if not 0.0 <= weight_decay: # pragma: no cover
+ if not weight_decay >= 0.0: # pragma: no cover
raise ValueError(f'Invalid weight_decay value: {weight_decay}')
# pylint: enable=unneeded-not
chain_fn = chain
flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay
- if use_accelerated_op:
- adam_scaler_fn = scale_by_accelerated_adam
- else:
- adam_scaler_fn = scale_by_adam
+ adam_scaler_fn = scale_by_accelerated_adam if use_accelerated_op else scale_by_adam
add_decayed_weights_fn = add_decayed_weights
scale_by_neg_lr_fn = scale_by_neg_lr
diff --git a/torchopt/alias/rmsprop.py b/torchopt/alias/rmsprop.py
index f0eb92cd..96092548 100644
--- a/torchopt/alias/rmsprop.py
+++ b/torchopt/alias/rmsprop.py
@@ -96,24 +96,21 @@ def rmsprop(
The functional optimizer wrapper :class:`torchopt.FuncOptimizer`.
"""
# pylint: disable=unneeded-not
- if not (callable(lr) or 0.0 <= lr): # pragma: no cover
+ if not (callable(lr) or lr >= 0.0): # pragma: no cover
raise ValueError(f'Invalid learning rate: {lr}')
- if not 0.0 <= alpha: # pragma: no cover
+ if not alpha >= 0.0: # pragma: no cover
raise ValueError(f'Invalid alpha value: {alpha}')
- if not 0.0 <= eps: # pragma: no cover
+ if not eps >= 0.0: # pragma: no cover
raise ValueError(f'Invalid epsilon value: {eps}')
- if not 0.0 <= momentum: # pragma: no cover
+ if not momentum >= 0.0: # pragma: no cover
raise ValueError(f'Invalid momentum value: {momentum}')
- if not 0.0 <= weight_decay: # pragma: no cover
+ if not weight_decay >= 0.0: # pragma: no cover
raise ValueError(f'Invalid weight_decay value: {weight_decay}')
# pylint: enable=unneeded-not
chain_fn = chain
flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay
- if centered:
- rmsprop_scaler_fn = scale_by_stddev
- else:
- rmsprop_scaler_fn = scale_by_rms
+ rmsprop_scaler_fn = scale_by_stddev if centered else scale_by_rms
trace_fn = trace
scale_by_neg_lr_fn = scale_by_neg_lr
diff --git a/torchopt/alias/sgd.py b/torchopt/alias/sgd.py
index 7d86b538..6fb3c6db 100644
--- a/torchopt/alias/sgd.py
+++ b/torchopt/alias/sgd.py
@@ -61,7 +61,7 @@ def sgd(
deep neural networks.
References:
- - Sutskever et al, 2013: http://proceedings.mlr.press/v28/sutskever13.pdf
+ - Sutskever et al., 2013: http://proceedings.mlr.press/v28/sutskever13.pdf
Args:
lr (float or callable): This is a fixed global scaling factor or a learning rate
@@ -85,11 +85,11 @@ def sgd(
The functional optimizer wrapper :class:`torchopt.FuncOptimizer`.
"""
# pylint: disable=unneeded-not
- if not (callable(lr) or 0.0 <= lr): # pragma: no cover
+ if not (callable(lr) or lr >= 0.0): # pragma: no cover
raise ValueError(f'Invalid learning rate: {lr}')
- if not 0.0 <= momentum: # pragma: no cover
+ if not momentum >= 0.0: # pragma: no cover
raise ValueError(f'Invalid momentum value: {momentum}')
- if not 0.0 <= weight_decay: # pragma: no cover
+ if not weight_decay >= 0.0: # pragma: no cover
raise ValueError(f'Invalid weight_decay value: {weight_decay}')
if nesterov and (momentum <= 0.0 or dampening != 0.0): # pragma: no cover
raise ValueError('Nesterov momentum requires a momentum and zero dampening')
diff --git a/torchopt/alias/utils.py b/torchopt/alias/utils.py
index b5088164..1e626810 100644
--- a/torchopt/alias/utils.py
+++ b/torchopt/alias/utils.py
@@ -17,11 +17,13 @@
import threading
+import torch
+
from torchopt import pytree
from torchopt.base import EmptyState, GradientTransformation, identity
from torchopt.transform import scale, scale_by_schedule
from torchopt.transform.utils import tree_map_flat, tree_map_flat_
-from torchopt.typing import OptState, Params, ScalarOrSchedule, Updates
+from torchopt.typing import Numeric, OptState, Params, ScalarOrSchedule, Updates
__all__ = ['flip_sign_and_add_weight_decay', 'scale_by_neg_lr']
@@ -43,7 +45,8 @@ def _get_use_chain_flat() -> bool: # only used for testing purposes
def flip_sign_and_add_weight_decay(
- weight_decay: float = 0.0, maximize=False
+ weight_decay: float = 0.0,
+ maximize: bool = False,
) -> GradientTransformation:
"""Flip the sign of the updates and adds weight decay."""
return _flip_sign_and_add_weight_decay(
@@ -54,7 +57,8 @@ def flip_sign_and_add_weight_decay(
def _flip_sign_and_add_weight_decay_flat(
- weight_decay: float = 0.0, maximize=False
+ weight_decay: float = 0.0,
+ maximize: bool = False,
) -> GradientTransformation:
"""Flip the sign of the updates and adds weight decay."""
return _flip_sign_and_add_weight_decay(
@@ -66,13 +70,13 @@ def _flip_sign_and_add_weight_decay_flat(
def _flip_sign_and_add_weight_decay(
weight_decay: float = 0.0,
- maximize=False,
+ maximize: bool = False,
*,
already_flattened: bool = False,
) -> GradientTransformation:
"""Flip the sign of the updates and adds weight decay."""
# pylint: disable-next=unneeded-not
- if not 0.0 <= weight_decay: # pragma: no cover
+ if not weight_decay >= 0.0: # pragma: no cover
raise ValueError(f'Invalid weight_decay value: {weight_decay}')
if not maximize and weight_decay == 0.0:
@@ -104,7 +108,7 @@ def update_fn(
if inplace:
- def f(g, p):
+ def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
if g.requires_grad:
return g.add_(p, alpha=weight_decay)
return g.add_(p.data, alpha=weight_decay)
@@ -113,7 +117,7 @@ def f(g, p):
else:
- def f(g, p):
+ def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
return g.add(p, alpha=weight_decay)
updates = tree_map(f, updates, params)
@@ -132,14 +136,14 @@ def update_fn(
) -> tuple[Updates, OptState]:
if inplace:
- def f(g):
+ def f(g: torch.Tensor) -> torch.Tensor:
return g.neg_()
updates = tree_map_(f, updates)
else:
- def f(g):
+ def f(g: torch.Tensor) -> torch.Tensor:
return g.neg()
updates = tree_map(f, updates)
@@ -162,7 +166,7 @@ def update_fn(
if inplace:
- def f(g, p):
+ def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
if g.requires_grad:
return g.neg_().add_(p, alpha=weight_decay)
return g.neg_().add_(p.data, alpha=weight_decay)
@@ -171,7 +175,7 @@ def f(g, p):
else:
- def f(g, p):
+ def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
return g.neg().add_(p, alpha=weight_decay)
updates = tree_map(f, updates, params)
@@ -194,13 +198,17 @@ def _scale_by_neg_lr_flat(lr: ScalarOrSchedule) -> GradientTransformation:
return _scale_by_neg_lr(lr=lr, already_flattened=True)
-def _scale_by_neg_lr(lr: ScalarOrSchedule, *, already_flattened=False) -> GradientTransformation:
- if not (callable(lr) or 0.0 <= lr): # pragma: no cover
+def _scale_by_neg_lr(
+ lr: ScalarOrSchedule,
+ *,
+ already_flattened: bool = False,
+) -> GradientTransformation:
+ if not (callable(lr) or lr >= 0.0): # pragma: no cover
raise ValueError(f'Invalid learning rate: {lr}')
if callable(lr):
- def schedule_wrapper(count):
+ def schedule_wrapper(count: Numeric) -> Numeric:
return -lr(count) # type: ignore[operator]
return scale_by_schedule.impl( # type: ignore[attr-defined]
diff --git a/torchopt/base.py b/torchopt/base.py
index b250c387..b0a40afa 100644
--- a/torchopt/base.py
+++ b/torchopt/base.py
@@ -35,11 +35,10 @@
import itertools
from abc import abstractmethod
-from typing import TYPE_CHECKING, Callable, NamedTuple
-from typing_extensions import Protocol # Python 3.8+
+from typing import TYPE_CHECKING, Callable, NamedTuple, Protocol
-if TYPE_CHECKING: # pragma: no cover
+if TYPE_CHECKING:
from torchopt.typing import OptState, Params, Updates
@@ -168,7 +167,7 @@ def __new__(cls, *transformations: GradientTransformation) -> ChainedGradientTra
if isinstance(t, ChainedGradientTransformation)
else ((t,) if not isinstance(t, IdentityGradientTransformation) else ())
for t in transformations
- )
+ ),
)
if len(transformations) == 0:
@@ -189,7 +188,7 @@ def update_fn(
if len(update_fns) != len(state):
raise ValueError(
'The number of updates and states has to be the same in chain! Make sure you'
- 'have called init first!'
+ 'have called init first!',
)
new_state = []
for s, fn in zip(state, update_fns): # pylint: disable=invalid-name
@@ -236,7 +235,7 @@ def __reduce__(self) -> tuple[Callable, tuple[tuple[GradientTransformation, ...]
class IdentityGradientTransformation(GradientTransformation):
"""A gradient transformation that does nothing."""
- def __new__(cls):
+ def __new__(cls) -> IdentityGradientTransformation:
"""Create a new gradient transformation that does nothing."""
return super().__new__(cls, init=cls.init_fn, update=cls.update_fn)
diff --git a/torchopt/clip.py b/torchopt/clip.py
index b2aafb48..69da9afd 100644
--- a/torchopt/clip.py
+++ b/torchopt/clip.py
@@ -78,7 +78,7 @@ def update_fn(
raise RuntimeError(
f'The total norm of order {norm_type} for gradients from `parameters` is '
f'non-finite, so it cannot be clipped. To disable this error and scale the '
- f'gradients by the non-finite norm anyway, set `error_if_nonfinite=False`'
+ f'gradients by the non-finite norm anyway, set `error_if_nonfinite=False`',
)
clip_coefficient = max_norm / (float(total_norm) + 1e-6)
# Note: multiplying by the clamped coefficient is redundant when the coefficient is
@@ -88,12 +88,12 @@ def update_fn(
clip_coefficient_clamped = min(clip_coefficient, 1.0)
if inplace:
- def f(g):
+ def f(g: torch.Tensor) -> torch.Tensor:
return g.mul_(clip_coefficient_clamped)
else:
- def f(g):
+ def f(g: torch.Tensor) -> torch.Tensor:
return g.mul(clip_coefficient_clamped)
new_updates = pytree.tree_map(f, updates)
diff --git a/torchopt/combine.py b/torchopt/combine.py
index 0f1ed8ec..fc1a7152 100644
--- a/torchopt/combine.py
+++ b/torchopt/combine.py
@@ -74,10 +74,7 @@ def chain_flat(*transformations: GradientTransformation) -> GradientTransformati
"""
if len(transformations) == 0:
return identity()
- if len(transformations) == 1:
- inner = transformations[0]
- else:
- inner = chain(*transformations)
+ inner = transformations[0] if len(transformations) == 1 else chain(*transformations)
def init_fn(params: Params) -> OptState:
return inner.init(pytree.tree_leaves(params, none_is_leaf=True))
@@ -90,10 +87,7 @@ def update_fn(
inplace: bool = True,
) -> tuple[Updates, OptState]:
flat_updates, treespec = pytree.tree_flatten(updates, none_is_leaf=True)
- if params is not None:
- flat_params = pytree.tree_leaves(params, none_is_leaf=True)
- else:
- flat_params = None
+ flat_params = pytree.tree_leaves(params, none_is_leaf=True) if params is not None else None
flat_updates, state = inner.update(flat_updates, state, params=flat_params, inplace=inplace)
updates: Updates
diff --git a/torchopt/diff/implicit/decorator.py b/torchopt/diff/implicit/decorator.py
index a5908963..031aa11f 100644
--- a/torchopt/diff/implicit/decorator.py
+++ b/torchopt/diff/implicit/decorator.py
@@ -91,7 +91,7 @@ def _root_vjp(
grad_outputs: TupleOfTensors,
output_is_tensor: bool,
argnums: tuple[int, ...],
- solve: Callable[..., TensorOrTensors] = linear_solve.solve_normal_cg(),
+ solve: Callable[..., TensorOrTensors],
) -> TupleOfOptionalTensors:
if output_is_tensor:
@@ -123,18 +123,20 @@ def matvec(u: TupleOfTensors) -> TupleOfTensors:
u: TupleOfTensors = solve(matvec, v) # type: ignore[assignment]
masked_optimality_fn = MaskedOptimalityFn(
- optimality_fn, solution, output_is_tensor, argnums, *args
+ optimality_fn,
+ solution,
+ output_is_tensor,
+ argnums,
+ *args,
)
_, optimality_vjp_fn, *_ = functorch.vjp(
- masked_optimality_fn, *masked_optimality_fn.post_filled
+ masked_optimality_fn,
+ *masked_optimality_fn.post_filled,
)
output: TupleOfTensors
- if output_is_tensor:
- output = optimality_vjp_fn(u[0])
- else:
- output = optimality_vjp_fn(u)
+ output = optimality_vjp_fn(u[0]) if output_is_tensor else optimality_vjp_fn(u)
# Prepend None as the vjp for init_params.
true_output: ListOfOptionalTensors = [None]
@@ -161,7 +163,9 @@ def _signature_bind(signature: inspect.Signature, *args: Any, **kwargs: Any) ->
def _signature_bind_and_match(
- signature: inspect.Signature, *args: Any, **kwargs: Any
+ signature: inspect.Signature,
+ *args: Any,
+ **kwargs: Any,
) -> tuple[Args, KwArgs, Callable[[Args], tuple[Args, KwArgs]]]:
# We want to bind *args and **kwargs based on the provided signature, but also to associate the
# resulting positional arguments back. To achieve this, we lift arguments to a triple:
@@ -179,7 +183,7 @@ def _signature_bind_and_match(
mapping = [(was_kwarg, ref) for was_kwarg, ref, _ in bound.args]
- def map_args_back(out_args):
+ def map_args_back(out_args: Args) -> tuple[Args, KwArgs]:
src_args = [None] * len(args)
src_kwargs = {}
for (was_kwarg, ref), out_arg in zip(mapping, out_args):
@@ -187,7 +191,7 @@ def map_args_back(out_args):
src_kwargs[ref] = out_arg
else:
src_args[ref] = out_arg
- return src_args, src_kwargs
+ return tuple(src_args), src_kwargs
out_args = tuple(v for _, _, v in bound.args)
out_kwargs = {k: v for k, (_, _, v) in bound.kwargs.items()}
@@ -259,7 +263,8 @@ def make_custom_vjp_solver_fn(
class ImplicitMetaGradient(Function):
@staticmethod
def forward( # type: ignore[override] # pylint: disable=arguments-differ
- ctx: Any, *flat_args: Any
+ ctx: Any,
+ *flat_args: Any,
) -> tuple[Any, ...]:
output, aux, output_is_tensor = None, None, False
@@ -278,7 +283,7 @@ def forward( # type: ignore[override] # pylint: disable=arguments-differ
raise RuntimeError(
f'custom_root(optimality_fn)(solver_fn)(*args): output of function '
f'solver_fn should be a tuple: (output, aux) if has_aux is True. '
- f'Got {output}'
+ f'Got {output}',
)
output, aux = output
if isinstance(output, torch.Tensor):
@@ -288,7 +293,7 @@ def forward( # type: ignore[override] # pylint: disable=arguments-differ
raise RuntimeError(
f'custom_root(optimality_fn)(solver_fn)(*args): output of function '
f'solver_fn should be a torch.Tensor or a tuple of torch.Tensor. '
- f'Got {output}'
+ f'Got {output}',
)
output = tuple(t.data for t in output)
@@ -309,7 +314,8 @@ def forward( # type: ignore[override] # pylint: disable=arguments-differ
@staticmethod
def backward( # pylint: disable=too-many-locals
- ctx: Any, *grad_outputs: Any
+ ctx: Any,
+ *grad_outputs: Any,
) -> TupleOfTensors:
grad_outputs: TupleOfTensors = grad_outputs[:-3]
@@ -320,13 +326,18 @@ def backward( # pylint: disable=too-many-locals
args_is_tensor_mask = ctx.args_is_tensor_mask
args_non_tensors = ctx.args_non_tensors
args = _merge_tensor_and_others(
- args_treespec, args_is_tensor_mask, args_tensors, args_non_tensors
+ args_treespec,
+ args_is_tensor_mask,
+ args_tensors,
+ args_non_tensors,
)
args, kwargs = _extract_kwargs(kwarg_keys, args)
bound_args, bound_kwargs, map_args_back = _signature_bind_and_match(
- reference_signature, *args, **kwargs # type: ignore[arg-type]
+ reference_signature, # type: ignore[arg-type]
+ *args,
+ **kwargs,
)
if bound_kwargs:
raise TypeError(
@@ -334,7 +345,7 @@ def backward( # pylint: disable=too-many-locals
f'arguments based on the signature {reference_signature}. This can '
f'happen under custom_root if optimality_fn takes catch-all **kwargs, or '
f'under custom_fixed_point if fixed_point_fn takes catch-all **kwargs, '
- f'both of which are currently unsupported.'
+ f'both of which are currently unsupported.',
)
# Compute VJPs w.r.t. args.
@@ -349,7 +360,7 @@ def backward( # pylint: disable=too-many-locals
)
args_vjps, kwargs_vjps = map_args_back(vjps)
- ordered_vjps = tuple(args_vjps) + tuple(kwargs_vjps[k] for k in kwargs.keys())
+ ordered_vjps = tuple(args_vjps) + tuple(kwargs_vjps[k] for k in kwargs)
true_vjps = []
for (_, _, arg_seq_type), vjp in zip(args_signs, ordered_vjps):
if arg_seq_type is not None:
@@ -362,7 +373,8 @@ def backward( # pylint: disable=too-many-locals
@functools.wraps(solver_fn)
def wrapped_solver_fn(
- *args: Any, **kwargs: Any
+ *args: Any,
+ **kwargs: Any,
) -> TensorOrTensors | tuple[TensorOrTensors, Any]:
args, kwargs = _signature_bind(solver_fn_signature, *args, **kwargs)
keys, vals = list(kwargs.keys()), list(kwargs.values())
@@ -379,7 +391,7 @@ def wrapped_solver_fn(
elif isinstance(arg, (tuple, list)) and all(map(torch.is_tensor, arg)):
nargs = len(arg)
args_signs.append(
- (args_offset, nargs, type(arg)) # start position, sequence type
+ (args_offset, nargs, type(arg)), # start position, sequence type
)
flat_args.extend(arg)
args_offset += nargs
@@ -387,7 +399,7 @@ def wrapped_solver_fn(
raise RuntimeError(
'custom_root(optimality_fn)(solver_fn)(*args): argument of function '
'solver_fn specified with `argnums` should be a torch.Tensor or a tuple of '
- 'torch.Tensor'
+ 'torch.Tensor',
)
else:
args_signs.append((args_offset, 1, None)) # start position, None
@@ -399,10 +411,7 @@ def wrapped_solver_fn(
result = make_custom_vjp_solver_fn(solver_fn, keys, args_signs).apply(*flat_args, *vals)
*output, aux, output_is_tensor, output_type = result
- if output_is_tensor:
- output = output[0]
- else:
- output = output_type(output)
+ output = output[0] if output_is_tensor else output_type(output)
if has_aux:
return output, aux
return output
@@ -414,7 +423,7 @@ def custom_root(
optimality_fn: Callable[..., TensorOrTensors],
argnums: int | tuple[int, ...],
has_aux: bool = False,
- solve: Callable[..., TensorOrTensors] = linear_solve.solve_normal_cg(),
+ solve: Callable[..., TensorOrTensors] | None = None,
) -> Callable[
[Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]]],
Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]],
@@ -465,6 +474,9 @@ def solver_fn(params, arg1, arg2, ...):
else:
assert 0 not in argnums
+ if solve is None:
+ solve = linear_solve.solve_normal_cg()
+
return functools.partial(
_custom_root,
optimality_fn=optimality_fn,
diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py
index bbae37c9..a72e5304 100644
--- a/torchopt/diff/implicit/nn/module.py
+++ b/torchopt/diff/implicit/nn/module.py
@@ -20,6 +20,7 @@
import abc
import functools
+import inspect
import itertools
from typing import Any, Iterable
@@ -36,38 +37,40 @@
def _stateless_objective_fn(
- __flat_params: TupleOfTensors,
- __flat_meta_params: TupleOfTensors,
- __params_names: Iterable[str],
- __meta_params_names: Iterable[str],
+ flat_params: TupleOfTensors,
+ flat_meta_params: TupleOfTensors,
+ params_names: Iterable[str],
+ meta_params_names: Iterable[str],
self: ImplicitMetaGradientModule,
- *input,
- **kwargs,
+ /,
+ *input: Any,
+ **kwargs: Any,
) -> torch.Tensor:
with reparametrize(
self,
itertools.chain(
- zip(__params_names, __flat_params),
- zip(__meta_params_names, __flat_meta_params),
+ zip(params_names, flat_params),
+ zip(meta_params_names, flat_meta_params),
),
):
return self.objective(*input, **kwargs)
def _stateless_optimality_fn(
- __flat_params: TupleOfTensors,
- __flat_meta_params: TupleOfTensors,
- __params_names: Iterable[str],
- __meta_params_names: Iterable[str],
+ flat_params: TupleOfTensors,
+ flat_meta_params: TupleOfTensors,
+ params_names: Iterable[str],
+ meta_params_names: Iterable[str],
self: ImplicitMetaGradientModule,
- *input,
- **kwargs,
+ /,
+ *input: Any,
+ **kwargs: Any,
) -> TupleOfTensors:
with reparametrize(
self,
itertools.chain(
- zip(__params_names, __flat_params),
- zip(__meta_params_names, __flat_meta_params),
+ zip(params_names, flat_params),
+ zip(meta_params_names, flat_meta_params),
),
):
return self.optimality(*input, **kwargs)
@@ -76,19 +79,24 @@ def _stateless_optimality_fn(
def make_optimality_from_objective(
cls: type[ImplicitMetaGradientModule],
) -> type[ImplicitMetaGradientModule]:
- """Derives the optimality function of the objective function."""
- if (
- getattr(cls, 'objective', ImplicitMetaGradientModule.objective)
- is ImplicitMetaGradientModule.objective
- ):
+ """Derive the optimality function of the objective function."""
+ static_super_objective = inspect.getattr_static(ImplicitMetaGradientModule, 'objective')
+ static_cls_objective = inspect.getattr_static(cls, 'objective', static_super_objective)
+ if static_cls_objective is static_super_objective:
raise TypeError('The objective function is not defined.')
- def optimality(self: ImplicitMetaGradientModule, *input, **kwargs) -> TupleOfTensors:
- params_names, flat_params = tuple(zip(*self.named_parameters()))
- meta_params_names, flat_meta_params = tuple(zip(*self.named_meta_parameters()))
+ def optimality(self: ImplicitMetaGradientModule, *input: Any, **kwargs: Any) -> TupleOfTensors:
+ named_params = tuple(self.named_parameters())
+ named_meta_params = tuple(self.named_meta_parameters())
+ if len(named_params) == 0:
+ raise RuntimeError('The module has no parameters.')
+ if len(named_meta_params) == 0:
+ raise RuntimeError('The module has no meta-parameters.')
+ params_names, flat_params = tuple(zip(*named_params))
+ meta_params_names, flat_meta_params = tuple(zip(*named_meta_params))
objective_grad_fn = functorch.grad(_stateless_objective_fn, argnums=0)
- flat_grads = objective_grad_fn(
+ return objective_grad_fn(
flat_params,
flat_meta_params,
params_names,
@@ -97,9 +105,8 @@ def optimality(self: ImplicitMetaGradientModule, *input, **kwargs) -> TupleOfTen
*input,
**kwargs,
)
- return flat_grads
- cls.optimality = optimality # type: ignore[assignment]
+ cls.optimality = optimality # type: ignore[method-assign]
return cls
@@ -111,22 +118,20 @@ def enable_implicit_gradients(
if getattr(cls_solve, '__implicit_gradients_enabled__', False):
raise TypeError('Implicit gradients are already enabled for the `solve` method.')
- if cls.linear_solve is not None:
- solve_kwargs = {'solve': cls.linear_solve}
- else:
- solve_kwargs = {}
+ solve_kwargs = {'solve': cls.linear_solve} if cls.linear_solve is not None else {}
@custom_root(_stateless_optimality_fn, argnums=1, has_aux=True, **solve_kwargs)
def stateless_solver_fn(
# pylint: disable=unused-argument
- __flat_params: TupleOfTensors,
- __flat_meta_params: TupleOfTensors,
- __params_names: Iterable[str],
- __meta_params_names: Iterable[str],
+ flat_params: TupleOfTensors,
+ flat_meta_params: TupleOfTensors,
+ params_names: Iterable[str],
+ meta_params_names: Iterable[str],
# pylint: enable=unused-argument
self: ImplicitMetaGradientModule,
- *input,
- **kwargs,
+ /,
+ *input: Any,
+ **kwargs: Any,
) -> tuple[TupleOfTensors, Any]:
"""Solve the optimization problem."""
output = cls_solve(self, *input, **kwargs)
@@ -134,10 +139,16 @@ def stateless_solver_fn(
return flat_optimal_params, output
@functools.wraps(cls_solve)
- def wrapped(self: ImplicitMetaGradientModule, *input, **kwargs) -> Any:
+ def wrapped(self: ImplicitMetaGradientModule, *input: Any, **kwargs: Any) -> Any:
"""Solve the optimization problem."""
- params_names, flat_params = tuple(zip(*self.named_parameters()))
- meta_params_names, flat_meta_params = tuple(zip(*self.named_meta_parameters()))
+ named_params = tuple(self.named_parameters())
+ named_meta_params = tuple(self.named_meta_parameters())
+ if len(named_params) == 0:
+ raise RuntimeError('The module has no parameters.')
+ if len(named_meta_params) == 0:
+ raise RuntimeError('The module has no meta-parameters.')
+ params_names, flat_params = tuple(zip(*named_params))
+ meta_params_names, flat_meta_params = tuple(zip(*named_meta_params))
flat_optimal_params, output = stateless_solver_fn(
flat_params,
@@ -152,11 +163,11 @@ def wrapped(self: ImplicitMetaGradientModule, *input, **kwargs) -> Any:
return output
wrapped.__implicit_gradients_enabled__ = True # type: ignore[attr-defined]
- cls.solve = wrapped # type: ignore[assignment]
+ cls.solve = wrapped # type: ignore[method-assign]
return cls
-class ImplicitMetaGradientModule(MetaGradientModule):
+class ImplicitMetaGradientModule(MetaGradientModule, metaclass=abc.ABCMeta):
"""The base class for differentiable implicit meta-gradient models."""
_custom_optimality: bool
@@ -168,28 +179,30 @@ def __init_subclass__(cls, linear_solve: LinearSolver | None = None) -> None:
super().__init_subclass__()
cls.linear_solve = linear_solve
- optimality = getattr(cls, 'optimality', ImplicitMetaGradientModule.optimality)
- objective = getattr(cls, 'objective', ImplicitMetaGradientModule.objective)
- cls._custom_optimality = optimality is not ImplicitMetaGradientModule.optimality
- cls._custom_objective = objective is not ImplicitMetaGradientModule.objective
+ static_super_optimality = inspect.getattr_static(ImplicitMetaGradientModule, 'optimality')
+ static_super_objective = inspect.getattr_static(ImplicitMetaGradientModule, 'objective')
+ static_cls_optimality = inspect.getattr_static(cls, 'optimality')
+ static_cls_objective = inspect.getattr_static(cls, 'objective')
+ cls._custom_optimality = static_cls_optimality is not static_super_optimality
+ cls._custom_objective = static_cls_objective is not static_super_objective
if cls._custom_optimality:
- if isinstance(optimality, staticmethod):
+ if isinstance(static_cls_optimality, staticmethod):
raise TypeError('method optimality() must not be a staticmethod.')
- if isinstance(optimality, classmethod):
+ if isinstance(static_cls_optimality, classmethod):
raise TypeError('method optimality() must not be a classmethod.')
- if not callable(optimality):
+ if not callable(static_cls_optimality):
raise TypeError('method optimality() must be callable.')
elif not cls._custom_objective:
raise TypeError(
- 'ImplicitMetaGradientModule requires either an optimality() method or an objective() method'
+ 'ImplicitMetaGradientModule requires either an optimality() method or an objective() method',
)
else:
- if isinstance(objective, staticmethod):
+ if isinstance(static_cls_objective, staticmethod):
raise TypeError('method objective() must not be a staticmethod.')
- if isinstance(objective, classmethod):
+ if isinstance(static_cls_objective, classmethod):
raise TypeError('method objective() must not be a classmethod.')
- if not callable(objective):
+ if not callable(static_cls_objective):
raise TypeError('method objective() must be callable.')
make_optimality_from_objective(cls)
@@ -197,7 +210,7 @@ def __init_subclass__(cls, linear_solve: LinearSolver | None = None) -> None:
enable_implicit_gradients(cls)
@abc.abstractmethod
- def solve(self, *input, **kwargs) -> Any:
+ def solve(self, *input: Any, **kwargs: Any) -> Any:
"""Solve the inner optimization problem.
.. warning::
@@ -207,24 +220,25 @@ def solve(self, *input, **kwargs) -> Any:
(including the meta-parameters) that were used to compute the objective output.
Alternatively, please use :func:`torch.autograd.grad` instead.
- Example::
-
- def solve(self, batch, labels):
- parameters = tuple(self.parameters())
- optimizer = torch.optim.Adam(parameters, lr=1e-3)
- with torch.enable_grad():
- for _ in range(100):
- loss = self.objective(batch, labels)
- optimizer.zero_grad()
- # Only update the `.grad` attribute for parameters
- # and leave the meta-parameters unchanged
- loss.backward(inputs=parameters)
- optimizer.step()
- return self
+ Examples:
+ .. code-block:: python
+
+ def solve(self, batch, labels):
+ parameters = tuple(self.parameters())
+ optimizer = torch.optim.Adam(parameters, lr=1e-3)
+ with torch.enable_grad():
+ for _ in range(100):
+ loss = self.objective(batch, labels)
+ optimizer.zero_grad()
+ # Only update the `.grad` attribute for parameters
+ # and leave the meta-parameters unchanged
+ loss.backward(inputs=parameters)
+ optimizer.step()
+ return self
"""
raise NotImplementedError # update parameters
- def optimality(self, *input, **kwargs) -> TupleOfTensors:
+ def optimality(self, *input: Any, **kwargs: Any) -> TupleOfTensors:
r"""Compute the optimality residual.
This method stands for the optimality residual to the optimal parameters after solving the
@@ -267,7 +281,7 @@ def optimality(self, *input, **kwargs) -> TupleOfTensors:
""" # pylint: disable=line-too-long
raise NotImplementedError
- def objective(self, *input, **kwargs) -> torch.Tensor:
+ def objective(self, *input: Any, **kwargs: Any) -> torch.Tensor:
"""Compute the objective function value.
This method is used to calculate the :meth:`optimality` if it is not implemented.
diff --git a/torchopt/diff/zero_order/__init__.py b/torchopt/diff/zero_order/__init__.py
index 5b85d03d..b621ffdc 100644
--- a/torchopt/diff/zero_order/__init__.py
+++ b/torchopt/diff/zero_order/__init__.py
@@ -16,6 +16,9 @@
import sys as _sys
from types import ModuleType as _ModuleType
+from typing import Any, Callable
+
+import torch
from torchopt.diff.zero_order import nn
from torchopt.diff.zero_order.decorator import zero_order
@@ -26,7 +29,11 @@
class _CallableModule(_ModuleType): # pylint: disable=too-few-public-methods
- def __call__(self, *args, **kwargs):
+ def __call__(
+ self,
+ *args: Any,
+ **kwargs: Any,
+ ) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]:
return self.zero_order(*args, **kwargs)
diff --git a/torchopt/diff/zero_order/decorator.py b/torchopt/diff/zero_order/decorator.py
index 43522028..f63f0574 100644
--- a/torchopt/diff/zero_order/decorator.py
+++ b/torchopt/diff/zero_order/decorator.py
@@ -17,8 +17,7 @@
from __future__ import annotations
import functools
-from typing import Any, Callable, Sequence
-from typing_extensions import Literal # Python 3.8+
+from typing import Any, Callable, Literal, Sequence
from typing_extensions import TypeAlias # Python 3.10+
import torch
@@ -35,7 +34,10 @@ def __init__(self, sample_fn: SampleFunc) -> None:
"""Wrap a sample function to make it a :class:`Samplable` object."""
self.sample_fn = sample_fn
- def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor | Sequence[Numeric]:
+ def sample(
+ self,
+ sample_shape: torch.Size = torch.Size(), # noqa: B008
+ ) -> torch.Tensor | Sequence[Numeric]:
# pylint: disable-next=line-too-long
"""Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched."""
return self.sample_fn(sample_shape)
@@ -46,7 +48,7 @@ def _zero_order_naive( # pylint: disable=too-many-statements
distribution: Samplable,
argnums: tuple[int, ...],
num_samples: int,
- sigma: Numeric,
+ sigma: float,
) -> Callable[..., torch.Tensor]:
@functools.wraps(fn)
def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements
@@ -89,7 +91,8 @@ def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor:
@staticmethod
def backward( # pylint: disable=too-many-locals
- ctx: Any, *grad_outputs: Any
+ ctx: Any,
+ *grad_outputs: Any,
) -> TupleOfOptionalTensors:
saved_tensors = ctx.saved_tensors
flat_diff_params = saved_tensors[: ctx.len_params]
@@ -109,18 +112,22 @@ def backward( # pylint: disable=too-many-locals
args: list[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment]
- def add_perturbation(tensor, noises):
- return tensor.add(noises, alpha=sigma)
+ def add_perturbation(
+ tensor: torch.Tensor,
+ noise: torch.Tensor | Numeric,
+ ) -> torch.Tensor:
+ return tensor.add(noise, alpha=sigma)
param_grads: ListOfTensors = [0.0 for _ in range(len(flat_diff_params))] # type: ignore[misc]
for _ in range(num_samples):
noises = [distribution.sample(sample_shape=p.shape) for p in flat_diff_params]
flat_noisy_params = [
- add_perturbation(t, n) for t, n in zip(flat_diff_params, noises)
+ add_perturbation(t, n) for t, n in zip(flat_diff_params, noises) # type: ignore[arg-type]
]
noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment]
- diff_params_treespec, flat_noisy_params
+ diff_params_treespec,
+ flat_noisy_params,
)
for argnum, noisy_param in zip(argnums, noisy_params):
@@ -147,7 +154,7 @@ def _zero_order_forward( # pylint: disable=too-many-statements
distribution: Samplable,
argnums: tuple[int, ...],
num_samples: int,
- sigma: Numeric,
+ sigma: float,
) -> Callable[..., torch.Tensor]:
@functools.wraps(fn)
def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements
@@ -190,7 +197,8 @@ def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor:
@staticmethod
def backward( # pylint: disable=too-many-locals
- ctx: Any, *grad_outputs: Any
+ ctx: Any,
+ *grad_outputs: Any,
) -> TupleOfOptionalTensors:
saved_tensors = ctx.saved_tensors
flat_diff_params = saved_tensors[: ctx.len_params]
@@ -211,18 +219,19 @@ def backward( # pylint: disable=too-many-locals
args: list[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment]
- def add_perturbation(tensor, noises):
- return tensor.add(noises, alpha=sigma)
+ def add_perturbation(tensor: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
+ return tensor.add(noise, alpha=sigma)
param_grads: ListOfTensors = [0.0 for _ in range(len(flat_diff_params))] # type: ignore[misc]
for _ in range(num_samples):
noises = [distribution.sample(sample_shape=p.shape) for p in flat_diff_params]
flat_noisy_params = [
- add_perturbation(t, n) for t, n in zip(flat_diff_params, noises)
+ add_perturbation(t, n) for t, n in zip(flat_diff_params, noises) # type: ignore[arg-type]
]
noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment]
- diff_params_treespec, flat_noisy_params
+ diff_params_treespec,
+ flat_noisy_params,
)
for argnum, noisy_param in zip(argnums, noisy_params):
@@ -250,7 +259,7 @@ def _zero_order_antithetic( # pylint: disable=too-many-statements
distribution: Samplable,
argnums: tuple[int, ...],
num_samples: int,
- sigma: Numeric,
+ sigma: float,
) -> Callable[..., torch.Tensor]:
@functools.wraps(fn)
def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements
@@ -292,7 +301,10 @@ def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor:
return output
@staticmethod
- def backward(ctx: Any, *grad_outputs: Any): # pylint: disable=too-many-locals
+ def backward( # pylint: disable=too-many-locals
+ ctx: Any,
+ *grad_outputs: Any,
+ ) -> TupleOfOptionalTensors:
saved_tensors = ctx.saved_tensors
flat_diff_params = saved_tensors[: ctx.len_params]
tensors = saved_tensors[ctx.len_params :]
@@ -313,13 +325,17 @@ def backward(ctx: Any, *grad_outputs: Any): # pylint: disable=too-many-locals
param_grads: ListOfTensors = [0.0 for _ in range(len(flat_diff_params))] # type: ignore[misc]
- def get_output(add_perturbation_fn, noises) -> torch.Tensor:
+ def get_output(
+ add_perturbation_fn: Callable,
+ noises: Sequence[torch.Tensor | Numeric],
+ ) -> torch.Tensor:
flat_noisy_params = [
add_perturbation_fn(t, n, alpha=sigma)
for t, n in zip(flat_diff_params, noises)
]
noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment]
- diff_params_treespec, flat_noisy_params
+ diff_params_treespec,
+ flat_noisy_params,
)
for argnum, noisy_param in zip(argnums, noisy_params):
@@ -329,7 +345,7 @@ def get_output(add_perturbation_fn, noises) -> torch.Tensor:
for _ in range(num_samples):
noises = [distribution.sample(sample_shape=p.shape) for p in flat_diff_params]
- output = get_output(torch.add, noises) - get_output(torch.sub, noises)
+ output = get_output(torch.add, noises) - get_output(torch.sub, noises) # type: ignore[arg-type]
weighted_grad = grad_outputs[0].mul(output).mul_(0.5 / sigma)
for i, noise in enumerate(noises):
@@ -353,7 +369,7 @@ def zero_order(
method: Method = 'naive',
argnums: int | tuple[int, ...] = (0,),
num_samples: int = 1,
- sigma: Numeric = 1.0,
+ sigma: float = 1.0,
) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]:
"""Return a decorator for applying zero-order differentiation.
@@ -369,7 +385,7 @@ def zero_order(
respect to. (default: :const:`0`)
num_samples (int, optional): The number of sample to get the averaged estimated gradient.
(default: :const:`1`)
- sigma (float or Tensor, optional): The standard deviation of the perturbation.
+ sigma (float, optional): The standard deviation of the perturbation.
(default: :const:`1.0`)
Returns:
diff --git a/torchopt/diff/zero_order/nn/module.py b/torchopt/diff/zero_order/nn/module.py
index 65014fb9..75da28f9 100644
--- a/torchopt/diff/zero_order/nn/module.py
+++ b/torchopt/diff/zero_order/nn/module.py
@@ -20,7 +20,7 @@
import abc
import functools
-from typing import Sequence
+from typing import Any, Sequence
import torch
import torch.nn as nn
@@ -37,25 +37,28 @@ def enable_zero_order_gradients(
cls: type[ZeroOrderGradientModule],
method: Method = 'naive',
num_samples: int = 1,
- sigma: Numeric = 1.0,
+ sigma: float = 1.0,
) -> type[ZeroOrderGradientModule]:
"""Enable zero-order gradient estimation for the :func:`forward` method."""
cls_forward = cls.forward
if getattr(cls_forward, '__zero_order_gradients_enabled__', False):
raise TypeError(
- 'Zero-order gradient estimation is already enabled for the `forward` method.'
+ 'Zero-order gradient estimation is already enabled for the `forward` method.',
)
@functools.wraps(cls_forward)
- def wrapped(self: ZeroOrderGradientModule, *input, **kwargs) -> torch.Tensor:
+ def wrapped(self: ZeroOrderGradientModule, *input: Any, **kwargs: Any) -> torch.Tensor:
"""Do the forward pass calculation."""
- params_names, flat_params = tuple(zip(*self.named_parameters()))
+ named_params = tuple(self.named_parameters())
+ if len(named_params) == 0:
+ raise RuntimeError('The module has no parameters.')
+ params_names, flat_params = tuple(zip(*named_params))
@zero_order(self.sample, argnums=0, method=method, num_samples=num_samples, sigma=sigma)
def forward_fn(
__flat_params: TupleOfTensors,
- *input,
- **kwargs,
+ *input: Any,
+ **kwargs: Any,
) -> torch.Tensor:
with reparametrize(self, zip(params_names, __flat_params)):
return cls_forward(self, *input, **kwargs)
@@ -63,7 +66,7 @@ def forward_fn(
return forward_fn(flat_params, *input, **kwargs)
wrapped.__zero_order_gradients_enabled__ = True # type: ignore[attr-defined]
- cls.forward = wrapped # type: ignore[assignment]
+ cls.forward = wrapped # type: ignore[method-assign]
return cls
@@ -74,7 +77,7 @@ def __init_subclass__( # pylint: disable=arguments-differ
cls,
method: Method = 'naive',
num_samples: int = 1,
- sigma: Numeric = 1.0,
+ sigma: float = 1.0,
) -> None:
"""Validate and initialize the subclass."""
super().__init_subclass__()
@@ -86,13 +89,14 @@ def __init_subclass__( # pylint: disable=arguments-differ
)
@abc.abstractmethod
- def forward(self, *args, **kwargs) -> torch.Tensor:
+ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
"""Do the forward pass of the model."""
raise NotImplementedError
@abc.abstractmethod
def sample(
- self, sample_shape: torch.Size = torch.Size() # pylint: disable=unused-argument
+ self,
+ sample_shape: torch.Size = torch.Size(), # noqa: B008 # pylint: disable=unused-argument
) -> torch.Tensor | Sequence[Numeric]:
# pylint: disable-next=line-too-long
"""Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched."""
diff --git a/torchopt/distributed/__init__.py b/torchopt/distributed/__init__.py
index 4272e37a..534b2dea 100644
--- a/torchopt/distributed/__init__.py
+++ b/torchopt/distributed/__init__.py
@@ -18,8 +18,8 @@
import torch.distributed.rpc as rpc
from torchopt.distributed import api, autograd, world
-from torchopt.distributed.api import *
-from torchopt.distributed.world import *
+from torchopt.distributed.api import * # noqa: F403
+from torchopt.distributed.world import * # noqa: F403
__all__ = ['is_available', *api.__all__, *world.__all__]
diff --git a/torchopt/distributed/api.py b/torchopt/distributed/api.py
index b46ad67e..3a6f0526 100644
--- a/torchopt/distributed/api.py
+++ b/torchopt/distributed/api.py
@@ -54,10 +54,7 @@
]
-if rpc.is_available():
- UNSET_RPC_TIMEOUT = rpc.api.UNSET_RPC_TIMEOUT
-else:
- UNSET_RPC_TIMEOUT = -1.0
+UNSET_RPC_TIMEOUT = rpc.api.UNSET_RPC_TIMEOUT if rpc.is_available() else -1.0
T = TypeVar('T')
@@ -134,7 +131,7 @@ def __call__(
elif batch_size != arg.shape[self.dim]: # type: ignore[unreachable]
raise ValueError(
f'Batch size mismatch on dim={self.dim}. '
- f'Expected {batch_size}, got {arg.shape[self.dim]} (shape: {arg.shape}).'
+ f'Expected {batch_size}, got {arg.shape[self.dim]} (shape: {arg.shape}).',
)
if batch_size is None:
diff --git a/torchopt/distributed/autograd.py b/torchopt/distributed/autograd.py
index 17fa9463..c2a4b3e2 100644
--- a/torchopt/distributed/autograd.py
+++ b/torchopt/distributed/autograd.py
@@ -31,14 +31,14 @@
LOCK = Lock()
-def is_available():
+def is_available() -> bool:
"""Check if distributed autograd module is available."""
return autograd.is_available()
if is_available():
# pylint: disable-next=unused-import,ungrouped-imports
- from torch.distributed.autograd import DistAutogradContext, get_gradients
+ from torch.distributed.autograd import DistAutogradContext, get_gradients # noqa: F401
def backward(
autograd_ctx_id: int,
@@ -69,7 +69,7 @@ def backward(
raise RuntimeError("'inputs' argument to backward() cannot be empty.")
else:
inputs = tuple(inputs)
- if not all(map(lambda t: t.requires_grad, inputs)):
+ if not all(t.requires_grad for t in inputs):
raise RuntimeError('One of the differentiated Tensors does not require grad')
roots = [tensors] if isinstance(tensors, torch.Tensor) else list(tensors)
@@ -111,7 +111,7 @@ def grad(
"""
outputs = [outputs] if isinstance(outputs, torch.Tensor) else list(outputs)
inputs = (inputs,) if isinstance(inputs, torch.Tensor) else tuple(inputs)
- if not all(map(lambda t: t.requires_grad, inputs)):
+ if not all(t.requires_grad for t in inputs):
raise RuntimeError('One of the differentiated Tensors does not require grad')
autograd.backward(autograd_ctx_id, roots=outputs, retain_graph=retain_graph)
@@ -125,7 +125,7 @@ def grad(
if not allow_unused:
raise RuntimeError(
'One of the differentiated Tensors appears to not have been used in the '
- 'graph. Set allow_unused=True if this is the desired behavior.'
+ 'graph. Set allow_unused=True if this is the desired behavior.',
) from ex
grads.append(None) # type: ignore[arg-type]
diff --git a/torchopt/distributed/world.py b/torchopt/distributed/world.py
index 804d4b9d..a9821ee0 100644
--- a/torchopt/distributed/world.py
+++ b/torchopt/distributed/world.py
@@ -166,7 +166,7 @@ def wrapper(func: F) -> F:
@record
@functools.wraps(func)
- def wrapped(*args, **kwargs):
+ def wrapped(*args: Any, **kwargs: Any) -> Any:
rpc.init_rpc(
name=world_info.worker_name,
rank=world_info.rank,
@@ -193,7 +193,7 @@ def wrapper(func: F) -> F:
world_rank = get_world_info().world_rank
@functools.wraps(func)
- def wrapped(*args, **kwargs):
+ def wrapped(*args: Any, **kwargs: Any) -> Any:
if inverse:
if world_rank not in ranks:
return func(*args, **kwargs)
@@ -211,7 +211,7 @@ def on_rank(*ranks: int) -> Callable[[F], F]:
return __on_ranks(ranks=ranks, inverse=False)
-def not_on_rank(*ranks) -> Callable[[F], F]:
+def not_on_rank(*ranks: int) -> Callable[[F], F]:
"""Return a decorator to mark a function to be executed only on non given ranks."""
return __on_ranks(ranks=ranks, inverse=True)
diff --git a/torchopt/hook.py b/torchopt/hook.py
index f188415c..13ed6abf 100644
--- a/torchopt/hook.py
+++ b/torchopt/hook.py
@@ -34,7 +34,9 @@ def zero_nan_hook(g: torch.Tensor) -> torch.Tensor:
def nan_to_num_hook(
- nan: float = 0.0, posinf: float | None = None, neginf: float | None = None
+ nan: float = 0.0,
+ posinf: float | None = None,
+ neginf: float | None = None,
) -> Callable[[torch.Tensor], torch.Tensor]:
"""Return a ``nan`` to num hook to replace ``nan`` / ``+inf`` / ``-inf`` with the given numbers."""
@@ -45,7 +47,7 @@ def hook(g: torch.Tensor) -> torch.Tensor:
return hook
-def register_hook(hook) -> GradientTransformation:
+def register_hook(hook: Callable[[torch.Tensor], torch.Tensor | None]) -> GradientTransformation:
"""Stateless identity transformation that leaves input gradients untouched.
This function passes through the *gradient updates* unchanged.
@@ -64,7 +66,7 @@ def update_fn(
params: Params | None = None, # pylint: disable=unused-argument
inplace: bool = True, # pylint: disable=unused-argument
) -> tuple[Updates, OptState]:
- def f(g):
+ def f(g: torch.Tensor) -> torch.utils.hooks.RemovableHandle:
return g.register_hook(hook)
pytree.tree_map_(f, updates)
diff --git a/torchopt/linalg/cg.py b/torchopt/linalg/cg.py
index 5456f076..9cd57cd8 100644
--- a/torchopt/linalg/cg.py
+++ b/torchopt/linalg/cg.py
@@ -70,12 +70,14 @@ def _cg_solve(
b2 = tree_vdot_real(b, b)
atol2 = max(rtol**2 * b2, atol**2)
- def cond_fn(value):
+ def cond_fn(value: tuple[TensorTree, TensorTree, float, TensorTree, int]) -> bool:
_, r, gamma, _, k = value
rs = gamma if M is _identity else tree_vdot_real(r, r)
return rs > atol2 and k < maxiter
- def body_fn(value):
+ def body_fn(
+ value: tuple[TensorTree, TensorTree, float, TensorTree, int],
+ ) -> tuple[TensorTree, TensorTree, float, TensorTree, int]:
x, r, gamma, p, k = value
Ap = A(p)
alpha = gamma / tree_vdot_real(p, Ap)
@@ -125,13 +127,11 @@ def _isolve(
if cat_shapes(x0) != cat_shapes(b):
raise ValueError(
- f'Tensors in x0 and b must have matching shapes: {cat_shapes(x0)} vs. {cat_shapes(b)}.'
+ f'Tensors in x0 and b must have matching shapes: {cat_shapes(x0)} vs. {cat_shapes(b)}.',
)
isolve_solve = partial(_isolve_solve, x0=x0, rtol=rtol, atol=atol, maxiter=maxiter, M=M)
-
- x = isolve_solve(A, b)
- return x
+ return isolve_solve(A, b)
def cg(
diff --git a/torchopt/linalg/ns.py b/torchopt/linalg/ns.py
index c1975203..747ad3cf 100644
--- a/torchopt/linalg/ns.py
+++ b/torchopt/linalg/ns.py
@@ -112,12 +112,12 @@ def ns(
return inv_A_hat_b
-def _ns_inv(A: torch.Tensor, maxiter: int, alpha: float | None = None):
+def _ns_inv(A: torch.Tensor, maxiter: int, alpha: float | None = None) -> torch.Tensor:
"""Use Neumann Series iteration to solve ``A^{-1}``."""
if A.ndim != 2 or A.shape[0] != A.shape[1]:
raise ValueError(f'`A` must be a square matrix, but has shape: {A.shape}')
- I = torch.eye(*A.shape, out=torch.empty_like(A))
+ I = torch.eye(*A.shape, out=torch.empty_like(A)) # noqa: E741
inv_A_hat = torch.zeros_like(A)
if alpha is not None:
# A^{-1} = a [I - (I - a A)]^{-1} = a [I + (I - a A) + (I - a A)^2 + (I - a A)^3 + ...]
diff --git a/torchopt/linalg/utils.py b/torchopt/linalg/utils.py
index f301a624..e3cd197e 100644
--- a/torchopt/linalg/utils.py
+++ b/torchopt/linalg/utils.py
@@ -32,7 +32,7 @@ def cat_shapes(tree: TensorTree) -> tuple[int, ...]:
def normalize_matvec(
- matvec: TensorTree | Callable[[TensorTree], TensorTree]
+ matvec: TensorTree | Callable[[TensorTree], TensorTree],
) -> Callable[[TensorTree], TensorTree]:
"""Normalize an argument for computing matrix-vector product."""
if callable(matvec):
@@ -48,7 +48,7 @@ def _matvec(x: TensorTree) -> TensorTree:
if len(x_flat) != len(mat_flat):
raise ValueError(
f'`x` must have the same number of leaves as `matvec`, '
- f'but has {len(x_flat)} leaves and `matvec` has {len(mat_flat)} leaves'
+ f'but has {len(x_flat)} leaves and `matvec` has {len(mat_flat)} leaves',
)
y_flat = map(torch.matmul, mat_flat, x_flat)
diff --git a/torchopt/linear_solve/cg.py b/torchopt/linear_solve/cg.py
index 844c9407..e8f9fb77 100644
--- a/torchopt/linear_solve/cg.py
+++ b/torchopt/linear_solve/cg.py
@@ -36,11 +36,11 @@
from __future__ import annotations
import functools
-from typing import Callable
+from typing import Any, Callable
from torchopt import linalg
from torchopt.linear_solve.utils import make_ridge_matvec
-from torchopt.typing import TensorTree
+from torchopt.typing import LinearSolver, TensorTree
__all__ = ['solve_cg']
@@ -51,7 +51,7 @@ def _solve_cg(
b: TensorTree,
ridge: float | None = None,
init: TensorTree | None = None,
- **kwargs,
+ **kwargs: Any,
) -> TensorTree:
"""Solve ``A x = b`` using conjugate gradient.
@@ -78,7 +78,7 @@ def _solve_cg(
return linalg.cg(matvec, b, x0=init, **kwargs)
-def solve_cg(**kwargs):
+def solve_cg(**kwargs: Any) -> LinearSolver:
"""Return a solver function to solve ``A x = b`` using conjugate gradient.
This assumes that ``A`` is a hermitian, positive definite matrix.
@@ -98,8 +98,7 @@ def solve_cg(**kwargs):
See Also:
Conjugate gradient iteration :func:`torchopt.linalg.cg`.
- Example::
-
+ Examples:
>>> A = {'a': torch.eye(5, 5), 'b': torch.eye(3, 3)}
>>> x = {'a': torch.randn(5), 'b': torch.randn(3)}
>>> def matvec(x: TensorTree) -> TensorTree:
@@ -108,6 +107,5 @@ def solve_cg(**kwargs):
>>> solver = solve_cg(init={'a': torch.zeros(5), 'b': torch.zeros(3)})
>>> x_hat = solver(matvec, b)
>>> assert torch.allclose(x_hat['a'], x['a']) and torch.allclose(x_hat['b'], x['b'])
-
"""
return functools.partial(_solve_cg, **kwargs)
diff --git a/torchopt/linear_solve/inv.py b/torchopt/linear_solve/inv.py
index 399a0ef9..e2a377d5 100644
--- a/torchopt/linear_solve/inv.py
+++ b/torchopt/linear_solve/inv.py
@@ -36,13 +36,13 @@
from __future__ import annotations
import functools
-from typing import Callable
+from typing import Any, Callable
import torch
from torchopt import linalg, pytree
from torchopt.linear_solve.utils import make_ridge_matvec, materialize_matvec
-from torchopt.typing import TensorTree
+from torchopt.typing import LinearSolver, TensorTree
__all__ = ['solve_inv']
@@ -53,7 +53,7 @@ def _solve_inv(
b: TensorTree,
ridge: float | None = None,
ns: bool = False,
- **kwargs,
+ **kwargs: Any,
) -> TensorTree:
"""Solve ``A x = b`` using matrix inversion.
@@ -91,7 +91,7 @@ def _solve_inv(
return tree_unravel(pytree.tree_map(torch.linalg.solve, A, tree_ravel(b)))
-def solve_inv(**kwargs):
+def solve_inv(**kwargs: Any) -> LinearSolver:
"""Return a solver function to solve ``A x = b`` using matrix inversion.
If ``ns = False``, this assumes the matrix ``A`` is a constant matrix and will materialize it
@@ -113,8 +113,7 @@ def solve_inv(**kwargs):
See Also:
Neumann Series matrix inversion approximation :func:`torchopt.linalg.ns`.
- Example::
-
+ Examples:
>>> A = {'a': torch.eye(5, 5), 'b': torch.eye(3, 3)}
>>> x = {'a': torch.randn(5), 'b': torch.randn(3)}
>>> def matvec(x: TensorTree) -> TensorTree:
@@ -123,6 +122,5 @@ def solve_inv(**kwargs):
>>> solver = solve_inv(ns=True, maxiter=10)
>>> x_hat = solver(matvec, b)
>>> assert torch.allclose(x_hat['a'], x['a']) and torch.allclose(x_hat['b'], x['b'])
-
"""
return functools.partial(_solve_inv, **kwargs)
diff --git a/torchopt/linear_solve/normal_cg.py b/torchopt/linear_solve/normal_cg.py
index 8d38f77a..78813ecb 100644
--- a/torchopt/linear_solve/normal_cg.py
+++ b/torchopt/linear_solve/normal_cg.py
@@ -36,11 +36,11 @@
from __future__ import annotations
import functools
-from typing import Callable
+from typing import Any, Callable
from torchopt import linalg
from torchopt.linear_solve.utils import make_normal_matvec, make_ridge_matvec, make_rmatvec
-from torchopt.typing import TensorTree
+from torchopt.typing import LinearSolver, TensorTree
__all__ = ['solve_normal_cg']
@@ -51,7 +51,7 @@ def _solve_normal_cg(
b: TensorTree,
ridge: float | None = None,
init: TensorTree | None = None,
- **kwargs,
+ **kwargs: Any,
) -> TensorTree:
"""Solve the normal equation ``A^T A x = A^T b`` using conjugate gradient.
@@ -71,10 +71,7 @@ def _solve_normal_cg(
Returns:
The solution with the same structure as ``b``.
"""
- if init is None:
- example_x = b # This assumes that matvec is a square linear operator.
- else:
- example_x = init
+ example_x = b if init is None else init
rmatvec = make_rmatvec(matvec, example_x) # (x) -> A.T @ x
normal_matvec = make_normal_matvec(matvec) # (x) -> A.T @ A @ x
@@ -90,7 +87,7 @@ def _solve_normal_cg(
return linalg.cg(normal_matvec, rhs, x0=init, **kwargs)
-def solve_normal_cg(**kwargs):
+def solve_normal_cg(**kwargs: Any) -> LinearSolver:
"""Return a solver function to solve ``A^T A x = A^T b`` using conjugate gradient.
This can be used to solve ``A x = b`` using conjugate gradient when ``A`` is not hermitian,
@@ -111,8 +108,7 @@ def solve_normal_cg(**kwargs):
See Also:
Conjugate gradient iteration :func:`torchopt.linalg.cg`.
- Example::
-
+ Examples:
>>> A = {'a': torch.randn(5, 5), 'b': torch.randn(3, 3)}
>>> x = {'a': torch.randn(5), 'b': torch.randn(3)}
>>> def matvec(x: TensorTree) -> TensorTree:
@@ -121,6 +117,5 @@ def solve_normal_cg(**kwargs):
>>> solver = solve_normal_cg(init={'a': torch.zeros(5), 'b': torch.zeros(3)})
>>> x_hat = solver(matvec, b)
>>> assert torch.allclose(x_hat['a'], x['a']) and torch.allclose(x_hat['b'], x['b'])
-
"""
return functools.partial(_solve_normal_cg, **kwargs)
diff --git a/torchopt/linear_solve/utils.py b/torchopt/linear_solve/utils.py
index f4f34e2a..22dcec6f 100644
--- a/torchopt/linear_solve/utils.py
+++ b/torchopt/linear_solve/utils.py
@@ -42,7 +42,8 @@
def make_rmatvec(
- matvec: Callable[[TensorTree], TensorTree], example_x: TensorTree
+ matvec: Callable[[TensorTree], TensorTree],
+ example_x: TensorTree,
) -> Callable[[TensorTree], TensorTree]:
"""Return a function that computes ``rmatvec(y) = A.T @ y`` from ``matvec(x) = A @ x``."""
_, vjp, *_ = functorch.vjp(matvec, example_x)
@@ -51,7 +52,7 @@ def make_rmatvec(
def make_normal_matvec(
- matvec: Callable[[TensorTree], TensorTree]
+ matvec: Callable[[TensorTree], TensorTree],
) -> Callable[[TensorTree], TensorTree]:
"""Return a function that computes ``normal_matvec(y) = A.T @ A @ y`` from ``matvec(x) = A @ x``."""
@@ -64,7 +65,8 @@ def normal_matvec(y: TensorTree) -> TensorTree:
def make_ridge_matvec(
- matvec: Callable[[TensorTree], TensorTree], ridge: float = 0.0
+ matvec: Callable[[TensorTree], TensorTree],
+ ridge: float = 0.0,
) -> Callable[[TensorTree], TensorTree]:
"""Return a function that computes ``ridge_matvec(y) = A.T @ A @ y + ridge * y`` from ``matvec(x) = A @ x``."""
@@ -76,7 +78,8 @@ def ridge_matvec(y: TensorTree) -> TensorTree:
def materialize_matvec(
- matvec: Callable[[TensorTree], TensorTree], x: TensorTree
+ matvec: Callable[[TensorTree], TensorTree],
+ x: TensorTree,
) -> tuple[
TensorTree,
Callable[[TensorTree], TensorTree],
diff --git a/torchopt/nn/module.py b/torchopt/nn/module.py
index f8804864..09ab359e 100644
--- a/torchopt/nn/module.py
+++ b/torchopt/nn/module.py
@@ -40,7 +40,7 @@ class MetaGradientModule(nn.Module): # pylint: disable=abstract-method
_meta_parameters: TensorContainer
_meta_modules: dict[str, nn.Module | None]
- def __new__(cls, *args, **kwargs) -> MetaGradientModule:
+ def __new__(cls, *args: Any, **kwargs: Any) -> MetaGradientModule:
"""Create a new module instance."""
instance = super().__new__(cls)
flat_args: list[Any]
@@ -56,7 +56,7 @@ def __new__(cls, *args, **kwargs) -> MetaGradientModule:
instance._meta_modules: dict[str, nn.Module | None] = OrderedDict() # type: ignore[misc]
return instance
- def __init__(self, *args, **kwargs) -> None: # pylint: disable=unused-argument
+ def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=unused-argument
"""Initialize a new module instance."""
super().__init__()
@@ -88,7 +88,7 @@ def __getattr__(self, name: str) -> torch.Tensor | nn.Module:
def __setattr__(self, name: str, value: torch.Tensor | nn.Module) -> None:
"""Set an attribute of the module."""
- def remove_from(*dicts_or_sets):
+ def remove_from(*dicts_or_sets: dict[str, Any] | set[str]) -> None:
for dict_or_set in dicts_or_sets:
if name in dict_or_set:
if isinstance(dict_or_set, dict):
@@ -103,7 +103,7 @@ def remove_from(*dicts_or_sets):
raise AttributeError('cannot assign parameters before Module.__init__() call')
if meta_params is None:
raise AttributeError(
- 'cannot assign meta-parameters before MetaGradientModule.__init__() call'
+ 'cannot assign meta-parameters before MetaGradientModule.__init__() call',
)
remove_from(
self.__dict__,
@@ -121,14 +121,14 @@ def remove_from(*dicts_or_sets):
if value is not None:
raise TypeError(
f"cannot assign '{torch.typename(value)}' as parameter '{name}' "
- f'(torch.Tensor or None expected)'
+ f'(torch.Tensor or None expected)',
)
self.register_parameter(name, value) # type: ignore[unreachable]
elif meta_params is not None and name in meta_params:
if value is not None:
raise TypeError(
f"cannot assign '{torch.typename(value)}' as meta-parameter '{name}' "
- f'(torch.Tensor or None expected)'
+ f'(torch.Tensor or None expected)',
)
self.register_meta_parameter(name, value) # type: ignore[unreachable]
else:
@@ -139,7 +139,7 @@ def remove_from(*dicts_or_sets):
raise AttributeError('cannot assign module before Module.__init__() call')
if meta_modules is None:
raise AttributeError(
- 'cannot assign module before MetaGradientModule.__init__() call'
+ 'cannot assign module before MetaGradientModule.__init__() call',
)
remove_from(
self.__dict__,
@@ -157,7 +157,7 @@ def remove_from(*dicts_or_sets):
if value is not None:
raise TypeError(
f"cannot assign '{torch.typename(value)}' as child module '{name}' "
- f'(torch.nn.Module or None expected)'
+ f'(torch.nn.Module or None expected)',
)
modules[name] = value # type: ignore[unreachable]
else:
@@ -166,7 +166,7 @@ def remove_from(*dicts_or_sets):
if value is not None and not isinstance(value, torch.Tensor):
raise TypeError(
f"cannot assign '{torch.typename(value)}' as buffer '{name}' "
- f'(torch.Tensor or None expected)'
+ f'(torch.Tensor or None expected)',
)
buffers[name] = value
else:
@@ -218,16 +218,16 @@ def register_parameter(self, name: str, param: torch.Tensor | None) -> None:
if not isinstance(param, torch.Tensor):
raise TypeError(
f"cannot assign '{torch.typename(param)}' object to parameter '{name}' "
- f'(torch.Tensor or None required)'
+ f'(torch.Tensor or None required)',
)
if not param.requires_grad:
raise ValueError(
- f"cannot assign Tensor that `requires_grad=False` to parameter '{name}'"
+ f"cannot assign Tensor that `requires_grad=False` to parameter '{name}'",
)
if param in self._meta_inputs.meta_parameters:
raise ValueError(
f"cannot assign Tensor that is a meta-parameter to parameter '{name}'. "
- f'Use self.register_meta_parameter() instead.'
+ f'Use self.register_meta_parameter() instead.',
)
self._parameters[name] = param # type: ignore
@@ -246,7 +246,7 @@ def register_meta_parameter(self, name: str, param: torch.Tensor | None) -> None
"""
if '_meta_parameters' not in self.__dict__:
raise AttributeError(
- 'cannot assign meta-parameter before MetaGradientModule.__init__() call'
+ 'cannot assign meta-parameter before MetaGradientModule.__init__() call',
)
if not isinstance(name, str):
raise TypeError(f'meta-parameter name should be a string. Got {torch.typename(name)}')
@@ -264,11 +264,11 @@ def register_meta_parameter(self, name: str, param: torch.Tensor | None) -> None
if not isinstance(param, torch.Tensor):
raise TypeError(
f"cannot assign '{torch.typename(param)}' object to meta-parameter '{name}' "
- f'(torch.Tensor or None required)'
+ f'(torch.Tensor or None required)',
)
if not param.requires_grad:
raise ValueError(
- f"cannot assign Tensor that `requires_grad=False` to meta-parameter '{name}'"
+ f"cannot assign Tensor that `requires_grad=False` to meta-parameter '{name}'",
)
self._meta_parameters[name] = param
@@ -296,7 +296,7 @@ def add_module(self, name: str, module: nn.Module | None) -> None:
if module in self._meta_inputs.meta_modules:
raise ValueError(
f"cannot add module that is a meta-module to module '{name}'. "
- f'Use self.add_meta_module() instead.'
+ f'Use self.add_meta_module() instead.',
)
self._modules[name] = module
@@ -345,19 +345,19 @@ def meta_parameters(self, recurse: bool = True) -> Iterator[torch.Tensor]:
Yields:
Parameter: module meta-parameter
- Example::
-
+ Examples:
>>> for param in model.meta_parameters():
>>> print(type(param), param.size())
(20L,)
(20L, 1L, 5L, 5L)
-
"""
for _, meta_param in self.named_meta_parameters(recurse=recurse):
yield meta_param
def named_meta_parameters(
- self, prefix: str = '', recurse: bool = True
+ self,
+ prefix: str = '',
+ recurse: bool = True,
) -> Iterator[tuple[str, torch.Tensor]]:
r"""Return an iterator over module meta-parameters, yielding both the name of the meta-parameter as well as the meta-parameter itself.
@@ -371,12 +371,10 @@ def named_meta_parameters(
Yields:
(string, Parameter): Tuple containing the name and parameter
- Example::
-
+ Examples:
>>> for name, meta_param in self.named_meta_parameters():
>>> if name in ['bias']:
>>> print(meta_param.size())
-
""" # pylint: disable=line-too-long
memo = set()
for name, param in getattr(self, '_meta_parameters', {}).items():
@@ -405,12 +403,10 @@ def named_meta_children(self) -> Iterator[tuple[str, nn.Module]]:
Yields:
(string, Module): Tuple containing a name and child meta-module
- Example::
-
+ Examples:
>>> for name, meta_module in model.named_meta_children():
>>> if name in ['conv4', 'conv5']:
>>> print(meta_module)
-
""" # pylint: disable=line-too-long
memo = set()
for name, meta_module in self._meta_modules.items():
@@ -431,7 +427,10 @@ def meta_modules(self) -> Iterator[nn.Module]:
yield meta_module
def named_meta_modules(
- self, memo: set[nn.Module] | None = None, prefix: str = '', remove_duplicate: bool = True
+ self,
+ memo: set[nn.Module] | None = None,
+ prefix: str = '',
+ remove_duplicate: bool = True,
) -> Iterator[tuple[str, nn.Module]]:
r"""Return an iterator over all meta-modules in the network, yielding both the name of the meta-module as well as the meta-module itself.
diff --git a/torchopt/nn/stateless.py b/torchopt/nn/stateless.py
index 9391352f..e547b5cb 100644
--- a/torchopt/nn/stateless.py
+++ b/torchopt/nn/stateless.py
@@ -57,10 +57,7 @@ def recursive_setattr(path: str, value: torch.Tensor) -> torch.Tensor:
prefix, _, attr = path.rpartition('.')
mod = get_submodule(prefix)
- if allow_missing:
- orig = getattr(mod, attr, MISSING)
- else:
- orig = getattr(mod, attr)
+ orig = getattr(mod, attr, MISSING) if allow_missing else getattr(mod, attr)
# pylint: disable=protected-access
if value is MISSING:
@@ -77,10 +74,7 @@ def recursive_setattr(path: str, value: torch.Tensor) -> torch.Tensor:
return orig
- orig_named_tensors = {
- name: recursive_setattr(name, tensor) for name, tensor in named_tensors.items()
- }
- return orig_named_tensors
+ return {name: recursive_setattr(name, tensor) for name, tensor in named_tensors.items()}
@contextlib.contextmanager
diff --git a/torchopt/optim/__init__.py b/torchopt/optim/__init__.py
index b75da23c..8e390a5c 100644
--- a/torchopt/optim/__init__.py
+++ b/torchopt/optim/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2022 MetaOPT Team. All Rights Reserved.
+# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,6 +15,7 @@
"""object oriented optimizer implementations."""
from torchopt.optim import meta
+from torchopt.optim.adagrad import AdaGrad, Adagrad
from torchopt.optim.adam import Adam
from torchopt.optim.adamw import AdamW
from torchopt.optim.base import Optimizer
diff --git a/torchopt/optim/adagrad.py b/torchopt/optim/adagrad.py
new file mode 100644
index 00000000..055e0ad5
--- /dev/null
+++ b/torchopt/optim/adagrad.py
@@ -0,0 +1,82 @@
+# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""AdaGrad optimizer."""
+
+from __future__ import annotations
+
+from typing import Iterable
+
+import torch
+
+from torchopt import alias
+from torchopt.optim.base import Optimizer
+from torchopt.typing import ScalarOrSchedule
+
+
+__all__ = ['AdaGrad', 'Adagrad']
+
+
+class AdaGrad(Optimizer):
+ """The classic AdaGrad optimizer.
+
+ See Also:
+ - The functional AdaGrad optimizer: :func:`torchopt.adagrad`.
+ - The differentiable meta AdaGrad optimizer: :class:`torchopt.MetaAdaGrad`.
+ """
+
+ # pylint: disable-next=too-many-arguments
+ def __init__(
+ self,
+ params: Iterable[torch.Tensor],
+ lr: ScalarOrSchedule = 1e-2,
+ lr_decay: float = 0.0,
+ weight_decay: float = 0.0,
+ initial_accumulator_value: float = 0.0,
+ eps: float = 1e-10,
+ *,
+ maximize: bool = False,
+ ) -> None:
+ r"""Initialize the AdaGrad optimizer.
+
+ Args:
+ params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what
+ tensors should be optimized.
+ lr (float or callable, optional): This is a fixed global scaling factor or a learning
+ rate scheduler. (default: :const:`1e-2`)
+ lr_decay (float, optional): Learning rate decay. (default: :const:`0.0`)
+ weight_decay (float, optional): Weight decay, add L2 penalty to parameters.
+ (default: :const:`0.0`)
+ initial_accumulator_value (float, optional): Initial value for the accumulator.
+ (default: :const:`0.0`)
+ eps (float, optional): A small constant applied to denominator outside of the square
+ root (as in the Adam paper) to avoid dividing by zero when rescaling.
+ (default: :const:`1e-10`)
+ maximize (bool, optional): Maximize the params based on the objective, instead of
+ minimizing. (default: :data:`False`)
+ """
+ super().__init__(
+ params,
+ alias.adagrad(
+ lr=lr,
+ lr_decay=lr_decay,
+ weight_decay=weight_decay,
+ initial_accumulator_value=initial_accumulator_value,
+ eps=eps,
+ maximize=maximize,
+ ),
+ )
+
+
+Adagrad = AdaGrad # alias for PyTorch compatibility
diff --git a/torchopt/optim/adam.py b/torchopt/optim/adam.py
index 640eea1d..5d85cbdc 100644
--- a/torchopt/optim/adam.py
+++ b/torchopt/optim/adam.py
@@ -1,4 +1,4 @@
-# Copyright 2022 MetaOPT Team. All Rights Reserved.
+# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -40,7 +40,7 @@ class Adam(Optimizer):
def __init__(
self,
params: Iterable[torch.Tensor],
- lr: ScalarOrSchedule,
+ lr: ScalarOrSchedule = 1e-3,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
diff --git a/torchopt/optim/adamw.py b/torchopt/optim/adamw.py
index 7db5e750..be8c6727 100644
--- a/torchopt/optim/adamw.py
+++ b/torchopt/optim/adamw.py
@@ -1,4 +1,4 @@
-# Copyright 2022 MetaOPT Team. All Rights Reserved.
+# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -64,7 +64,7 @@ def __init__(
(default: :const:`1e-8`)
weight_decay (float, optional): Strength of the weight decay regularization. Note that
this weight decay is multiplied with the learning rate. This is consistent with
- other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where
+ other frameworks such as PyTorch, but different from (Loshchilov et al., 2019) where
the weight decay is only multiplied with the "schedule multiplier", but not the base
learning rate. (default: :const:`1e-2`)
eps_root (float, optional): A small constant applied to denominator inside the square
diff --git a/torchopt/optim/base.py b/torchopt/optim/base.py
index aac3a782..d0be2fd1 100644
--- a/torchopt/optim/base.py
+++ b/torchopt/optim/base.py
@@ -67,12 +67,12 @@ def zero_grad(self, set_to_none: bool = False) -> None:
"""
if set_to_none:
- def f(p):
+ def f(p: torch.Tensor) -> None:
p.grad = None
else:
- def f(p):
+ def f(p: torch.Tensor) -> None:
if p.grad is None:
return
if p.grad.grad_fn is not None:
@@ -110,7 +110,7 @@ def step(self, closure: Callable[[], torch.Tensor] | None = None) -> torch.Tenso
with torch.enable_grad():
loss = closure()
- def f(p):
+ def f(p: torch.Tensor) -> torch.Tensor | None:
return p.grad
for i, (params, state) in enumerate(zip(self.param_groups, self.state_groups)):
diff --git a/torchopt/optim/func/base.py b/torchopt/optim/func/base.py
index 9dce3412..94038464 100644
--- a/torchopt/optim/func/base.py
+++ b/torchopt/optim/func/base.py
@@ -1,4 +1,4 @@
-# Copyright 2022 MetaOPT Team. All Rights Reserved.
+# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -34,6 +34,7 @@ class FuncOptimizer: # pylint: disable=too-few-public-methods
and update the parameters.
See Also:
+ - The functional AdaGrad optimizer: :func:`torchopt.adagrad`.
- The functional Adam optimizer: :func:`torchopt.adam`.
- The functional AdamW optimizer: :func:`torchopt.adamw`.
- The functional RMSprop optimizer: :func:`torchopt.rmsprop`.
@@ -86,10 +87,12 @@ def step(
# Step parameter only
grads = torch.autograd.grad(loss, params, create_graph=True, allow_unused=True)
updates, self.optim_state = self.impl.update(
- grads, self.optim_state, params=params, inplace=inplace
+ grads,
+ self.optim_state,
+ params=params,
+ inplace=inplace,
)
- new_params = apply_updates(params, updates, inplace=inplace)
- return new_params
+ return apply_updates(params, updates, inplace=inplace)
def state_dict(self) -> OptState:
"""Extract the references of the optimizer states.
diff --git a/torchopt/optim/meta/__init__.py b/torchopt/optim/meta/__init__.py
index ba486d6d..28f374cc 100644
--- a/torchopt/optim/meta/__init__.py
+++ b/torchopt/optim/meta/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2022 MetaOPT Team. All Rights Reserved.
+# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,6 +14,7 @@
# ==============================================================================
"""Differentiable Meta-Optimizers."""
+from torchopt.optim.meta.adagrad import MetaAdaGrad, MetaAdagrad
from torchopt.optim.meta.adam import MetaAdam
from torchopt.optim.meta.adamw import MetaAdamW
from torchopt.optim.meta.base import MetaOptimizer
diff --git a/torchopt/optim/meta/adagrad.py b/torchopt/optim/meta/adagrad.py
new file mode 100644
index 00000000..079d76db
--- /dev/null
+++ b/torchopt/optim/meta/adagrad.py
@@ -0,0 +1,79 @@
+# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Differentiable AdaGrad optimizer."""
+
+from __future__ import annotations
+
+import torch.nn as nn
+
+from torchopt import alias
+from torchopt.optim.meta.base import MetaOptimizer
+from torchopt.typing import ScalarOrSchedule
+
+
+__all__ = ['MetaAdaGrad', 'MetaAdagrad']
+
+
+class MetaAdaGrad(MetaOptimizer):
+ """The differentiable AdaGrad optimizer.
+
+ See Also:
+ - The functional AdaGrad optimizer: :func:`torchopt.adagrad`.
+ - The classic AdaGrad optimizer: :class:`torchopt.AdaGrad`.
+ """
+
+ # pylint: disable-next=too-many-arguments
+ def __init__(
+ self,
+ module: nn.Module,
+ lr: ScalarOrSchedule = 1e-2,
+ lr_decay: float = 0.0,
+ weight_decay: float = 0.0,
+ initial_accumulator_value: float = 0.0,
+ eps: float = 1e-10,
+ *,
+ maximize: bool = False,
+ ) -> None:
+ """Initialize the meta AdaGrad optimizer.
+
+ Args:
+ module (nn.Module): A network whose parameters should be optimized.
+ lr (float or callable, optional): This is a fixed global scaling factor or a learning
+ rate scheduler. (default: :const:`1e-2`)
+ lr_decay (float, optional): Learning rate decay. (default: :const:`0.0`)
+ weight_decay (float, optional): Weight decay, add L2 penalty to parameters.
+ (default: :const:`0.0`)
+ initial_accumulator_value (float, optional): Initial value for the accumulator.
+ (default: :const:`0.0`)
+ eps (float, optional): A small constant applied to denominator outside of the square
+ root (as in the Adam paper) to avoid dividing by zero when rescaling.
+ (default: :const:`1e-10`)
+ maximize (bool, optional): Maximize the params based on the objective, instead of
+ minimizing. (default: :data:`False`)
+ """
+ super().__init__(
+ module,
+ alias.adagrad(
+ lr=lr,
+ lr_decay=lr_decay,
+ weight_decay=weight_decay,
+ initial_accumulator_value=initial_accumulator_value,
+ eps=eps,
+ maximize=maximize,
+ ),
+ )
+
+
+MetaAdagrad = MetaAdaGrad # alias for PyTorch compatibility
diff --git a/torchopt/optim/meta/adamw.py b/torchopt/optim/meta/adamw.py
index c8a8ef9c..204a5428 100644
--- a/torchopt/optim/meta/adamw.py
+++ b/torchopt/optim/meta/adamw.py
@@ -1,4 +1,4 @@
-# Copyright 2022 MetaOPT Team. All Rights Reserved.
+# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -64,7 +64,7 @@ def __init__(
(default: :const:`1e-8`)
weight_decay (float, optional): Strength of the weight decay regularization. Note that
this weight decay is multiplied with the learning rate. This is consistent with
- other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where
+ other frameworks such as PyTorch, but different from (Loshchilov et al., 2019) where
the weight decay is only multiplied with the "schedule multiplier", but not the base
learning rate. (default: :const:`1e-2`)
eps_root (float, optional): A small constant applied to denominator inside the square
diff --git a/torchopt/optim/meta/base.py b/torchopt/optim/meta/base.py
index c5c9ad73..54327f3b 100644
--- a/torchopt/optim/meta/base.py
+++ b/torchopt/optim/meta/base.py
@@ -68,7 +68,7 @@ def step(self, loss: torch.Tensor) -> None: # pylint: disable=too-many-locals
"""
# Step parameter only
for i, (param_container, state) in enumerate(
- zip(self.param_containers_groups, self.state_groups)
+ zip(self.param_containers_groups, self.state_groups),
):
flat_params: TupleOfTensors
flat_params, container_treespec = pytree.tree_flatten_as_tuple(param_container) # type: ignore[arg-type]
@@ -89,7 +89,8 @@ def step(self, loss: torch.Tensor) -> None: # pylint: disable=too-many-locals
self.state_groups[i] = new_state
flat_new_params = apply_updates(flat_params, updates, inplace=False)
new_params: ModuleTensorContainers = pytree.tree_unflatten( # type: ignore[assignment]
- container_treespec, flat_new_params
+ container_treespec,
+ flat_new_params,
)
for container, new_param in zip(param_container, new_params):
container.update(new_param)
diff --git a/torchopt/pytree.py b/torchopt/pytree.py
index d3b2d181..253cb154 100644
--- a/torchopt/pytree.py
+++ b/torchopt/pytree.py
@@ -102,7 +102,9 @@ def tree_add(*trees: PyTree[T]) -> PyTree[T]:
def tree_add_scalar_mul(
- tree_x: TensorTree, tree_y: TensorTree, alpha: Scalar | None = None
+ tree_x: TensorTree,
+ tree_y: TensorTree,
+ alpha: Scalar | None = None,
) -> TensorTree:
"""Compute ``tree_x + alpha * tree_y``."""
if alpha is None:
@@ -116,7 +118,9 @@ def tree_sub(minuend_tree: PyTree[T], subtrahend_tree: PyTree[T]) -> PyTree[T]:
def tree_sub_scalar_mul(
- tree_x: TensorTree, tree_y: TensorTree, alpha: Scalar | None = None
+ tree_x: TensorTree,
+ tree_y: TensorTree,
+ alpha: Scalar | None = None,
) -> TensorTree:
"""Compute ``tree_x - alpha * tree_y``."""
if alpha is None:
diff --git a/torchopt/schedule/__init__.py b/torchopt/schedule/__init__.py
index 46f59550..b9916783 100644
--- a/torchopt/schedule/__init__.py
+++ b/torchopt/schedule/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2022 MetaOPT Team. All Rights Reserved.
+# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -31,7 +31,8 @@
# ==============================================================================
"""Learning rate schedules."""
+from torchopt.schedule.exponential_decay import exponential_decay
from torchopt.schedule.polynomial import linear_schedule, polynomial_schedule
-__all__ = ['polynomial_schedule', 'linear_schedule']
+__all__ = ['exponential_decay', 'polynomial_schedule', 'linear_schedule']
diff --git a/torchopt/schedule/exponential_decay.py b/torchopt/schedule/exponential_decay.py
new file mode 100644
index 00000000..8811b353
--- /dev/null
+++ b/torchopt/schedule/exponential_decay.py
@@ -0,0 +1,119 @@
+# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# This file is modified from:
+# https://github.com/deepmind/optax/blob/master/optax/_src/schedule.py
+# ==============================================================================
+# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Exponential learning rate decay."""
+
+import logging
+import math
+from typing import Optional
+
+from torchopt.typing import Numeric, Scalar, Schedule
+
+
+__all__ = ['exponential_decay']
+
+
+# pylint: disable-next=too-many-arguments
+def exponential_decay(
+ init_value: Scalar,
+ decay_rate: Scalar,
+ transition_begin: int = 0,
+ transition_steps: int = 1,
+ staircase: bool = False,
+ end_value: Optional[float] = None,
+) -> Schedule:
+ """Construct a schedule with either continuous or discrete exponential decay.
+
+ This function applies an exponential decay function to a provided initial value. The function
+ returns the decayed value as follows:
+
+ .. code-block:: python
+
+ decayed_value = init_value * decay_rate**(count / transition_steps)
+
+ If the argument ``staircase`` is :data:`True`, then ``count / transition_steps`` is an integer
+ division and the decayed value follows a staircase function.
+
+ Args:
+ init_value (float or Tensor): Initial value for the scalar to be annealed.
+ decay_rate (float or Tensor): The decay rate.
+ transition_begin (int, optional): Must be *positive*. After how many steps to start
+ annealing (before this many steps the scalar value is held fixed at ``init_value``).
+ (default: :const:`0`)
+ transition_steps (int, optional): Number of steps over which annealing takes place, the
+ scalar starts changing at ``transition_begin`` steps and completes the transition by
+ ``transition_begin + transition_steps`` steps. If ``transition_steps <= 0``, then the
+ entire annealing process is disabled and the value is held fixed at ``init_value``.
+ (default: :const:`1`)
+ staircase (bool, optional): If :data:`True`, decay the scalar at discrete intervals.
+ (default: :data:`False`)
+ end_value (float or Tensor, optional): End value of the scalar to be annealed.
+ (default: :data:`None`)
+
+ Returns:
+ schedule: A function that maps step counts to values.
+ """
+ if transition_steps is not None and transition_steps <= 0: # pragma: no cover
+ logging.info(
+ 'An exponential schedule was set with a non-positive `transition_steps`'
+ ' value; this will result in a constant schedule with value '
+ '`init_value`.',
+ )
+ return lambda count: init_value
+
+ if decay_rate == 0: # pragma: no cover
+ logging.info(
+ 'An exponential schedule was set with a zero `decay_rate` value; '
+ 'this will result in a constant schedule with value `init_value`.',
+ )
+ return lambda count: init_value
+
+ if transition_begin < 0: # pragma: no cover
+ logging.info(
+ 'An exponential schedule was set with a negative `transition_begin` '
+ 'value; this will result in `transition_begin` falling back to `0`.',
+ )
+ transition_begin = 0
+
+ if end_value is not None: # pragma: no cover
+ clip_fn = max if decay_rate < 1.0 else min
+
+ def schedule(count: Numeric) -> Numeric:
+ decreased_count = count - transition_begin
+ p = decreased_count / transition_steps
+ if staircase:
+ p = math.floor(p)
+ decayed_value = init_value if decreased_count <= 0.0 else init_value * (decay_rate**p)
+ if end_value is not None:
+ return clip_fn(decayed_value, end_value)
+ return decayed_value
+
+ return schedule
diff --git a/torchopt/schedule/polynomial.py b/torchopt/schedule/polynomial.py
index d54dbf17..39629c38 100644
--- a/torchopt/schedule/polynomial.py
+++ b/torchopt/schedule/polynomial.py
@@ -1,4 +1,4 @@
-# Copyright 2022 MetaOPT Team. All Rights Reserved.
+# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -68,17 +68,17 @@ def polynomial_schedule(
schedule:
A function that maps step counts to values.
"""
- if transition_steps <= 0:
+ if transition_steps <= 0: # pragma: no cover
logging.info(
'A polynomial schedule was set with a non-positive `transition_steps` value; this '
- 'results in a constant schedule with value `init_value`.'
+ 'results in a constant schedule with value `init_value`.',
)
return lambda count: init_value
- if transition_begin < 0:
+ if transition_begin < 0: # pragma: no cover
logging.info(
'An exponential schedule was set with a negative `transition_begin` value; this will '
- 'result in `transition_begin` falling back to `0`.'
+ 'result in `transition_begin` falling back to `0`.',
)
transition_begin = 0
diff --git a/torchopt/transform/__init__.py b/torchopt/transform/__init__.py
index 7006090f..47c49ea1 100644
--- a/torchopt/transform/__init__.py
+++ b/torchopt/transform/__init__.py
@@ -36,6 +36,7 @@
from torchopt.transform.scale import scale
from torchopt.transform.scale_by_adam import scale_by_accelerated_adam, scale_by_adam
from torchopt.transform.scale_by_rms import scale_by_rms
+from torchopt.transform.scale_by_rss import scale_by_rss
from torchopt.transform.scale_by_schedule import scale_by_schedule
from torchopt.transform.scale_by_stddev import scale_by_stddev
from torchopt.transform.trace import trace
@@ -49,6 +50,7 @@
'masked',
'scale_by_adam',
'scale_by_accelerated_adam',
+ 'scale_by_rss',
'scale_by_rms',
'scale_by_stddev',
'nan_to_num',
diff --git a/torchopt/transform/add_decayed_weights.py b/torchopt/transform/add_decayed_weights.py
index 14745766..04d564d7 100644
--- a/torchopt/transform/add_decayed_weights.py
+++ b/torchopt/transform/add_decayed_weights.py
@@ -36,6 +36,8 @@
from typing import Any, Callable, NamedTuple
+import torch
+
from torchopt import pytree
from torchopt.base import EmptyState, GradientTransformation, identity
from torchopt.transform.utils import tree_map_flat, tree_map_flat_
@@ -103,12 +105,12 @@ def _masked(
*,
already_flattened: bool = False,
) -> GradientTransformation:
- if already_flattened:
+ if already_flattened: # noqa: SIM108
tree_map = tree_map_flat
else:
tree_map = pytree.tree_map # type: ignore[assignment]
- def tree_mask(params, mask_tree):
+ def tree_mask(params: Params, mask_tree: OptState) -> Params:
return tree_map(lambda p, m: p if m else MaskedNode(), params, mask_tree)
def init_fn(params: Params) -> OptState:
@@ -128,11 +130,17 @@ def update_fn(
masked_params = None if params is None else tree_mask(params, mask_tree)
new_masked_updates, new_inner_state = inner.update(
- masked_updates, state.inner_state, params=masked_params, inplace=inplace
+ masked_updates,
+ state.inner_state,
+ params=masked_params,
+ inplace=inplace,
)
new_updates = tree_map(
- lambda old_u, new_u, m: new_u if m else old_u, updates, new_masked_updates, mask_tree
+ lambda old_u, new_u, m: new_u if m else old_u,
+ updates,
+ new_masked_updates,
+ mask_tree,
)
return new_updates, MaskedState(inner_state=new_inner_state)
@@ -188,7 +196,7 @@ def _add_decayed_weights(
already_flattened: bool = False,
) -> GradientTransformation:
# pylint: disable-next=unneeded-not
- if not 0.0 <= weight_decay: # pragma: no cover
+ if not weight_decay >= 0.0: # pragma: no cover
raise ValueError(f'Invalid weight_decay value: {weight_decay}')
if weight_decay == 0.0 and mask is None:
@@ -218,7 +226,7 @@ def update_fn(
if inplace:
- def f(g, p):
+ def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
if g.requires_grad:
return g.add_(p, alpha=weight_decay)
return g.add_(p.data, alpha=weight_decay)
@@ -227,7 +235,7 @@ def f(g, p):
else:
- def f(g, p):
+ def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
return g.add(p, alpha=weight_decay)
updates = tree_map(f, updates, params)
diff --git a/torchopt/transform/nan_to_num.py b/torchopt/transform/nan_to_num.py
index 804f8219..27d87499 100644
--- a/torchopt/transform/nan_to_num.py
+++ b/torchopt/transform/nan_to_num.py
@@ -16,6 +16,8 @@
from __future__ import annotations
+import torch
+
from torchopt import pytree
from torchopt.base import EmptyState, GradientTransformation
from torchopt.typing import OptState, Params, Updates
@@ -44,12 +46,12 @@ def update_fn(
) -> tuple[Updates, OptState]:
if inplace:
- def f(g):
+ def f(g: torch.Tensor) -> torch.Tensor:
return g.nan_to_num_(nan=nan, posinf=posinf, neginf=neginf)
else:
- def f(g):
+ def f(g: torch.Tensor) -> torch.Tensor:
return g.nan_to_num(nan=nan, posinf=posinf, neginf=neginf)
new_updates = pytree.tree_map(f, updates)
diff --git a/torchopt/transform/scale.py b/torchopt/transform/scale.py
index 639c903e..c731003c 100644
--- a/torchopt/transform/scale.py
+++ b/torchopt/transform/scale.py
@@ -33,6 +33,8 @@
from __future__ import annotations
+import torch
+
from torchopt import pytree
from torchopt.base import EmptyState, GradientTransformation
from torchopt.transform.utils import tree_map_flat, tree_map_flat_
@@ -85,14 +87,14 @@ def update_fn(
) -> tuple[Updates, OptState]:
if inplace:
- def f(g):
+ def f(g: torch.Tensor) -> torch.Tensor:
return g.mul_(step_size)
updates = tree_map_(f, updates)
else:
- def f(g):
+ def f(g: torch.Tensor) -> torch.Tensor:
return g.mul(step_size)
updates = tree_map(f, updates)
diff --git a/torchopt/transform/scale_by_adam.py b/torchopt/transform/scale_by_adam.py
index 36f30be9..c3c6254e 100644
--- a/torchopt/transform/scale_by_adam.py
+++ b/torchopt/transform/scale_by_adam.py
@@ -69,7 +69,7 @@ def _bias_correction(
) -> Updates:
"""Perform bias correction. This becomes a no-op as count goes to infinity."""
- def f(t, c): # pylint: disable=invalid-name
+ def f(t: torch.Tensor, c: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
return t.div(1 - pow(decay, c))
if already_flattened:
@@ -87,7 +87,7 @@ def scale_by_adam(
"""Rescale updates according to the Adam algorithm.
References:
- [Kingma et al, 2014](https://arxiv.org/abs/1412.6980)
+ - Kingma et al., 2014: https://arxiv.org/abs/1412.6980
Args:
b1 (float, optional): Decay rate for the exponentially weighted average of grads.
@@ -142,7 +142,7 @@ def _scale_by_adam(
already_flattened: bool = False,
) -> GradientTransformation:
# pylint: disable=unneeded-not
- if not 0.0 <= eps: # pragma: no cover
+ if not eps >= 0.0: # pragma: no cover
raise ValueError(f'Invalid epsilon value: {eps}')
if not 0.0 <= b1 < 1.0: # pragma: no cover
raise ValueError(f'Invalid beta parameter at index 0: {b1}')
@@ -150,20 +150,23 @@ def _scale_by_adam(
raise ValueError(f'Invalid beta parameter at index 1: {b2}')
# pylint: enable=unneeded-not
- if already_flattened:
+ if already_flattened: # noqa: SIM108
tree_map = tree_map_flat
else:
tree_map = pytree.tree_map # type: ignore[assignment]
def init_fn(params: Params) -> OptState:
zero = tree_map( # count init
- lambda t: torch.zeros(1, dtype=torch.int64, device=t.device).squeeze_(), params
+ lambda t: torch.zeros(1, dtype=torch.int64, device=t.device).squeeze_(),
+ params,
)
mu = tree_map( # first moment
- lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params
+ lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad),
+ params,
)
nu = tree_map( # second moment
- lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params
+ lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad),
+ params,
)
return ScaleByAdamState(mu=mu, nu=nu, count=zero)
@@ -175,10 +178,20 @@ def update_fn(
inplace: bool = True,
) -> tuple[Updates, OptState]:
mu = update_moment.impl( # type: ignore[attr-defined]
- updates, state.mu, b1, order=1, inplace=inplace, already_flattened=already_flattened
+ updates,
+ state.mu,
+ b1,
+ order=1,
+ inplace=inplace,
+ already_flattened=already_flattened,
)
nu = update_moment.impl( # type: ignore[attr-defined]
- updates, state.nu, b2, order=2, inplace=inplace, already_flattened=already_flattened
+ updates,
+ state.nu,
+ b2,
+ order=2,
+ inplace=inplace,
+ already_flattened=already_flattened,
)
# pylint: disable=line-too-long
count_inc = inc_count.impl(updates, state.count, already_flattened=already_flattened) # type: ignore[attr-defined]
@@ -187,12 +200,20 @@ def update_fn(
if inplace:
- def f(g, m, v): # pylint: disable=unused-argument
+ def f(
+ g: torch.Tensor, # pylint: disable=unused-argument
+ m: torch.Tensor,
+ v: torch.Tensor,
+ ) -> torch.Tensor:
return m.div_(v.add_(eps_root).sqrt_().add(eps))
else:
- def f(g, m, v): # pylint: disable=unused-argument
+ def f(
+ g: torch.Tensor, # pylint: disable=unused-argument
+ m: torch.Tensor,
+ v: torch.Tensor,
+ ) -> torch.Tensor:
return m.div(v.add(eps_root).sqrt_().add(eps))
updates = tree_map(f, updates, mu_hat, nu_hat)
@@ -217,7 +238,7 @@ def scale_by_accelerated_adam(
This function is accelerated by using some fused accelerated operators.
References:
- [Kingma et al, 2014](https://arxiv.org/abs/1412.6980)
+ - Kingma et al., 2014: https://arxiv.org/abs/1412.6980
Args:
b1 (float, optional): Decay rate for the exponentially weighted average of grads.
@@ -272,7 +293,7 @@ def _scale_by_accelerated_adam(
already_flattened: bool = False,
) -> GradientTransformation:
# pylint: disable=unneeded-not
- if not 0.0 <= eps: # pragma: no cover
+ if not eps >= 0.0: # pragma: no cover
raise ValueError(f'Invalid epsilon value: {eps}')
if not 0.0 <= b1 < 1.0: # pragma: no cover
raise ValueError(f'Invalid beta parameter at index 0: {b1}')
@@ -293,9 +314,40 @@ def update_fn(
count_inc = inc_count.impl(updates, state.count, already_flattened=True) # type: ignore[attr-defined]
op = AdamOp(b1=b1, b2=b2, eps=eps, eps_root=eps_root, inplace=inplace)
- out = tree_map_flat(op, state.mu, state.nu, updates, count_inc)
- new_mu, new_nu, new_updates = tuple(zip(*out)) # transpose
+ def op_fn(
+ mu: torch.Tensor | None,
+ nu: torch.Tensor | None,
+ update: torch.Tensor | None,
+ count: torch.Tensor | None,
+ ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
+ if mu is None:
+ return (None, None, None)
+ return op(mu, nu, update, count) # type: ignore[arg-type]
+
+ out = tree_map_flat(
+ op_fn,
+ state.mu,
+ state.nu,
+ updates,
+ count_inc,
+ none_is_leaf=True,
+ )
+
+ if len(out) == 0:
+ new_mu, new_nu, new_updates = (), (), ()
+ else:
+ new_mu, new_nu, new_updates = tuple(zip(*out)) # transpose
+
+ new_mu, new_nu, new_updates = (
+ new if type(new) is type(old) else type(old)(new)
+ for new, old in (
+ (new_mu, state.mu),
+ (new_nu, state.nu),
+ (new_updates, updates),
+ )
+ )
+
return new_updates, ScaleByAdamState(mu=new_mu, nu=new_nu, count=count_inc)
else:
@@ -310,26 +362,57 @@ def update_fn(
) -> tuple[Updates, OptState]:
count_inc = inc_count.impl(updates, state.count, already_flattened=False) # type: ignore[attr-defined]
- treespec = pytree.tree_structure(updates, none_is_leaf=True)
-
- op = AdamOp(b1=b1, b2=b2, eps=eps, eps_root=eps_root, inplace=inplace)
- out = pytree.tree_map(op, state.mu, state.nu, updates, count_inc)
-
new_mu: Updates
new_nu: Updates
new_updates: Updates
- new_mu, new_nu, new_updates = pytree.tree_transpose(treespec, TRIPLE_PYTREE_SPEC, out) # type: ignore[misc]
+
+ treespec = pytree.tree_structure(updates, none_is_leaf=True)
+ if treespec.num_leaves > 0:
+ op = AdamOp(b1=b1, b2=b2, eps=eps, eps_root=eps_root, inplace=inplace)
+
+ def op_fn(
+ mu: torch.Tensor | None,
+ nu: torch.Tensor | None,
+ update: torch.Tensor | None,
+ count: torch.Tensor | None,
+ ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
+ if mu is None:
+ return (None, None, None)
+ return op(mu, nu, update, count) # type: ignore[arg-type]
+
+ out = pytree.tree_map(
+ op_fn,
+ state.mu,
+ state.nu,
+ updates,
+ count_inc,
+ none_is_leaf=True,
+ )
+
+ new_mu, new_nu, new_updates = pytree.tree_transpose( # type: ignore[misc]
+ treespec,
+ TRIPLE_PYTREE_SPEC,
+ out,
+ )
+ else:
+ new_mu = pytree.tree_unflatten(treespec, ())
+ new_nu = pytree.tree_unflatten(treespec, ())
+ new_updates = pytree.tree_unflatten(treespec, ())
+
return new_updates, ScaleByAdamState(mu=new_mu, nu=new_nu, count=count_inc)
def init_fn(params: Params) -> OptState:
zero = tree_map( # count init
- lambda t: torch.zeros(1, dtype=torch.int64, device=t.device).squeeze_(), params
+ lambda t: torch.zeros(1, dtype=torch.int64, device=t.device).squeeze_(),
+ params,
)
mu = tree_map( # first moment
- lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params
+ lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad),
+ params,
)
nu = tree_map( # second moment
- lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params
+ lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad),
+ params,
)
return ScaleByAdamState(mu=mu, nu=nu, count=zero)
diff --git a/torchopt/transform/scale_by_rms.py b/torchopt/transform/scale_by_rms.py
index 7a0c8c20..ac2fef16 100644
--- a/torchopt/transform/scale_by_rms.py
+++ b/torchopt/transform/scale_by_rms.py
@@ -60,7 +60,7 @@ def scale_by_rms(
"""Rescale updates by the root of the exp. moving avg of the square.
References:
- [Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
+ - Tieleman and Hinton, 2012: http://www.cs.toronto.edu/~hinton/coursera/lecture6/lec6.pdf
Args:
alpha (float, optional): Decay rate for the exponentially weighted average of squared grads.
@@ -101,9 +101,9 @@ def _scale_by_rms(
already_flattened: bool = False,
) -> GradientTransformation:
# pylint: disable=unneeded-not
- if not 0.0 <= alpha: # pragma: no cover
+ if not alpha >= 0.0: # pragma: no cover
raise ValueError(f'Invalid alpha value: {alpha}')
- if not 0.0 <= eps: # pragma: no cover
+ if not eps >= 0.0: # pragma: no cover
raise ValueError(f'Invalid epsilon value: {eps}')
# pylint: enable=unneeded-not
@@ -126,19 +126,24 @@ def update_fn(
inplace: bool = True,
) -> tuple[Updates, OptState]:
nu = update_moment.impl( # type: ignore[attr-defined]
- updates, state.nu, alpha, order=2, inplace=inplace, already_flattened=already_flattened
+ updates,
+ state.nu,
+ alpha,
+ order=2,
+ inplace=inplace,
+ already_flattened=already_flattened,
)
if inplace:
- def f(g, n): # pylint: disable=invalid-name
+ def f(g: torch.Tensor, n: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
return g.div_(n.sqrt().add_(eps))
updates = tree_map_(f, updates, nu)
else:
- def f(g, n): # pylint: disable=invalid-name
+ def f(g: torch.Tensor, n: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
return g.div(n.sqrt().add(eps))
updates = tree_map(f, updates, nu)
diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py
new file mode 100644
index 00000000..68021e5e
--- /dev/null
+++ b/torchopt/transform/scale_by_rss.py
@@ -0,0 +1,154 @@
+# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# This file is modified from:
+# https://github.com/deepmind/optax/blob/master/optax/_src/transform.py
+# ==============================================================================
+# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Preset transformations for scaling updates by the root of the sum of all squared gradients."""
+
+from __future__ import annotations
+
+from typing import NamedTuple
+
+import torch
+
+from torchopt import pytree
+from torchopt.base import GradientTransformation
+from torchopt.transform.utils import tree_map_flat, update_moment
+from torchopt.typing import OptState, Params, Updates
+
+
+__all__ = ['scale_by_rss']
+
+
+class ScaleByRssState(NamedTuple):
+ """State holding the sum of gradient squares to date."""
+
+ sum_of_squares: Updates
+
+
+def scale_by_rss(
+ initial_accumulator_value: float = 0.0,
+ eps: float = 1e-10,
+) -> GradientTransformation:
+ """Rescale updates by the root of the sum of all squared gradients to date.
+
+ References:
+ - Duchi et al., 2011: https://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf
+ - McMahan et al., 2010: https://arxiv.org/abs/1002.4908
+
+ Args:
+ initial_accumulator_value (float, optional): Starting value for accumulators, must be
+ ``>= 0``. (default: :const:`0.0`)
+ eps (float, optional): A small floating point value to avoid zero denominator.
+ (default: :const:`1e-10`)
+
+ Returns:
+ An (init_fn, update_fn) tuple.
+ """
+ return _scale_by_rss(
+ initial_accumulator_value=initial_accumulator_value,
+ eps=eps,
+ already_flattened=False,
+ )
+
+
+def _scale_by_rss_flat(
+ initial_accumulator_value: float = 0.0,
+ eps: float = 1e-10,
+) -> GradientTransformation:
+ return _scale_by_rss(
+ initial_accumulator_value=initial_accumulator_value,
+ eps=eps,
+ already_flattened=True,
+ )
+
+
+def _scale_by_rss(
+ initial_accumulator_value: float = 0.0,
+ eps: float = 1e-10,
+ *,
+ already_flattened: bool = False,
+) -> GradientTransformation:
+ if already_flattened: # noqa: SIM108
+ tree_map = tree_map_flat
+ else:
+ tree_map = pytree.tree_map # type: ignore[assignment]
+
+ def init_fn(params: Params) -> OptState:
+ sum_of_squares = tree_map(
+ lambda t: torch.full_like(
+ t,
+ initial_accumulator_value,
+ memory_format=torch.preserve_format,
+ ),
+ params,
+ )
+ return ScaleByRssState(sum_of_squares=sum_of_squares)
+
+ def update_fn(
+ updates: Updates,
+ state: OptState,
+ params: Params | None = None, # pylint: disable=unused-argument
+ inplace: bool = True,
+ ) -> tuple[Updates, OptState]:
+ sum_of_squares = update_moment.impl( # type: ignore[attr-defined]
+ updates,
+ state.sum_of_squares,
+ decay=1.0,
+ order=2,
+ inplace=inplace,
+ already_flattened=already_flattened,
+ )
+
+ if inplace:
+
+ def f(g: torch.Tensor, sos: torch.Tensor) -> torch.Tensor:
+ return torch.where(
+ sos > 0.0,
+ g.div_(sos.sqrt().add_(eps)),
+ 0.0,
+ )
+
+ else:
+
+ def f(g: torch.Tensor, sos: torch.Tensor) -> torch.Tensor:
+ return torch.where(
+ sos > 0.0,
+ g.div(sos.sqrt().add(eps)),
+ 0.0,
+ )
+
+ updates = tree_map(f, updates, sum_of_squares)
+ return updates, ScaleByRssState(sum_of_squares=sum_of_squares)
+
+ return GradientTransformation(init_fn, update_fn)
+
+
+scale_by_rss.flat = _scale_by_rss_flat # type: ignore[attr-defined]
+scale_by_rss.impl = _scale_by_rss # type: ignore[attr-defined]
diff --git a/torchopt/transform/scale_by_schedule.py b/torchopt/transform/scale_by_schedule.py
index d6e3b0fa..f27fb7e8 100644
--- a/torchopt/transform/scale_by_schedule.py
+++ b/torchopt/transform/scale_by_schedule.py
@@ -40,7 +40,7 @@
from torchopt import pytree
from torchopt.base import GradientTransformation
from torchopt.transform.utils import inc_count, tree_map_flat, tree_map_flat_
-from torchopt.typing import OptState, Params, Schedule, SequenceOfTensors, Updates
+from torchopt.typing import Numeric, OptState, Params, Schedule, SequenceOfTensors, Updates
__all__ = ['scale_by_schedule']
@@ -83,7 +83,8 @@ def _scale_by_schedule(
def init_fn(params: Params) -> OptState:
zero = tree_map( # count init
- lambda t: torch.zeros(1, dtype=torch.int64, device=t.device).squeeze_(), params
+ lambda t: torch.zeros(1, dtype=torch.int64, device=t.device).squeeze_(),
+ params,
)
return ScaleByScheduleState(count=zero)
@@ -96,7 +97,7 @@ def update_fn(
) -> tuple[Updates, OptState]:
if inplace:
- def f(g, c): # pylint: disable=invalid-name
+ def f(g: torch.Tensor, c: Numeric) -> torch.Tensor: # pylint: disable=invalid-name
step_size = step_size_fn(c)
return g.mul_(step_size)
@@ -104,7 +105,7 @@ def f(g, c): # pylint: disable=invalid-name
else:
- def f(g, c): # pylint: disable=invalid-name
+ def f(g: torch.Tensor, c: Numeric) -> torch.Tensor: # pylint: disable=invalid-name
step_size = step_size_fn(c)
return g.mul(step_size)
@@ -117,7 +118,7 @@ def f(g, c): # pylint: disable=invalid-name
updates,
state.count,
already_flattened=already_flattened,
- )
+ ),
),
)
diff --git a/torchopt/transform/scale_by_stddev.py b/torchopt/transform/scale_by_stddev.py
index 228ed707..bbbfb384 100644
--- a/torchopt/transform/scale_by_stddev.py
+++ b/torchopt/transform/scale_by_stddev.py
@@ -63,7 +63,7 @@ def scale_by_stddev(
"""Rescale updates by the root of the centered exponential moving average of squares.
References:
- [Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
+ - Tieleman and Hinton, 2012: http://www.cs.toronto.edu/~hinton/coursera/lecture6/lec6.pdf
Args:
alpha (float, optional): Decay rate for the exponentially weighted average of squared grads.
@@ -104,9 +104,9 @@ def _scale_by_stddev(
already_flattened: bool = False,
) -> GradientTransformation:
# pylint: disable=unneeded-not
- if not 0.0 <= alpha: # pragma: no cover
+ if not alpha >= 0.0: # pragma: no cover
raise ValueError(f'Invalid alpha value: {alpha}')
- if not 0.0 <= eps: # pragma: no cover
+ if not eps >= 0.0: # pragma: no cover
raise ValueError(f'Invalid epsilon value: {eps}')
# pylint: enable=unneeded-not
@@ -130,22 +130,32 @@ def update_fn(
inplace: bool = True,
) -> tuple[Updates, OptState]:
mu = update_moment.impl( # type: ignore[attr-defined]
- updates, state.mu, alpha, order=1, inplace=inplace, already_flattened=already_flattened
+ updates,
+ state.mu,
+ alpha,
+ order=1,
+ inplace=inplace,
+ already_flattened=already_flattened,
)
nu = update_moment.impl( # type: ignore[attr-defined]
- updates, state.nu, alpha, order=2, inplace=inplace, already_flattened=already_flattened
+ updates,
+ state.nu,
+ alpha,
+ order=2,
+ inplace=inplace,
+ already_flattened=already_flattened,
)
if inplace:
- def f(g, m, n):
+ def f(g: torch.Tensor, m: torch.Tensor, n: torch.Tensor) -> torch.Tensor:
return g.div_(n.addcmul(m, m, value=-1.0).sqrt_().add(eps))
updates = tree_map_(f, updates, mu, nu)
else:
- def f(g, m, n):
+ def f(g: torch.Tensor, m: torch.Tensor, n: torch.Tensor) -> torch.Tensor:
return g.div(n.addcmul(m, m, value=-1.0).sqrt_().add(eps))
updates = tree_map(f, updates, mu, nu)
diff --git a/torchopt/transform/trace.py b/torchopt/transform/trace.py
index 03d2441d..7a1e1971 100644
--- a/torchopt/transform/trace.py
+++ b/torchopt/transform/trace.py
@@ -110,7 +110,7 @@ def _trace(
already_flattened: bool = False,
) -> GradientTransformation:
# pylint: disable=unneeded-not
- if not 0.0 <= momentum: # pragma: no cover
+ if not momentum >= 0.0: # pragma: no cover
raise ValueError(f'Invalid momentum value: {momentum}')
if nesterov and (momentum <= 0.0 or dampening != 0.0): # pragma: no cover
raise ValueError('Nesterov momentum requires a momentum and zero dampening')
@@ -129,8 +129,9 @@ def _trace(
def init_fn(params: Params) -> OptState:
return TraceState(
trace=tree_map(
- lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params
- )
+ lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad),
+ params,
+ ),
)
first_call = True
@@ -147,12 +148,12 @@ def update_fn(
if nesterov:
if inplace:
- def f1(g, t):
+ def f1(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
if first_call:
return t.add_(g)
return t.mul_(momentum).add_(g)
- def f2(g, t):
+ def f2(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
return g.add_(t, alpha=momentum)
new_trace = tree_map(f1, updates, state.trace)
@@ -160,12 +161,12 @@ def f2(g, t):
else:
- def f1(g, t):
+ def f1(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
if first_call:
return t.add(g)
return t.mul(momentum).add_(g)
- def f2(g, t):
+ def f2(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
return g.add(t, alpha=momentum)
new_trace = tree_map(f1, updates, state.trace)
@@ -174,12 +175,12 @@ def f2(g, t):
else:
if inplace:
- def f(g, t):
+ def f(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
if first_call:
return t.add_(g)
return t.mul_(momentum).add_(g, alpha=1.0 - dampening)
- def copy_(g, t):
+ def copy_(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
return g.copy_(t)
new_trace = tree_map(f, updates, state.trace)
@@ -187,7 +188,7 @@ def copy_(g, t):
else:
- def f(g, t):
+ def f(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
if first_call:
return t.add(g)
return t.mul(momentum).add_(g, alpha=1.0 - dampening)
diff --git a/torchopt/transform/utils.py b/torchopt/transform/utils.py
index 77ba58ca..8c67fd7e 100644
--- a/torchopt/transform/utils.py
+++ b/torchopt/transform/utils.py
@@ -59,7 +59,7 @@ def tree_map_flat(
fn = func
else:
- def fn(x, *xs):
+ def fn(x: Any | None, *xs: Any) -> Any | None:
return func(x, *xs) if x is not None else None
return flat_arg.__class__(map(fn, flat_arg, *flat_args)) # type: ignore[call-arg]
@@ -76,7 +76,7 @@ def tree_map_flat_(
fn = func
else:
- def fn(x, *xs):
+ def fn(x: Any | None, *xs: Any) -> Any | None:
return func(x, *xs) if x is not None else None
flat_results = map(fn, flat_arg, *flat_args)
@@ -111,7 +111,7 @@ def _inc_count(
*,
already_flattened: bool = False,
) -> TensorTree:
- def f(c, g): # pylint: disable=invalid-name
+ def f(c: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor: # pylint: disable=invalid-name
return c + (c != INT64_MAX).to(torch.int64) if g is not None else c
if already_flattened:
@@ -167,31 +167,55 @@ def _update_moment(
*,
order: int,
inplace: bool = True,
- already_flattened=False,
+ already_flattened: bool = False,
) -> TensorTree:
assert order in (1, 2)
if inplace:
if order == 2:
+ if decay != 1.0:
+
+ def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
+ return t.mul_(decay).addcmul_(g, g, value=1 - decay) if g is not None else t
+
+ else:
- def f(g, t):
- return t.mul_(decay).addcmul_(g, g, value=1 - decay) if g is not None else t
+ def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
+ return t.addcmul_(g, g) if g is not None else t
else:
+ if decay != 1.0:
+
+ def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
+ return t.mul_(decay).add_(g, alpha=1 - decay) if g is not None else t
+
+ else:
- def f(g, t):
- return t.mul_(decay).add_(g, alpha=1 - decay) if g is not None else t
+ def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
+ return t.add_(g) if g is not None else t
else:
if order == 2:
+ if decay != 1.0:
- def f(g, t):
- return t.mul(decay).addcmul_(g, g, value=1 - decay) if g is not None else t
+ def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
+ return t.mul(decay).addcmul_(g, g, value=1 - decay) if g is not None else t
+
+ else:
+
+ def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
+ return t.addcmul(g, g) if g is not None else t
else:
+ if decay != 1.0:
+
+ def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
+ return t.mul(decay).add_(g, alpha=1 - decay) if g is not None else t
+
+ else:
- def f(g, t):
- return t.mul(decay).add_(g, alpha=1 - decay) if g is not None else t
+ def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
+ return t.add(g) if g is not None else t
if already_flattened:
return tree_map_flat(f, updates, moments, none_is_leaf=True)
diff --git a/torchopt/typing.py b/torchopt/typing.py
index 2075dc62..510cb693 100644
--- a/torchopt/typing.py
+++ b/torchopt/typing.py
@@ -15,9 +15,19 @@
"""Typing utilities."""
import abc
-from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
+from typing import (
+ Callable,
+ Dict,
+ List,
+ Optional,
+ Protocol,
+ Sequence,
+ Tuple,
+ TypeVar,
+ Union,
+ runtime_checkable,
+)
from typing_extensions import TypeAlias # Python 3.10+
-from typing_extensions import Protocol, runtime_checkable # Python 3.8+
import torch
import torch.distributed.rpc as rpc
@@ -126,11 +136,12 @@ class Samplable(Protocol): # pylint: disable=too-few-public-methods
@abc.abstractmethod
def sample(
- self, sample_shape: Size = Size() # pylint: disable=unused-argument
+ self,
+ sample_shape: Size = Size(), # noqa: B008 # pylint: disable=unused-argument
) -> Union[Tensor, Sequence[Numeric]]:
# pylint: disable-next=line-too-long
"""Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched."""
- raise NotImplementedError # pragma: no cover
+ raise NotImplementedError
Samplable.register(Distribution)
diff --git a/torchopt/update.py b/torchopt/update.py
index 9485896b..3a2a6984 100644
--- a/torchopt/update.py
+++ b/torchopt/update.py
@@ -31,6 +31,10 @@
# ==============================================================================
"""Helper functions for applying updates."""
+from __future__ import annotations
+
+import torch
+
from torchopt import pytree
from torchopt.typing import Params, Updates
@@ -59,14 +63,14 @@ def apply_updates(params: Params, updates: Updates, *, inplace: bool = True) ->
"""
if inplace:
- def f(p, u):
+ def f(p: torch.Tensor, u: torch.Tensor | None) -> torch.Tensor:
if u is not None:
p.data.add_(u)
return p
else:
- def f(p, u):
+ def f(p: torch.Tensor, u: torch.Tensor | None) -> torch.Tensor:
return p.add(u) if u is not None else p
return pytree.tree_map(f, params, updates)
diff --git a/torchopt/utils.py b/torchopt/utils.py
index 12adb214..69bda9ac 100644
--- a/torchopt/utils.py
+++ b/torchopt/utils.py
@@ -18,8 +18,7 @@
import copy
import itertools
-from typing import TYPE_CHECKING, NamedTuple, Sequence, cast, overload
-from typing_extensions import Literal # Python 3.8+
+from typing import TYPE_CHECKING, Any, Literal, NamedTuple, Sequence, cast, overload
from typing_extensions import TypeAlias # Python 3.10+
import torch
@@ -29,7 +28,7 @@
from torchopt.typing import Device, ModuleTensorContainers, OptState, TensorContainer, TensorTree
-if TYPE_CHECKING: # pragma: no cover
+if TYPE_CHECKING:
from torchopt.optim.meta.base import MetaOptimizer
@@ -46,8 +45,8 @@
class ModuleState(NamedTuple):
"""Container for module state."""
- params: tuple[dict[str, torch.Tensor], ...]
- buffers: tuple[dict[str, torch.Tensor], ...]
+ params: tuple[TensorContainer, ...]
+ buffers: tuple[TensorContainer, ...]
visual_contents: dict | None = None
detach_buffers: bool = False
@@ -74,7 +73,7 @@ def stop_gradient(target: ModuleState | nn.Module | MetaOptimizer | TensorTree)
# pylint: disable-next=import-outside-toplevel
from torchopt.optim.meta.base import MetaOptimizer
- def fn_(obj):
+ def fn_(obj: Any) -> None:
if isinstance(obj, torch.Tensor):
requires_grad = obj.requires_grad
obj.detach_().requires_grad_(requires_grad)
@@ -98,6 +97,7 @@ def extract_state_dict(
by: CopyMode = 'reference',
device: Device | None = None,
with_buffers: bool = True,
+ detach_buffers: bool = False,
enable_visual: bool = False,
visual_prefix: str = '',
) -> ModuleState: # pragma: no cover
@@ -110,9 +110,6 @@ def extract_state_dict(
*,
by: CopyMode = 'reference',
device: Device | None = None,
- with_buffers: bool = True,
- enable_visual: bool = False,
- visual_prefix: str = '',
) -> tuple[OptState, ...]: # pragma: no cover
...
@@ -185,7 +182,8 @@ def clone(t: torch.Tensor) -> torch.Tensor:
def clone_detach_(t: torch.Tensor) -> torch.Tensor:
if isinstance(t, nn.Parameter):
return nn.Parameter(
- t.clone().to(device=target_device).detach_(), requires_grad=t.requires_grad
+ t.clone().to(device=target_device).detach_(),
+ requires_grad=t.requires_grad,
)
return t.clone().to(device=target_device).detach_().requires_grad_(t.requires_grad)
@@ -221,38 +219,38 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor:
else:
visual_contents = None
- params: list[dict[str, torch.Tensor]] = []
- buffers: list[dict[str, torch.Tensor]] = []
+ params: list[TensorContainer] = []
+ buffers: list[TensorContainer] = []
memo: set[nn.Module] = set()
- def update_params(container):
+ def update_params(container: TensorContainer) -> None:
if len(container) > 0:
params.append(
type(container)(
(k, replicate(v))
for k, v in container.items()
if isinstance(v, torch.Tensor)
- )
+ ),
)
- def update_buffers(container):
+ def update_buffers(container: TensorContainer) -> None:
if len(container) > 0:
fn = clone_detach_ if detach_buffers else replicate
buffers.append(
type(container)(
(k, fn(v)) for k, v in container.items() if isinstance(v, torch.Tensor)
- )
+ ),
)
# pylint: disable=protected-access
- update_params(target._parameters)
+ update_params(target._parameters) # type: ignore[arg-type]
if with_buffers:
update_buffers(target._buffers)
memo.add(target)
for submodule in target.modules():
if submodule in memo:
continue
- update_params(submodule._parameters)
+ update_params(submodule._parameters) # type: ignore[arg-type]
if with_buffers:
update_buffers(submodule._buffers)
memo.add(submodule)
@@ -264,10 +262,10 @@ def update_buffers(container):
detach_buffers=detach_buffers,
)
- elif isinstance(target, MetaOptimizer):
+ if isinstance(target, MetaOptimizer):
state = target.state_dict()
- def get_variable(t):
+ def get_variable(t: torch.Tensor | None) -> torch.Tensor | None:
if isinstance(t, torch.Tensor):
return replicate(t)
return t
@@ -279,7 +277,8 @@ def get_variable(t):
def extract_module_containers(
- module: nn.Module, with_buffers: bool = True
+ module: nn.Module,
+ with_buffers: bool = True,
) -> tuple[ModuleTensorContainers, ModuleTensorContainers]:
"""Extract the references to the containers of parameters and buffers from a module."""
if isinstance(module, nn.Module):
@@ -287,19 +286,19 @@ def extract_module_containers(
buffers: list[TensorContainer] = []
memo: set[nn.Module] = set()
- def update_container(container, items):
+ def update_container(container: list[TensorContainer], items: TensorContainer) -> None:
if len(items) > 0:
container.append(items) # we need references to original dictionaries
# pylint: disable=protected-access
- update_container(params, module._parameters)
+ update_container(params, module._parameters) # type: ignore[arg-type]
if with_buffers:
update_container(buffers, module._buffers)
memo.add(module)
for submodule in module.modules():
if submodule in memo:
continue
- update_container(params, submodule._parameters)
+ update_container(params, submodule._parameters) # type: ignore[arg-type]
if with_buffers:
update_container(buffers, submodule._buffers)
memo.add(submodule)
@@ -453,7 +452,8 @@ def clone(t: torch.Tensor) -> torch.Tensor:
def clone_detach_(t: torch.Tensor) -> torch.Tensor:
if isinstance(t, nn.Parameter):
return nn.Parameter(
- t.clone().to(device=target_device).detach_(), requires_grad=t.requires_grad
+ t.clone().to(device=target_device).detach_(),
+ requires_grad=t.requires_grad,
)
return t.clone().to(device=target_device).detach_().requires_grad_(t.requires_grad)
diff --git a/torchopt/version.py b/torchopt/version.py
index b8136a22..4b091d8a 100644
--- a/torchopt/version.py
+++ b/torchopt/version.py
@@ -14,7 +14,7 @@
# ==============================================================================
"""TorchOpt: a high-performance optimizer library built upon PyTorch."""
-__version__ = '0.7.0'
+__version__ = '0.7.1'
__license__ = 'Apache License, Version 2.0'
__author__ = 'TorchOpt Contributors'
__release__ = False
@@ -26,7 +26,7 @@
try:
prefix, sep, suffix = (
subprocess.check_output(
- ['git', 'describe', '--abbrev=7'],
+ ['git', 'describe', '--abbrev=7'], # noqa: S603,S607
cwd=os.path.dirname(os.path.abspath(__file__)),
stderr=subprocess.DEVNULL,
text=True,
diff --git a/torchopt/visual.py b/torchopt/visual.py
index 7afe65a4..493ffbab 100644
--- a/torchopt/visual.py
+++ b/torchopt/visual.py
@@ -20,12 +20,13 @@
from __future__ import annotations
from collections import namedtuple
-from typing import Generator, Iterable, Mapping, cast
+from typing import Any, Generator, Iterable, Mapping, cast
import torch
from graphviz import Digraph
-from torchopt.typing import TensorOrTensors
+from torchopt import pytree
+from torchopt.typing import TensorTree
from torchopt.utils import ModuleState
@@ -38,7 +39,7 @@
SAVED_PREFIX = '_saved_'
-def get_fn_name(fn, show_attrs, max_attr_chars):
+def get_fn_name(fn: Any, show_attrs: bool, max_attr_chars: int) -> str:
"""Return function name."""
name = str(type(fn).__name__)
if not show_attrs:
@@ -63,7 +64,7 @@ def get_fn_name(fn, show_attrs, max_attr_chars):
sep = '-' * max(col1width + col2width + 2, len(name))
attrstr = '%-' + str(col1width) + 's: %' + str(col2width) + 's'
- def truncate(s): # pylint: disable=invalid-name
+ def truncate(s: str) -> str: # pylint: disable=invalid-name
return s[: col2width - 3] + '...' if len(s) > col2width else s
params = '\n'.join(attrstr % (k, truncate(str(v))) for (k, v) in attrs.items())
@@ -72,7 +73,7 @@ def truncate(s): # pylint: disable=invalid-name
# pylint: disable-next=too-many-branches,too-many-statements,too-many-locals
def make_dot(
- var: TensorOrTensors,
+ var: TensorTree,
params: (
Mapping[str, torch.Tensor]
| ModuleState
@@ -142,20 +143,20 @@ def make_dot(
dot = Digraph(node_attr=node_attr, graph_attr={'size': '12,12'})
seen = set()
- def size_to_str(size):
+ def size_to_str(size: tuple[int, ...]) -> str:
return '(' + (', ').join(map(str, size)) + ')'
- def get_var_name(var, name=None):
+ def get_var_name(var: torch.Tensor, name: str | None = None) -> str:
if not name:
name = param_map[var] if var in param_map else ''
return f'{name}\n{size_to_str(var.size())}'
- def get_var_name_with_flag(var):
+ def get_var_name_with_flag(var: torch.Tensor) -> str | None:
if var in param_map:
return f'{param_map[var][0]}\n{size_to_str(param_map[var][1].size())}'
return None
- def add_nodes(fn): # pylint: disable=too-many-branches
+ def add_nodes(fn: Any) -> None: # pylint: disable=too-many-branches
assert not isinstance(fn, torch.Tensor)
if fn in seen:
return
@@ -210,7 +211,10 @@ def add_nodes(fn): # pylint: disable=too-many-branches
dot.edge(str(id(t)), str(id(fn)))
dot.node(str(id(t)), get_var_name(t), fillcolor='orange')
- def add_base_tensor(v, color='darkolivegreen1'): # pylint: disable=invalid-name
+ def add_base_tensor(
+ v: torch.Tensor, # pylint: disable=invalid-name
+ color: str = 'darkolivegreen1',
+ ) -> None:
if v in seen:
return
seen.add(v)
@@ -220,15 +224,11 @@ def add_base_tensor(v, color='darkolivegreen1'): # pylint: disable=invalid-name
dot.edge(str(id(v.grad_fn)), str(id(v)))
# pylint: disable=protected-access
if v._is_view():
- add_base_tensor(v._base, color='darkolivegreen3')
+ add_base_tensor(v._base, color='darkolivegreen3') # type: ignore[arg-type]
dot.edge(str(id(v._base)), str(id(v)), style='dotted')
# handle multiple outputs
- if isinstance(var, (tuple, list)):
- for v in var: # pylint: disable=invalid-name
- add_base_tensor(v)
- else:
- add_base_tensor(var)
+ pytree.tree_map_(add_base_tensor, var)
resize_graph(dot)
pFad - Phonifier reborn
Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies:
Alternative Proxy
pFad Proxy
pFad v3 Proxy
pFad v4 Proxy