diff --git a/.editorconfig b/.editorconfig index 96ef7342..3ae9f69a 100644 --- a/.editorconfig +++ b/.editorconfig @@ -14,7 +14,7 @@ insert_final_newline = true indent_size = 4 src_paths=torchopt,tests,examples -[*.{yaml,yml}] +[*.{yaml,yml,json}] indent_size = 2 [*.md] @@ -25,8 +25,18 @@ x-soft-wrap-text = true indent_size = 4 x-soft-wrap-text = true +[*.{bib,tex}] +indent_size = 2 + [Makefile] indent_style = tab +[*.sh] +indent_style = tab + +[*.bat] +end_of_line = crlf +indent_style = tab + [*.{cpp,h,cu,cuh}] indent_size = 2 diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..82919783 --- /dev/null +++ b/.flake8 @@ -0,0 +1,41 @@ +[flake8] +max-line-length = 120 +max-doc-length = 100 +select = B,C,E,F,W,Y,SIM +ignore = + # E203: whitespace before ':' + # W503: line break before binary operator + # W504: line break after binary operator + # format by black + E203,W503,W504, + # E501: line too long + # W505: doc line too long + # too long docstring due to long example blocks + E501,W505, +per-file-ignores = + # F401: module imported but unused + # intentionally unused imports + __init__.py: F401 + # F401: module imported but unused + # F403: unable to detect undefined names + # F405: member mey be undefined, or defined from star imports + # members populated from optree + torchopt/pytree.py: F401,F403,F405 + # E301: expected 1 blank line + # E302: expected 2 blank lines + # E305: expected 2 blank lines after class or function definition + # E701: multiple statements on one line (colon) + # E704: multiple statements on one line (def) + # format by black + *.pyi: E301,E302,E305,E701,E704 +exclude = + .git, + .vscode, + venv, + third-party, + __pycache__, + docs/source/conf.py, + build, + dist, + examples, + tests diff --git a/.gitattributes b/.gitattributes index a894e29e..1d0afc65 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1 +1,8 @@ +* text eol=lf *.ipynb linguist-detectable=false + +*.png binary +*.jpg binary +*.jpeg binary +*.gif binary +*.pdf binary diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml new file mode 100644 index 00000000..6d381b28 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -0,0 +1,122 @@ +name: 🐛 Bug Report +description: File an issue about a bug. +title: "[BUG] " +labels: [bug] +assignees: [Benjamin-eecs] +body: + - type: markdown + attributes: + value: | + Please do your best to make the issue as easy to act on as possible, and only submit here if there is clearly a problem with TorchOpt (ask in [Discussions](https://github.com/metaopt/torchopt/discussions) first if unsure). + + - type: checkboxes + id: steps + attributes: + label: Required prerequisites + description: Make sure you've completed the following steps before submitting your issue -- thank you! + options: + - label: I have read the documentation . + required: true + - label: I have searched the [Issue Tracker](https://github.com/metaopt/torchopt/issues) and [Discussions](https://github.com/metaopt/torchopt/discussions) that this hasn't already been reported. (+1 or comment there if it has.) + required: true + - label: Consider asking first in a [Discussion](https://github.com/metaopt/torchopt/discussions/new). + required: false + + - type: input + id: version + attributes: + label: What version of TorchOpt are you using? + description: Run command `python3 -c 'print(__import__("torchopt").__version__)'` in your shell and paste the output here. + placeholder: E.g., 0.6.0 + validations: + required: true + + - type: textarea + id: system-info + attributes: + label: System information + description: | + Describe the characteristic of your environment: + + - Describe how the library was installed (pip, conda, source, ...) + - Python version + - Versions of any other relevant libraries + + ```python + import sys, torch, functorch, torchopt + print(sys.version, sys.platform) + print(torchopt.__version__, torch.__version__, functorch.__version__) + ``` + validations: + required: true + + - type: textarea + id: description + attributes: + label: Problem description + description: >- + Provide a short description, state the expected behavior and what actually happens. Include + relevant information like what version of TorchOpt you are using, what system you are on, + and any useful commands / output. + validations: + required: true + + - type: textarea + id: code + attributes: + label: Reproducible example code + description: >- + The code should be minimal, have minimal external dependencies, and isolate the functions + that cause breakage. Submit matched and complete snippets that can be easily run to diagnose + the issue. + value: | + The Python snippets: + + ```python + + ``` + + Command lines: + + ```bash + + ``` + + Extra dependencies: + + ```text + + ``` + + Steps to reproduce: + + 1. + 2. + 3. + validations: + required: true + + - type: textarea + id: traceback + attributes: + label: Traceback + description: Put the Python traceback information here. + placeholder: | + Traceback (most recent call last): + File ... + render: pytb + + - type: textarea + id: expected + attributes: + label: Expected behavior + description: Provide a clear and concise description of what you expected to happen. + + - type: textarea + id: additional-context + attributes: + label: Additional context + description: >- + Add any other context about the problem here. Screenshots may also be helpful. + + If you know or suspect the reason for this bug, paste the code lines and suggest modifications. diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md deleted file mode 100644 index 86dcfbcb..00000000 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ /dev/null @@ -1,64 +0,0 @@ ---- -name: Bug report -about: Create a report to help us improve -title: "[BUG]" -labels: ["bug"] -assignees: Benjamin-eecs - ---- - -## Describe the bug - -A clear and concise description of what the bug is. - -## To Reproduce - -Steps to reproduce the behavior. - -Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful. - -Please use the markdown code blocks for both code and stack traces. - -```python -import torchopt -``` - -```pytb -Traceback (most recent call last): - File ... -``` - -## Expected behavior - -A clear and concise description of what you expected to happen. - -## Screenshots - -If applicable, add screenshots to help explain your problem. - -## System info - -Describe the characteristic of your environment: - -- Describe how the library was installed (pip, source, ...) -- Python version -- Versions of any other relevant libraries - -```python -import torchopt, numpy, sys -print(torchopt.__version__, numpy.__version__, sys.version, sys.platform) -``` - -## Additional context - -Add any other context about the problem here. - -## Reason and Possible fixes - -If you know or suspect the reason for this bug, paste the code lines and suggest modifications. - -## Checklist - -- [ ] I have checked that there is no similar issue in the repo (**required**) -- [ ] I have read the [documentation](https://torchopt.readthedocs.io/) (**required**) -- [ ] I have provided a minimal working example to reproduce the bug (**required**) diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 00000000..a3b57cdc --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,5 @@ +blank_issues_enabled: false +contact_links: + - name: 💬 Start a discussion + url: https://github.com/metaopt/torchopt/discussions/new + about: Please ask and answer questions here if unsure. diff --git a/.github/ISSUE_TEMPLATE/feature-request.yml b/.github/ISSUE_TEMPLATE/feature-request.yml new file mode 100644 index 00000000..ee76e770 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature-request.yml @@ -0,0 +1,46 @@ +name: ✨ Feature Request +description: Suggest an idea for this project. +title: "[Feature Request] " +labels: [enhancement] +assignees: [Benjamin-eecs] +body: + - type: checkboxes + id: steps + attributes: + label: Required prerequisites + description: Make sure you've completed the following steps before submitting your issue -- thank you! + options: + - label: I have searched the [Issue Tracker](https://github.com/metaopt/torchopt/issues) and [Discussions](https://github.com/metaopt/torchopt/discussions) that this hasn't already been reported. (+1 or comment there if it has.) + required: true + - label: Consider asking first in a [Discussion](https://github.com/metaopt/torchopt/discussions/new). + required: false + + - type: textarea + id: motivation + attributes: + label: Motivation + description: Outline the motivation for the proposal. + value: | + + validations: + required: true + + - type: textarea + id: solution + attributes: + label: Solution + description: Provide a clear and concise description of what you want to happen. + + - type: textarea + id: alternatives + attributes: + label: Alternatives + description: A clear and concise description of any alternative solutions or features you've considered. + + - type: textarea + id: additional-context + attributes: + label: Additional context + description: Add any other context about the problem here. Screenshots may also be helpful. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md deleted file mode 100644 index b61aa154..00000000 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ /dev/null @@ -1,30 +0,0 @@ ---- -name: Feature request -about: Suggest an idea for this project -title: "[Feature Request]" -labels: ["enhancement"] -assignees: Benjamin-eecs - ---- - -## Motivation - -Please outline the motivation for the proposal. -Is your feature request related to a problem? e.g., "I'm always frustrated when [...]". -If this is related to another issue, please link here too. - -## Solution - -A clear and concise description of what you want to happen. - -## Alternatives - -A clear and concise description of any alternative solutions or features you've considered. - -## Additional context - -Add any other context or screenshots about the feature request here. - -## Checklist - -- [ ] I have checked that there is no similar issue in the repo (**required**) diff --git a/.github/ISSUE_TEMPLATE/questions.yml b/.github/ISSUE_TEMPLATE/questions.yml new file mode 100644 index 00000000..a33c9c3b --- /dev/null +++ b/.github/ISSUE_TEMPLATE/questions.yml @@ -0,0 +1,26 @@ +name: 🤔 Questions / Help / Support +description: Do you need support? +title: "[Question] " +labels: [question] +assignees: [Benjamin-eecs] +body: + - type: checkboxes + id: steps + attributes: + label: Required prerequisites + description: Make sure you've completed the following steps before submitting your issue -- thank you! + options: + - label: I have read the documentation . + required: true + - label: I have searched the [Issue Tracker](https://github.com/metaopt/torchopt/issues) and [Discussions](https://github.com/metaopt/torchopt/discussions) that this hasn't already been reported. (+1 or comment there if it has.) + required: true + - label: Consider asking first in a [Discussion](https://github.com/metaopt/torchopt/discussions/new). + required: false + + - type: textarea + id: questions + attributes: + label: Questions + description: Describe your questions with relevant resources such as snippets, links, images, etc. + validations: + required: true diff --git a/.github/ISSUE_TEMPLATE/questions_help_support.md b/.github/ISSUE_TEMPLATE/questions_help_support.md deleted file mode 100644 index 072d2e52..00000000 --- a/.github/ISSUE_TEMPLATE/questions_help_support.md +++ /dev/null @@ -1,17 +0,0 @@ ---- -name: Questions / Help / Support -about: Do you need support? -title: "[Question]" -labels: "question" -assignees: Benjamin-eecs - ---- - -## Questions - - - -## Checklist - -- [ ] I have checked that there is no similar issue in the repo (**required**) -- [ ] I have read the [documentation](https://torchopt.readthedocs.io/) (**required**) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 807bd4bb..2709e055 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -31,10 +31,10 @@ What types of changes does your code introduce? Put an `x` in all the boxes that Go over all the following points, and put an `x` in all the boxes that apply. If you are unsure about any of these, don't hesitate to ask. We are here to help! -- [ ] I have read the [CONTRIBUTION](https://torchopt.readthedocs.io/en/latest/developer/contributing.html) guide (**required**) +- [ ] I have read the [CONTRIBUTION](https://torchopt.readthedocs.io/en/latest/developer/contributing.html) guide. (**required**) - [ ] My change requires a change to the documentation. -- [ ] I have updated the tests accordingly (*required for a bug fix or a new feature*). +- [ ] I have updated the tests accordingly. (*required for a bug fix or a new feature*) - [ ] I have updated the documentation accordingly. -- [ ] I have reformatted the code using `make format` (**required**) -- [ ] I have checked the code using `make lint` (**required**) +- [ ] I have reformatted the code using `make format`. (**required**) +- [ ] I have checked the code using `make lint`. (**required**) - [ ] I have ensured `make test` pass. (**required**) diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..24937aad --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,13 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + labels: + - dependencies + schedule: + interval: "weekly" + day: "monday" + time: "12:00" + timezone: "Asia/Shanghai" + commit-message: + prefix: "deps(workflows)" diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 72dd012a..99b553e4 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -14,7 +14,7 @@ on: - include/** - src/** - torchopt/version.py - - .github/workflow/build.yml + - .github/workflows/build.yml release: types: - published @@ -37,81 +37,238 @@ concurrency: cancel-in-progress: ${{ github.event_name == 'pull_request' }} env: - CUDA_VERSION: "11.6" - TEST_TORCH_SPECS: "cpu cu113 cu116" + CUDA_VERSION: "12.1" + TEST_TORCH_SPECS: "cpu cu118" jobs: - build-sdist: + build: + name: Build sdist and pure-Python wheel runs-on: ubuntu-latest if: github.repository == 'metaopt/torchopt' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/')) - timeout-minutes: 10 + timeout-minutes: 60 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" - fetch-depth: 1 + fetch-depth: 0 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: - python-version: "3.7 - 3.10" + python-version: "3.8 - 3.12" # sync with requires-python in pyproject.toml update-environment: true - name: Install dependencies run: python -m pip install --upgrade pip setuptools wheel build - - name: Build sdist - run: python -m build --sdist + - name: Set __release__ + if: | + startsWith(github.ref, 'refs/tags/') || + (github.event_name == 'workflow_dispatch' && github.event.inputs.task == 'build-and-publish') + run: | + python .github/workflows/set_release.py + + - name: Print version + run: python setup.py --version + + - name: Build sdist and pure-Python wheel + run: python -m build + env: + TORCHOPT_NO_EXTENSIONS: "true" - name: Upload artifact - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: - name: sdist - path: dist/*.tar.gz + name: build + path: dist/* + if-no-files-found: error + + - name: Install dependencies + run: | + python -m pip install -r tests/requirements.txt + + - name: Install TorchOpt + run: | + python -m pip install -vvv dist/*.whl + + - name: Test with pytest + run: | + make pytest + + build-wheels-py38: + name: Build wheels for Python ${{ matrix.python-version }} on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + needs: [build] + if: github.repository == 'metaopt/torchopt' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/')) + strategy: + matrix: + os: [ubuntu-latest] + python-version: ["3.8"] # sync with requires-python in pyproject.toml + fail-fast: false + timeout-minutes: 60 + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + submodules: "recursive" + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + update-environment: true + + - name: Install dependencies + run: python -m pip install --upgrade pip setuptools wheel build + + - name: Set __release__ + if: | + startsWith(github.ref, 'refs/tags/') || + (github.event_name == 'workflow_dispatch' && github.event.inputs.task == 'build-and-publish') + run: python .github/workflows/set_release.py + + - name: Print version + run: python setup.py --version + + - name: Set CIBW_BUILD + run: python .github/workflows/set_cibw_build.py + + - name: Build wheels + uses: pypa/cibuildwheel@v2.19 + env: + CIBW_BUILD: ${{ env.CIBW_BUILD }} + with: + package-dir: . + output-dir: wheelhouse + config-file: "{package}/pyproject.toml" + + - uses: actions/upload-artifact@v4 + with: + name: wheels-${{ matrix.python-version }}-${{ matrix.os }} + path: wheelhouse/*.whl if-no-files-found: error build-wheels: - runs-on: ubuntu-latest - needs: [build-sdist] + name: Build wheels for Python ${{ matrix.python-version }} on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + needs: [build, build-wheels-py38] if: github.repository == 'metaopt/torchopt' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/')) - timeout-minutes: 90 + strategy: + matrix: + os: [ubuntu-latest] + python-version: ["3.9", "3.10", "3.11", "3.12"] # sync with requires-python in pyproject.toml + fail-fast: false + timeout-minutes: 60 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" - fetch-depth: 1 + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + update-environment: true + + - name: Install dependencies + run: python -m pip install --upgrade pip setuptools wheel build + + - name: Set __release__ + if: | + startsWith(github.ref, 'refs/tags/') || + (github.event_name == 'workflow_dispatch' && github.event.inputs.task == 'build-and-publish') + run: python .github/workflows/set_release.py + + - name: Print version + run: python setup.py --version + + - name: Set CIBW_BUILD + run: python .github/workflows/set_cibw_build.py - name: Build wheels - uses: pypa/cibuildwheel@v2.8.1 + uses: pypa/cibuildwheel@v2.19 + env: + CIBW_BUILD: ${{ env.CIBW_BUILD }} with: package-dir: . output-dir: wheelhouse config-file: "{package}/pyproject.toml" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: - name: wheels + name: wheels-${{ matrix.python-version }}-${{ matrix.os }} path: wheelhouse/*.whl if-no-files-found: error + list-artifacts: + name: List artifacts + runs-on: ubuntu-latest + needs: [build, build-wheels-py38, build-wheels] + if: github.repository == 'metaopt/torchopt' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/')) + timeout-minutes: 15 + steps: + - name: Download built sdist + uses: actions/download-artifact@v4 + with: + # unpacks default artifact into dist/ + # if `name: artifact` is omitted, the action will create extra parent dir + name: build + path: dist + + - name: Download built wheels + uses: actions/download-artifact@v4 + with: + pattern: wheels-* + path: dist + merge-multiple: true + + - name: List distributions + run: ls -lh dist/* + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + name: artifacts + path: dist/* + if-no-files-found: error + publish: runs-on: ubuntu-latest - needs: [build-sdist, build-wheels] + needs: [list-artifacts] if: | github.repository == 'metaopt/torchopt' && github.event_name != 'pull_request' && (github.event_name != 'workflow_dispatch' || github.event.inputs.task == 'build-and-publish') && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/')) timeout-minutes: 15 steps: + - name: Checkout + uses: actions/checkout@v4 + with: + submodules: "recursive" + fetch-depth: 0 + - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 if: startsWith(github.ref, 'refs/tags/') with: - python-version: "3.7 - 3.10" + python-version: "3.8 - 3.12" # sync with requires-python in pyproject.toml update-environment: true + - name: Install dependencies + run: python -m pip install --upgrade pip setuptools wheel build + + - name: Set __release__ + if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' + run: | + python .github/workflows/set_release.py + + - name: Print version + run: python setup.py --version + - name: Check consistency between the package version and release tag if: startsWith(github.ref, 'refs/tags/') run: | @@ -122,39 +279,34 @@ jobs: exit 1 fi - - name: Download built sdist - uses: actions/download-artifact@v3 + - name: Download built artifacts + uses: actions/download-artifact@v4 with: # unpacks default artifact into dist/ # if `name: artifact` is omitted, the action will create extra parent dir - name: sdist + name: artifacts path: dist - - name: Download built wheels - uses: actions/download-artifact@v3 - with: - # unpacks default artifact into dist/ - # if `name: artifact` is omitted, the action will create extra parent dir - name: wheels - path: dist + - name: List distributions + run: ls -lh dist/* - name: Publish to TestPyPI if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' - uses: pypa/gh-action-pypi-publish@v1.5.0 + uses: pypa/gh-action-pypi-publish@release/v1 with: user: __token__ password: ${{ secrets.TESTPYPI_UPLOAD_TOKEN }} - repository_url: https://test.pypi.org/legacy/ + repository-url: https://test.pypi.org/legacy/ verbose: true - print_hash: true - skip_existing: true + print-hash: true + skip-existing: true - name: Publish to PyPI if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' - uses: pypa/gh-action-pypi-publish@v1.5.0 + uses: pypa/gh-action-pypi-publish@release/v1 with: user: __token__ password: ${{ secrets.PYPI_UPLOAD_TOKEN }} verbose: true - print_hash: true - skip_existing: true + print-hash: true + skip-existing: true diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 44ece663..472d5967 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -15,39 +15,44 @@ concurrency: group: "${{ github.workflow }}-${{ github.ref }}" cancel-in-progress: ${{ github.event_name == 'pull_request' }} +env: + CUDA_VERSION: "12.1" + jobs: lint: runs-on: ubuntu-latest timeout-minutes: 30 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 1 - - name: Set up Python 3.7 # the lowest version we support - uses: actions/setup-python@v4 + - name: Set up Python 3.9 + uses: actions/setup-python@v5 with: - python-version: "3.7" + python-version: "3.9" update-environment: true - name: Setup CUDA Toolkit - uses: Jimver/cuda-toolkit@v0.2.7 id: cuda-toolkit - with: - cuda: "11.6.2" - method: network - sub-packages: '["nvcc"]' - - run: | - CUDA_VERSION="${{steps.cuda-toolkit.outputs.cuda}}" - echo "CUDA_VERSION=${CUDA_VERSION}" >> "${GITHUB_ENV}" - PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cu$(echo "${CUDA_VERSION}" | cut -d'.' -f-2 | tr -d '.')" + run: | + CUDA_PKG_SUFFIX="$(echo "${CUDA_VERSION}" | cut -d'.' -f-2 | tr '.' '-')" + sudo apt-get update && sudo apt-get install wget --yes + ( + source /etc/os-release + wget -O cuda-keyring.deb "https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${VERSION_ID//./}/$(uname -m)/cuda-keyring_1.0-1_all.deb" + sudo dpkg -i cuda-keyring.deb + ) + sudo apt-get update && sudo apt-get install "cuda-minimal-build-${CUDA_PKG_SUFFIX}" --yes + echo "PATH=/usr/local/cuda/bin${PATH:+:${PATH}}" >> "${GITHUB_ENV}" + echo "LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" >> "${GITHUB_ENV}" + + PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cu$(echo "${CUDA_PKG_SUFFIX}" | tr -d '-')" echo "PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}" >> "${GITHUB_ENV}" - echo "Installed CUDA version is: ${CUDA_VERSION}" - echo "CUDA install location: ${{steps.cuda-toolkit.outputs.CUDA_PATH}}" - nvcc -V + /usr/local/cuda/bin/nvcc -V echo "Torch index URL: ${PIP_EXTRA_INDEX_URL}" - name: Upgrade pip @@ -66,6 +71,10 @@ jobs: run: | make pre-commit + - name: ruff + run: | + make ruff + - name: flake8 run: | make flake8 @@ -82,8 +91,19 @@ jobs: run: | make cpplint + - name: clang-tidy + run: | + sudo apt-get update && sudo apt-get install libomp-dev --yes + make clang-tidy + - name: clang-format run: | + ( + source /etc/os-release + wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key | sudo tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc + sudo add-apt-repository "deb http://apt.llvm.org/${UBUNTU_CODENAME} llvm-toolchain-${UBUNTU_CODENAME} main" --yes + ) + sudo apt-get update && sudo apt-get install clang-format --yes make clang-format - name: addlicense diff --git a/.github/workflows/set_cibw_build.py b/.github/workflows/set_cibw_build.py new file mode 100755 index 00000000..ec4383f4 --- /dev/null +++ b/.github/workflows/set_cibw_build.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python3 + +# pylint: disable=missing-module-docstring + +import os +import sys + + +# pylint: disable-next=consider-using-f-string +CIBW_BUILD = 'CIBW_BUILD=*cp%d%d-*manylinux*' % sys.version_info[:2] + +print(CIBW_BUILD) +with open(os.getenv('GITHUB_ENV'), mode='a', encoding='utf-8') as file: + print(CIBW_BUILD, file=file) diff --git a/.github/workflows/set_release.py b/.github/workflows/set_release.py new file mode 100755 index 00000000..6c437f19 --- /dev/null +++ b/.github/workflows/set_release.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 + +# pylint: disable=missing-module-docstring + +import pathlib +import re + + +ROOT = pathlib.Path(__file__).absolute().parent.parent.parent + +VERSION_FILE = ROOT / 'torchopt' / 'version.py' + +VERSION_CONTENT = VERSION_FILE.read_text(encoding='utf-8') + +VERSION_FILE.write_text( + data=re.sub( + r'__release__\s*=.*', + '__release__ = True', + string=VERSION_CONTENT, + ), + encoding='utf-8', +) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c36e78f2..f156ffe3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -26,39 +26,45 @@ concurrency: group: "${{ github.workflow }}-${{ github.ref }}" cancel-in-progress: ${{ github.event_name == 'pull_request' }} +env: + CUDA_VERSION: "12.1" + jobs: test: + name: Test with CXX/CUDA extensions on ubuntu-latest runs-on: ubuntu-latest timeout-minutes: 60 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 1 - - name: Set up Python 3.7 # the lowest version we support - uses: actions/setup-python@v4 + - name: Set up Python 3.8 + uses: actions/setup-python@v5 with: - python-version: "3.7" + python-version: "3.8" # the lowest version we support (sync with requires-python in pyproject.toml) update-environment: true - name: Setup CUDA Toolkit - uses: Jimver/cuda-toolkit@v0.2.7 id: cuda-toolkit - with: - cuda: "11.6.2" - method: network - sub-packages: '["nvcc"]' - - run: | - CUDA_VERSION="${{steps.cuda-toolkit.outputs.cuda}}" - echo "CUDA_VERSION=${CUDA_VERSION}" >> "${GITHUB_ENV}" - PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cu$(echo "${CUDA_VERSION}" | cut -d'.' -f-2 | tr -d '.')" - echo "PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}" >> "${GITHUB_ENV}" + run: | + CUDA_PKG_SUFFIX="$(echo "${CUDA_VERSION}" | cut -d'.' -f-2 | tr '.' '-')" + sudo apt-get update && sudo apt-get install wget --yes + ( + source /etc/os-release + wget -O cuda-keyring.deb "https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${VERSION_ID//./}/$(uname -m)/cuda-keyring_1.0-1_all.deb" + sudo dpkg -i cuda-keyring.deb + ) + sudo apt-get update && sudo apt-get install "cuda-minimal-build-${CUDA_PKG_SUFFIX}" --yes + echo "PATH=/usr/local/cuda/bin${PATH:+:${PATH}}" >> "${GITHUB_ENV}" + echo "LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" >> "${GITHUB_ENV}" + PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cu$(echo "${CUDA_PKG_SUFFIX}" | tr -d '-')" + echo "PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}" >> "${GITHUB_ENV}" echo "Installed CUDA version is: ${CUDA_VERSION}" - echo "CUDA install location: ${{steps.cuda-toolkit.outputs.CUDA_PATH}}" - nvcc -V + /usr/local/cuda/bin/nvcc -V echo "Torch index URL: ${PIP_EXTRA_INDEX_URL}" - name: Upgrade pip @@ -74,17 +80,67 @@ jobs: USE_FP16: "ON" TORCH_CUDA_ARCH_LIST: "Common" run: | - python -m pip install -vvv -e . + python -m pip install -vvv --editable . - name: Test with pytest run: | make pytest - name: Upload coverage to Codecov - uses: codecov/codecov-action@v1 + uses: codecov/codecov-action@v4 + if: ${{ matrix.os == 'ubuntu-latest' }} with: - token: ${{ secrets.CODECOV }} + token: ${{ secrets.CODECOV_TOKEN }} file: ./tests/coverage.xml flags: unittests name: codecov-umbrella fail_ci_if_error: false + + test-pure-python: + name: Test for pure-Python on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + timeout-minutes: 60 + strategy: + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + fail-fast: false + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + submodules: "recursive" + fetch-depth: 1 + + - name: Set up Python 3.8 + uses: actions/setup-python@v5 + with: + python-version: "3.8" # the lowest version we support (sync with requires-python in pyproject.toml) + update-environment: true + + - name: Upgrade pip + run: | + python -m pip install --upgrade pip setuptools wheel + + - name: Install dependencies + run: | + python -m pip install -r tests/requirements.txt + + - name: Install TorchOpt + run: | + python -m pip install -vvv --editable . + env: + TORCHOPT_NO_EXTENSIONS: "true" + + - name: Test with pytest + run: | + make pytest + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + if: ${{ matrix.os == 'ubuntu-latest' }} + with: + token: ${{ secrets.CODECOV_TOKEN }} + file: ./tests/coverage.xml + flags: unittests + name: codecov-umbrella-pure-python + fail_ci_if_error: false diff --git a/.gitignore b/.gitignore index a0107f9b..350ddfb2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,5 @@ -##### Project specific ##### -!torchopt/_src/ -!torchopt/_lib/ +##### Project Specific ##### +third-party/ ##### Python.gitignore ##### # Byte-compiled / optimized / DLL files @@ -31,6 +30,7 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST +*.whl # PyInstaller # Usually these files are written by a python script from a template @@ -77,6 +77,7 @@ instance/ # Sphinx documentation docs/_build/ docs/source/_build/ +_autosummary/ # PyBuilder .pybuilder/ @@ -145,6 +146,9 @@ venv.bak/ # mkdocs documentation /site +# ruff +.ruff_cache/ + # mypy .mypy_cache/ .dmypy.json diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 21062f0e..7ab860a5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,8 +1,15 @@ # See https://pre-commit.com for more information # See https://pre-commit.com/hooks.html for more hooks +ci: + skip: [pylint] + autofix_prs: true + autofix_commit_msg: "fix: [pre-commit.ci] auto fixes [...]" + autoupdate_commit_msg: "chore(pre-commit): [pre-commit.ci] autoupdate" + autoupdate_schedule: monthly +default_stages: [commit, push, manual] repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.6.0 hooks: - id: check-symlinks - id: destroyed-symlinks @@ -18,16 +25,58 @@ repos: - id: detect-private-key - id: debug-statements - id: double-quote-string-fixer + - repo: https://github.com/pre-commit/mirrors-clang-format + rev: v18.1.8 + hooks: + - id: clang-format + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.5.0 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/PyCQA/isort - rev: 5.10.1 + rev: 5.13.2 hooks: - id: isort - stages: [commit, push, manual] - repo: https://github.com/psf/black - rev: 22.8.0 + rev: 24.4.2 + hooks: + - id: black-jupyter + - repo: https://github.com/asottile/pyupgrade + rev: v3.16.0 + hooks: + - id: pyupgrade + args: [--py38-plus] # sync with requires-python + exclude: | + (?x)( + ^examples/ + ) + - repo: https://github.com/pycqa/flake8 + rev: 7.1.0 hooks: - - id: black - stages: [commit, push, manual] + - id: flake8 + additional_dependencies: + - flake8-bugbear + - flake8-comprehensions + - flake8-docstrings + - flake8-pyi + - flake8-simplify + exclude: | + (?x)( + ^examples/| + ^tests/| + ^docs/source/conf.py$ + ) + - repo: https://github.com/codespell-project/codespell + rev: v2.3.0 + hooks: + - id: codespell + additional_dependencies: [".[toml]"] + exclude: | + (?x)( + ^docs/source/spelling_wordlist.txt$| + ^docs/source/references.bib$ + ) - repo: local hooks: - id: pylint @@ -36,9 +85,22 @@ repos: language: system types: [python] require_serial: true - stages: [commit, push, manual] exclude: | (?x)( + ^docs/| + ^examples/| + ^tests/| + ^setup.py$ + ) + - repo: https://github.com/pycqa/pydocstyle + rev: 6.3.0 + hooks: + - id: pydocstyle + additional_dependencies: [".[toml]"] + exclude: | + (?x)( + ^.github/| + ^docs/| ^examples/| ^tests/| ^setup.py$ diff --git a/.pylintrc b/.pylintrc index e55faae7..a21967ee 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,4 +1,22 @@ -[MASTER] +[MAIN] + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Load and enable all available extensions. Use --list-extensions to see a list +# all available extensions. +#enable-all-extensions= + +# In error mode, messages with a category besides ERROR or FATAL are +# suppressed, and no reports are done by default. Error mode is compatible with +# disabling specific errors. +#errors-only= + +# Always return a 0 (non-error) status code, even if lint errors are found. +# This is primarily useful in continuous integration scripts. +#exit-zero= # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may @@ -16,28 +34,41 @@ extension-pkg-whitelist= # specified are enabled, while categories only check already-enabled messages. fail-on= -# Specify a score threshold to be exceeded before program exits with error. -fail-under=10.0 +# Specify a score threshold under which the program will exit with error. +fail-under=10 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +#from-stdin= # Files or directories to be skipped. They should be base names, not paths. ignore=CVS,.vscode,.history -# Add files or directories matching the regex patterns to the ignore-list. The -# regex matches against paths and can be in Posix or Windows format. -ignore-paths=^_C/$,^examples/$,^tests/$ +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\' represents the directory delimiter on Windows systems, it +# can't be used as an escape character. +ignore-paths=^_C/$,^docs/$,^examples/$,^tests/$ -# Files or directories matching the regex patterns are skipped. The regex -# matches against base names, not paths. The default value ignores emacs file -# locks +# Files or directories matching the regular expression patterns are skipped. +# The regex matches against base names, not paths. The default value ignores +# Emacs file locks ignore-patterns=^\.# +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + # Python code to execute, usually for sys.path manipulation such as # pygtk.require(). #init-hook= # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the -# number of processors available to use. -jobs=1 +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +jobs=0 # Control the amount of potential inferred values when inferring a single # object. This can help the performance when dealing with large functions or @@ -53,7 +84,7 @@ persistent=yes # Minimum Python version to use for version dependent checks. Will default to # the version used to run pylint. -py-version=3.7 +py-version=3.8 # the lowest version we support (sync with requires-python in pyproject.toml) # Discover python modules and packages in the file system subtree. recursive=no @@ -66,115 +97,8 @@ suggestion-mode=yes # active Python interpreter and may run arbitrary code. unsafe-load-any-extension=no - -[MESSAGES CONTROL] - -# Only show warnings with the listed confidence levels. Leave empty to show -# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, -# UNDEFINED. -confidence= - -# Disable the message, report, category or checker with the given id(s). You -# can either give multiple identifiers separated by comma (,) or put this -# option multiple times (only on the command line, not in the configuration -# file where it should appear only once). You can also use "--disable=all" to -# disable everything first and then re-enable specific checks. For example, if -# you want to run only the similarities checker, you can use "--disable=all -# --enable=similarities". If you want to run only the classes checker, but have -# no Warning level messages displayed, use "--disable=all --enable=classes -# --disable=W". -disable=missing-module-docstring, - duplicate-code, - consider-using-from-import - -# Enable the message, report, category or checker with the given id(s). You can -# either give multiple identifier separated by comma (,) or put this option -# multiple time (only on the command line, not in the configuration file where -# it should appear only once). See also the "--disable" option for examples. -enable=c-extension-no-member - - -[REPORTS] - -# Python expression which should return a score less than or equal to 10. You -# have access to the variables 'error', 'warning', 'refactor', and 'convention' -# which contain the number of messages in each category, as well as 'statement' -# which is the total number of statements analyzed. This score is used by the -# global evaluation report (RP0004). -evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) - -# Template used to display messages. This is a python new-style format string -# used to format the message information. See doc for all details. -#msg-template= - -# Set the output format. Available formats are text, parseable, colorized, json -# and msvs (visual studio). You can also give a reporter class, e.g. -# mypackage.mymodule.MyReporterClass. -output-format=text - -# Tells whether to display a full report or only the messages. -reports=no - -# Activate the evaluation score. -score=yes - - -[REFACTORING] - -# Maximum number of nested blocks for function / method body -max-nested-blocks=5 - -# Complete name of functions that never returns. When checking for -# inconsistent-return-statements if a never returning function is called then -# it will be considered as an explicit return statement and no message will be -# printed. -never-returning-functions=sys.exit,argparse.parse_error - - -[STRING] - -# This flag controls whether inconsistent-quotes generates a warning when the -# character used as a quote delimiter is used inconsistently within a module. -check-quote-consistency=no - -# This flag controls whether the implicit-str-concat should generate a warning -# on implicit string concatenation in sequences defined over several lines. -check-str-concat-over-line-jumps=no - - -[MISCELLANEOUS] - -# List of note tags to take in consideration, separated by a comma. -notes=FIXME, - XXX, - TODO - -# Regular expression of note tags to take in consideration. -#notes-rgx= - - -[SPELLING] - -# Limits count of emitted suggestions for spelling mistakes. -max-spelling-suggestions=4 - -# Spelling dictionary name. Available dictionaries: none. To make it work, -# install the 'python-enchant' package. -spelling-dict= - -# List of comma separated words that should be considered directives if they -# appear and the beginning of a comment and should not be checked. -spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: - -# List of comma separated words that should not be checked. -spelling-ignore-words= - -# A path to a file that contains the private dictionary; one word per line. -spelling-private-dict-file= - -# Tells whether to store unknown words to the private dictionary (see the -# --spelling-private-dict-file option) instead of raising a message. -spelling-store-unknown-words=no +# In verbose mode, extra non-checker-related info will be displayed. +#verbose= [BASIC] @@ -266,7 +190,9 @@ good-names=i, t, lr, mu, - nu + nu, + x, + y # Good variable names regexes, separated by a comma. If names match any regex, # they will always be accepted @@ -323,158 +249,6 @@ variable-naming-style=snake_case #variable-rgx= -[LOGGING] - -# The type of string formatting that logging methods do. `old` means using % -# formatting, `new` is for `{}` formatting. -logging-format-style=old - -# Logging modules to check that the string format arguments are in logging -# function parameter format. -logging-modules=logging - - -[VARIABLES] - -# List of additional names supposed to be defined in builtins. Remember that -# you should avoid defining new builtins when possible. -additional-builtins= - -# Tells whether unused global variables should be treated as a violation. -allow-global-unused-variables=yes - -# List of names allowed to shadow builtins -allowed-redefined-builtins= - -# List of strings which can identify a callback function by name. A callback -# name must start or end with one of those strings. -callbacks=cb_, - _cb - -# A regular expression matching the name of dummy variables (i.e. expected to -# not be used). -dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ - -# Argument names that match this expression will be ignored. Default to name -# with leading underscore. -ignored-argument-names=_.*|^ignored_|^unused_ - -# Tells whether we should check for unused import in __init__ files. -init-import=no - -# List of qualified module names which can have objects that can redefine -# builtins. -redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io - - -[TYPECHECK] - -# List of decorators that produce context managers, such as -# contextlib.contextmanager. Add to this list to register other decorators that -# produce valid context managers. -contextmanager-decorators=contextlib.contextmanager - -# List of members which are set dynamically and missed by pylint inference -# system, and so shouldn't trigger E1101 when accessed. Python regular -# expressions are accepted. -generated-members=numpy.*, - torch.* - -# Tells whether missing members accessed in mixin class should be ignored. A -# class is considered mixin if its name matches the mixin-class-rgx option. -ignore-mixin-members=yes - -# Tells whether to warn about missing members when the owner of the attribute -# is inferred to be None. -ignore-none=yes - -# This flag controls whether pylint should warn about no-member and similar -# checks whenever an opaque object is returned when inferring. The inference -# can return multiple potential results while evaluating a Python object, but -# some branches might not be evaluated, which results in partial inference. In -# that case, it might be useful to still emit no-member and other checks for -# the rest of the inferred objects. -ignore-on-opaque-inference=yes - -# List of class names for which member attributes should not be checked (useful -# for classes with dynamically set attributes). This supports the use of -# qualified names. -ignored-classes=optparse.Values,thread._local,_thread._local - -# List of module names for which member attributes should not be checked -# (useful for modules/projects where namespaces are manipulated during runtime -# and thus existing member attributes cannot be deduced by static analysis). It -# supports qualified module names, as well as Unix pattern matching. -ignored-modules= - -# Show a hint with possible names when a member name was not found. The aspect -# of finding the hint is based on edit distance. -missing-member-hint=yes - -# The minimum edit distance a name should have in order to be considered a -# similar match for a missing member name. -missing-member-hint-distance=1 - -# The total number of similar names that should be taken in consideration when -# showing a hint for a missing member. -missing-member-max-choices=1 - -# Regex pattern to define which classes are considered mixins ignore-mixin- -# members is set to 'yes' -mixin-class-rgx=.*[Mm]ixin - -# List of decorators that change the signature of a decorated function. -signature-mutators= - - -[FORMAT] - -# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. -expected-line-ending-format= - -# Regexp for a line that is allowed to be longer than the limit. -ignore-long-lines=^\s*(# )??$ - -# Number of spaces of indent required inside a hanging or continued line. -indent-after-paren=4 - -# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 -# tab). -indent-string=' ' - -# Maximum number of characters on a single line. -max-line-length=100 - -# Maximum number of lines in a module. -max-module-lines=1000 - -# Allow the body of a class to be on the same line as the declaration if body -# contains single statement. -single-line-class-stmt=no - -# Allow the body of an if to be on the same line as the test if there is no -# else. -single-line-if-stmt=no - - -[SIMILARITIES] - -# Comments are removed from the similarity computation -ignore-comments=yes - -# Docstrings are removed from the similarity computation -ignore-docstrings=yes - -# Imports are removed from the similarity computation -ignore-imports=no - -# Signatures are removed from the similarity computation -ignore-signatures=no - -# Minimum lines number of a similarity. -min-similarity-lines=4 - - [CLASSES] # Warn about protected attribute access inside special methods @@ -542,6 +316,43 @@ max-statements=50 min-public-methods=2 +[EXCEPTIONS] + +# Exceptions that will emit a warning when caught. +overgeneral-exceptions=builtins.BaseException, + builtins.Exception + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=120 + +# Maximum number of lines in a module. +max-module-lines=1000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + [IMPORTS] # List of modules that can be imported at any level, not just the top level @@ -551,11 +362,6 @@ allow-any-import-level= # Allow wildcard imports from modules that define __all__. allow-wildcard-with-all=no -# Analyse import fallback blocks. This can be used to support both Python 2 and -# 3 compatible code, which means that the block might have code that exists -# only in one or another interpreter, leading to false positives when analysed. -analyse-fallback-blocks=no - # Deprecated modules which should not be used, separated by a comma. deprecated-modules= @@ -583,9 +389,241 @@ known-third-party=enchant preferred-modules= -[EXCEPTIONS] +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, +# UNDEFINED. +confidence=HIGH, + CONTROL_FLOW, + INFERENCE, + INFERENCE_FAILURE, + UNDEFINED + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then re-enable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=duplicate-code, + consider-using-from-import + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[METHOD_ARGS] + +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +notes-rgx= + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit,argparse.parse_error + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each +# category, as well as 'statement' which is the total number of statements +# analyzed. This score is used by the global evaluation report (RP0004). +evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +msg-template= -# Exceptions that will emit a warning when being caught. Defaults to -# "BaseException, Exception". -overgeneral-exceptions=BaseException, - Exception +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +#output-format= + +# Tells whether to display a full report or only the messages. +reports=no + +# Activate the evaluation score. +score=yes + + +[SIMILARITIES] + +# Comments are removed from the similarity computation +ignore-comments=yes + +# Docstrings are removed from the similarity computation +ignore-docstrings=yes + +# Imports are removed from the similarity computation +ignore-imports=yes + +# Signatures are removed from the similarity computation +ignore-signatures=yes + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: en_AU (hunspell), en_CA +# (hunspell), en_GB (hunspell), en_US (hunspell), en_ZA (hunspell). +spelling-dict=en_US + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file=docs/source/spelling_wordlist.txt + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members=numpy.*, + torch.* + +# Tells whether missing members accessed in mixin class should be ignored. A +# class is considered mixin if its name matches the mixin-class-rgx option. +ignore-mixin-members=yes + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins=no-member, + not-async-context-manager, + not-context-manager, + attribute-defined-outside-init + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx=.*[Mm]ixin + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of names allowed to shadow builtins +allowed-redefined-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 73e1e60f..c014fa97 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -7,9 +7,9 @@ version: 2 # Set the version of Python and other tools you might need build: - os: ubuntu-20.04 + os: ubuntu-lts-latest tools: - python: mambaforge-4.10 + python: mambaforge-latest jobs: post_install: - python -m pip install --upgrade pip setuptools @@ -19,10 +19,6 @@ build: conda: environment: docs/conda-recipe.yaml -# If using Sphinx, optionally build your docs in additional formats such as PDF -formats: - - pdf - # Build documentation in the docs/ directory with Sphinx sphinx: builder: html diff --git a/CHANGELOG.md b/CHANGELOG.md index 5334e26a..62234c25 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,22 +13,115 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - +- Enable CI workflow to build CXX/CUDA extension for Python 3.12 by [@XuehaiPan](https://github.com/XuehaiPan) in [#216](https://github.com/metaopt/torchopt/pull/216). ### Changed +- Refactor the raw import statement in `setup.py` with `importlib` utilities by [@XuehaiPan](https://github.com/XuehaiPan) in [#214](https://github.com/metaopt/torchopt/pull/214). + +### Fixed + +- + +### Removed + +- Drop PyTorch 1.x support by [@XuehaiPan](https://github.com/XuehaiPan) in [#215](https://github.com/metaopt/torchopt/pull/215). + +------ + +## [0.7.3] - 2023-11-10 + +### Changed +- Set minimal C++ standard to C++17 by [@XuehaiPan](https://github.com/XuehaiPan) in [#195](https://github.com/metaopt/torchopt/pull/195). ### Fixed +- Fix `optree` compatibility for multi-tree-map with `None` values by [@XuehaiPan](https://github.com/XuehaiPan) in [#195](https://github.com/metaopt/torchopt/pull/195). +------ + +## [0.7.2] - 2023-08-18 + +### Added + +- Implement `Adadelta`, `RAdam`, `Adamax` optimizer by [@JieRen98](https://github.com/JieRen98) and [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#171](https://github.com/metaopt/torchopt/pull/171). + +------ + +## [0.7.1] - 2023-05-12 + +### Added + +- Enable CI workflow to build CXX/CUDA extension for Python 3.11 by [@XuehaiPan](https://github.com/XuehaiPan) in [#152](https://github.com/metaopt/torchopt/pull/152). +- Implement AdaGrad optimizer and exponential learning rate decay schedule by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#80](https://github.com/metaopt/torchopt/pull/80). +- Enable tests on Windows by [@XuehaiPan](https://github.com/XuehaiPan) in [#140](https://github.com/metaopt/torchopt/pull/140). +- Add `ruff` and `flake8` plugins integration by [@XuehaiPan](https://github.com/XuehaiPan) in [#138](https://github.com/metaopt/torchopt/pull/138) and [#139](https://github.com/metaopt/torchopt/pull/139). +- Add more documentation on implicit differentiation by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#143](https://github.com/metaopt/torchopt/pull/143). + +### Fixed + +- Fix overloaded annotations of `extract_state_dict` by [@StefanoWoerner](https://github.com/StefanoWoerner) in [#162](https://github.com/metaopt/torchopt/pull/162). +- Fix transpose empty iterable with `zip(*nested)` in transformations by [@XuehaiPan](https://github.com/XuehaiPan) in [#145](https://github.com/metaopt/torchopt/pull/145). ### Removed +- Drop Python 3.7 support by [@XuehaiPan](https://github.com/XuehaiPan) in [#136](https://github.com/metaopt/torchopt/pull/136). + +------ + +## [0.7.0] - 2023-02-16 + +### Added + +- Update Sphinx documentation by [@XuehaiPan](https://github.com/XuehaiPan) and [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@waterhorse1](https://github.com/waterhorse1) and [@JieRen98](https://github.com/JieRen98) in [#127](https://github.com/metaopt/torchopt/pull/127). +- Add object-oriented modules support for zero-order differentiation by [@XuehaiPan](https://github.com/XuehaiPan) in [#125](https://github.com/metaopt/torchopt/pull/125). + +### Changed + +- Use postponed evaluation of annotations and update doctring style by [@XuehaiPan](https://github.com/XuehaiPan) in [#135](https://github.com/metaopt/torchopt/pull/135). +- Rewrite setup CUDA Toolkit logic by [@XuehaiPan](https://github.com/XuehaiPan) in [#133](https://github.com/metaopt/torchopt/pull/133). +### Fixed + +- Update tests and fix corresponding bugs by [@XuehaiPan](https://github.com/XuehaiPan) and [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@JieRen98](https://github.com/JieRen98) in [#78](https://github.com/metaopt/torchopt/pull/78). +- Fix memory leak in implicit MAML omniglot few-shot classification example with OOP APIs by [@XuehaiPan](https://github.com/XuehaiPan) in [#113](https://github.com/metaopt/torchopt/pull/113). ------ +## [0.6.0] - 2022-12-07 + +### Added + +- Add unroll pragma for CUDA OPs by [@JieRen98](https://github.com/JieRen98) and [@XuehaiPan](https://github.com/XuehaiPan) in [#112](https://github.com/metaopt/torchopt/pull/112). +- Add Python implementation of accelerated OP and pure-Python wheels by [@XuehaiPan](https://github.com/XuehaiPan) in [#67](https://github.com/metaopt/torchopt/pull/67). +- Add `nan_to_num` hook and gradient transformation by [@XuehaiPan](https://github.com/XuehaiPan) in [#119](https://github.com/metaopt/torchopt/pull/119). +- Add matrix inversion linear solver with neumann series approximation by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#98](https://github.com/metaopt/torchopt/pull/98). +- Add if condition of number of threads for CPU OPs by [@JieRen98](https://github.com/JieRen98) in [#105](https://github.com/metaopt/torchopt/pull/105). +- Add implicit MAML omniglot few-shot classification example with OOP APIs by [@XuehaiPan](https://github.com/XuehaiPan) in [#107](https://github.com/metaopt/torchopt/pull/107). +- Add implicit MAML omniglot few-shot classification example by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#48](https://github.com/metaopt/torchopt/pull/48). +- Add object-oriented modules support for implicit meta-gradient by [@XuehaiPan](https://github.com/XuehaiPan) in [#101](https://github.com/metaopt/torchopt/pull/101). +- Bump PyTorch version to 1.13.0 by [@XuehaiPan](https://github.com/XuehaiPan) in [#104](https://github.com/metaopt/torchopt/pull/104). +- Add zero-order gradient estimation by [@JieRen98](https://github.com/JieRen98) in [#93](https://github.com/metaopt/torchopt/pull/93). +- Add RPC-based distributed training support and add distributed MAML example by [@XuehaiPan](https://github.com/XuehaiPan) in [#83](https://github.com/metaopt/torchopt/pull/83). +- Add full type hints by [@XuehaiPan](https://github.com/XuehaiPan) in [#92](https://github.com/metaopt/torchopt/pull/92). +- Add API documentation and tutorial for implicit gradients by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@JieRen98](https://github.com/JieRen98) and [@XuehaiPan](https://github.com/XuehaiPan) in [#73](https://github.com/metaopt/torchopt/pull/73). +- Add wrapper class for functional optimizers and examples of `functorch` integration by [@vmoens](https://github.com/vmoens) and [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#6](https://github.com/metaopt/torchopt/pull/6). +- Implicit differentiation support by [@JieRen98](https://github.com/JieRen98) and [@waterhorse1](https://github.com/waterhorse1) and [@XuehaiPan](https://github.com/XuehaiPan) in [#41](https://github.com/metaopt/torchopt/pull/41). + +### Changed + +- Refactor code organization by [@XuehaiPan](https://github.com/XuehaiPan) in [#92](https://github.com/metaopt/torchopt/pull/92) and [#100](https://github/metaopt/torchopt/pull/100). + +### Fixed + +- Fix implicit MAML omniglot few-shot classification example by [@XuehaiPan](https://github.com/XuehaiPan) in [#108](https://github.com/metaopt/torchopt/pull/108). +- Align results of distributed examples by [@XuehaiPan](https://github.com/XuehaiPan) in [#95](https://github.com/metaopt/torchopt/pull/95). +- Fix `None` in module containers by [@XuehaiPan](https://github.com/XuehaiPan). +- Fix backward errors when using inplace `sqrt_` and `add_` by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@JieRen98](https://github.com/JieRen98) and [@XuehaiPan](https://github.com/XuehaiPan). +- Fix LR scheduling by [@XuehaiPan](https://github.com/XuehaiPan) in [#76](https://github.com/metaopt/torchopt/pull/76). +- Fix the step count tensor (`shape=(1,)`) can change the shape of the scalar updates (`shape=()`) by [@XuehaiPan](https://github.com/XuehaiPan) in [#71](https://github.com/metaopt/torchopt/pull/71). + ## [0.5.0] - 2022-09-05 ### Added @@ -114,9 +207,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ------ -[Unreleased]: https://github.com/olivierlacan/keep-a-changelog/compare/v0.5.0...HEAD -[0.5.0]: https://github.com/olivierlacan/keep-a-changelog/compare/v0.4.3...v0.5.0 -[0.4.3]: https://github.com/olivierlacan/keep-a-changelog/compare/v0.4.2...v0.4.3 -[0.4.2]: https://github.com/olivierlacan/keep-a-changelog/compare/v0.4.1...v0.4.2 -[0.4.1]: https://github.com/olivierlacan/keep-a-changelog/compare/v0.4.0...v0.4.1 -[0.4.0]: https://github.com/olivierlacan/keep-a-changelog/releases/tag/v0.4.0 +[Unreleased]: https://github.com/metaopt/torchopt/compare/v0.7.3...HEAD +[0.7.3]: https://github.com/metaopt/torchopt/compare/v0.7.2...v0.7.3 +[0.7.2]: https://github.com/metaopt/torchopt/compare/v0.7.1...v0.7.2 +[0.7.1]: https://github.com/metaopt/torchopt/compare/v0.7.0...v0.7.1 +[0.7.0]: https://github.com/metaopt/torchopt/compare/v0.6.0...v0.7.0 +[0.6.0]: https://github.com/metaopt/torchopt/compare/v0.5.0...v0.6.0 +[0.5.0]: https://github.com/metaopt/torchopt/compare/v0.4.3...v0.5.0 +[0.4.3]: https://github.com/metaopt/torchopt/compare/v0.4.2...v0.4.3 +[0.4.2]: https://github.com/metaopt/torchopt/compare/v0.4.1...v0.4.2 +[0.4.1]: https://github.com/metaopt/torchopt/compare/v0.4.0...v0.4.1 +[0.4.0]: https://github.com/metaopt/torchopt/releases/tag/v0.4.0 diff --git a/CITATION.cff b/CITATION.cff index b738a26c..3c6098bf 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -20,6 +20,10 @@ authors: family-names: Pan email: xuehaipan@pku.edu.cn affiliation: Peking University + - given-names: Yao + family-names: Fu + email: f.yu@ed.ac.uk + affiliation: University of Edinburgh - given-names: Luo family-names: Mai email: luo.mai@ed.ac.uk @@ -28,7 +32,7 @@ authors: family-names: Yang affiliation: Peking University email: yaodong.yang@pku.edu.cn -version: 0.5.0 -date-released: "2022-09-05" +version: 0.7.3 +date-released: "2023-11-10" license: Apache-2.0 repository-code: "https://github.com/metaopt/torchopt" diff --git a/CMakeLists.txt b/CMakeLists.txt index 26786756..101ba3ec 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,19 +13,30 @@ # limitations under the License. # ============================================================================== -cmake_minimum_required(VERSION 3.8) +cmake_minimum_required(VERSION 3.11) # for FetchContent project(torchopt LANGUAGES CXX) +include(FetchContent) + +set(THIRD_PARTY_DIR "${CMAKE_SOURCE_DIR}/third-party") +if(NOT DEFINED PYBIND11_VERSION AND NOT "$ENV{PYBIND11_VERSION}" STREQUAL "") + set(PYBIND11_VERSION "$ENV{PYBIND11_VERSION}") +endif() +if(NOT PYBIND11_VERSION) + set(PYBIND11_VERSION stable) +endif() + if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) endif() -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) find_package(Threads REQUIRED) # -pthread find_package(OpenMP REQUIRED) # -Xpreprocessor -fopenmp set(CMAKE_POSITION_INDEPENDENT_CODE ON) # -fPIC +set(CMAKE_CXX_VISIBILITY_PRESET hidden) # -fvisibility=hidden if(MSVC) string(APPEND CMAKE_CXX_FLAGS " /Wall") @@ -168,7 +179,7 @@ endif() system( STRIP OUTPUT_VARIABLE PYTHON_VERSION - COMMAND "${PYTHON_EXECUTABLE}" -c "print(__import__('platform').python_version())" + COMMAND "${PYTHON_EXECUTABLE}" -c "print('.'.join(map(str, __import__('sys').version_info[:3])))" ) message(STATUS "Use Python version: ${PYTHON_VERSION}") @@ -178,7 +189,7 @@ if(NOT DEFINED PYTHON_INCLUDE_DIR) message(STATUS "Auto detecting Python include directory...") system( STRIP OUTPUT_VARIABLE PYTHON_INCLUDE_DIR - COMMAND "${PYTHON_EXECUTABLE}" -c "print(__import__('sysconfig').get_path('include'))" + COMMAND "${PYTHON_EXECUTABLE}" -c "print(__import__('sysconfig').get_path('platinclude'))" ) endif() @@ -186,15 +197,16 @@ if("${PYTHON_INCLUDE_DIR}" STREQUAL "") message(FATAL_ERROR "Python include directory not found") else() message(STATUS "Detected Python include directory: \"${PYTHON_INCLUDE_DIR}\"") - include_directories(${PYTHON_INCLUDE_DIR}) + include_directories("${PYTHON_INCLUDE_DIR}") endif() system( STRIP OUTPUT_VARIABLE PYTHON_SITE_PACKAGES - COMMAND "${PYTHON_EXECUTABLE}" -c "print(__import__('sysconfig') .get_path('purelib'))" + COMMAND "${PYTHON_EXECUTABLE}" -c "print(__import__('sysconfig').get_path('purelib'))" ) message(STATUS "Detected Python site packages: \"${PYTHON_SITE_PACKAGES}\"") +# Include pybind11 set(PYBIND11_PYTHON_VERSION "${PYTHON_VERSION}") if(NOT DEFINED PYBIND11_CMAKE_DIR) @@ -206,14 +218,28 @@ if(NOT DEFINED PYBIND11_CMAKE_DIR) endif() if("${PYBIND11_CMAKE_DIR}" STREQUAL "") - message(FATAL_ERROR "Pybind11 CMake directory not found") + FetchContent_Declare( + pybind11 + GIT_REPOSITORY https://github.com/pybind/pybind11.git + GIT_TAG "${PYBIND11_VERSION}" + GIT_SHALLOW TRUE + SOURCE_DIR "${THIRD_PARTY_DIR}/pybind11" + BINARY_DIR "${THIRD_PARTY_DIR}/.cmake/pybind11/build" + STAMP_DIR "${THIRD_PARTY_DIR}/.cmake/pybind11/stamp" + ) + FetchContent_GetProperties(pybind11) + + if(NOT pybind11_POPULATED) + message(STATUS "Populating Git repository pybind11@${PYBIND11_VERSION} to third-party/pybind11...") + FetchContent_MakeAvailable(pybind11) + endif() else() message(STATUS "Detected Pybind11 CMake directory: \"${PYBIND11_CMAKE_DIR}\"") find_package(pybind11 CONFIG PATHS "${PYBIND11_CMAKE_DIR}") endif() if(NOT DEFINED TORCH_INCLUDE_PATH) - message(STATUS "Auto detecting PyTorch include directory...") + message(STATUS "Auto detecting Torch include directory...") system( STRIP OUTPUT_VARIABLE TORCH_INCLUDE_PATH COMMAND "${PYTHON_EXECUTABLE}" -c "print('\\\;'.join(__import__('torch.utils.cpp_extension', fromlist=[None]).include_paths()))" @@ -232,7 +258,7 @@ else() endif() if(NOT DEFINED TORCH_LIBRARY_PATH) - message(STATUS "Auto detecting PyTorch library directory...") + message(STATUS "Auto detecting Torch library directory...") system( STRIP OUTPUT_VARIABLE TORCH_LIBRARY_PATH COMMAND "${PYTHON_EXECUTABLE}" -c "print('\\\;'.join(__import__('torch.utils.cpp_extension', fromlist=[None]).library_paths()))" @@ -251,19 +277,23 @@ endif() unset(TORCH_LIBRARIES) +foreach(VAR_PATH ${TORCH_LIBRARY_PATH}) + file(GLOB TORCH_LIBRARY "${VAR_PATH}/*") + message(STATUS "Detected Torch libraries: \"${TORCH_LIBRARY}\"") +endforeach() + foreach(VAR_PATH ${TORCH_LIBRARY_PATH}) if(WIN32) file(GLOB TORCH_LIBRARY "${VAR_PATH}/*.lib") else() file(GLOB TORCH_LIBRARY "${VAR_PATH}/libtorch_python.*") endif() - list(APPEND TORCH_LIBRARIES "${TORCH_LIBRARY}") endforeach() -message(STATUS "Detected Torch libraries: \"${TORCH_LIBRARIES}\"") +message(STATUS "Detected Torch Python libraries: \"${TORCH_LIBRARIES}\"") add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) -include_directories(${CMAKE_SOURCE_DIR}) +include_directories("${CMAKE_SOURCE_DIR}") add_subdirectory(src) diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index cc4a4e9a..4c96e694 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -5,7 +5,7 @@ We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender -identity and expression, level of experience, education, socio-economic status, +identity and expression, level of experience, education, socioeconomic status, nationality, personal appearance, race, religion, or sexual identity and orientation. diff --git a/CPPLINT.cfg b/CPPLINT.cfg index 41265bb6..dd346401 100644 --- a/CPPLINT.cfg +++ b/CPPLINT.cfg @@ -1 +1,4 @@ linelength=100 +filter=-readability/nolint +filter=-readability/braces +filter=-whitespace/newline diff --git a/Dockerfile b/Dockerfile index 82434eed..246a81e9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,7 +7,7 @@ # $ docker build --target devel --tag torchopt-devel:latest . # -ARG cuda_docker_tag="11.6.2-cudnn8-devel-ubuntu20.04" +ARG cuda_docker_tag="12.1.0-cudnn8-devel-ubuntu22.04" FROM nvidia/cuda:"${cuda_docker_tag}" AS builder ENV DEBIAN_FRONTEND=noninteractive @@ -16,12 +16,12 @@ SHELL ["/bin/bash", "-c"] # Install packages RUN apt-get update && \ apt-get install -y sudo ca-certificates openssl \ - git ssh build-essential gcc-10 g++-10 cmake make \ - python3.9-dev python3.9-venv graphviz && \ + git ssh build-essential gcc g++ cmake make \ + python3-dev python3-venv graphviz && \ rm -rf /var/lib/apt/lists/* ENV LANG C.UTF-8 -ENV CC=gcc-10 CXX=g++-10 +ENV CC=gcc CXX=g++ # Add a new user RUN useradd -m -s /bin/bash torchopt && \ @@ -30,7 +30,7 @@ USER torchopt RUN echo "export PS1='[\[\e[1;33m\]\u\[\e[0m\]:\[\e[1;35m\]\w\[\e[0m\]]\$ '" >> ~/.bashrc # Setup virtual environment -RUN /usr/bin/python3.9 -m venv --upgrade-deps ~/venv && rm -rf ~/.pip/cache +RUN /usr/bin/python3 -m venv --upgrade-deps ~/venv && rm -rf ~/.pip/cache RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cu$(echo "${CUDA_VERSION}" | cut -d'.' -f-2 | tr -d '.')" && \ echo "export PIP_EXTRA_INDEX_URL='${PIP_EXTRA_INDEX_URL}'" >> ~/venv/bin/activate && \ echo "source /home/torchopt/venv/bin/activate" >> ~/.bashrc @@ -48,14 +48,13 @@ FROM builder AS devel-builder # Install extra dependencies RUN sudo apt-get update && \ - sudo apt-get install -y golang-1.16 clang-format clang-tidy && \ - sudo chown -R "$(whoami):$(whoami)" /usr/lib/go-1.16 && \ + sudo apt-get install -y golang clang-format clang-tidy && \ + sudo chown -R "$(whoami):$(whoami)" "$(realpath /usr/lib/go)" && \ sudo rm -rf /var/lib/apt/lists/* # Install addlicense -ENV GOPATH="/usr/lib/go-1.16" -ENV GOBIN="${GOPATH}/bin" -ENV GOROOT="${GOPATH}" +ENV GOROOT="/usr/lib/go" +ENV GOBIN="${GOROOT}/bin" ENV PATH="${GOBIN}:${PATH}" RUN go install github.com/google/addlicense@latest @@ -74,7 +73,7 @@ COPY --chown=torchopt . . # Install TorchOpt RUN source ~/venv/bin/activate && \ - python -m pip install -e . && \ + make install-editable && \ rm -rf .eggs *.egg-info ~/.pip/cache ~/.cache/pip ENTRYPOINT [ "/bin/bash", "--login" ] diff --git a/LICENSE b/LICENSE index 710ed864..185e0144 100644 --- a/LICENSE +++ b/LICENSE @@ -187,7 +187,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright [2022] [MetaOPT Team. All Rights Reserved.] + Copyright [2022-2024] [MetaOPT Team. All Rights Reserved.] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/MANIFEST.in b/MANIFEST.in index 08cf6257..09403999 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,5 @@ recursive-include torchopt *.pyi +recursive-include torchopt *.typed include LICENSE # Include source files in sdist diff --git a/Makefile b/Makefile index ac67d4b8..e9099f0c 100644 --- a/Makefile +++ b/Makefile @@ -1,29 +1,36 @@ -print-% : ; @echo $* = $($*) +print-%: ; @echo $* = $($*) PROJECT_NAME = torchopt COPYRIGHT = "MetaOPT Team. All Rights Reserved." PROJECT_PATH = $(PROJECT_NAME) SHELL = /bin/bash SOURCE_FOLDERS = $(PROJECT_PATH) examples include src tests docs PYTHON_FILES = $(shell find $(SOURCE_FOLDERS) -type f -name "*.py" -o -name "*.pyi") -CXX_FILES = $(shell find $(SOURCE_FOLDERS) -type f -name "*.h" -o -name "*.cpp" -o -name "*.cuh" -o -name "*.cu") +CXX_FILES = $(shell find $(SOURCE_FOLDERS) -type f -name "*.h" -o -name "*.cpp") +CUDA_FILES = $(shell find $(SOURCE_FOLDERS) -type f -name "*.cuh" -o -name "*.cu") COMMIT_HASH = $(shell git log -1 --format=%h) PATH := $(HOME)/go/bin:$(PATH) PYTHON ?= $(shell command -v python3 || command -v python) +CLANG_FORMAT ?= $(shell command -v clang-format-17 || command -v clang-format) +PYTESTOPTS ?= .PHONY: default default: install install: - $(PYTHON) -m pip install . + $(PYTHON) -m pip install -vvv . install-editable: $(PYTHON) -m pip install --upgrade pip $(PYTHON) -m pip install --upgrade setuptools wheel - $(PYTHON) -m pip install torch numpy pybind11 + $(PYTHON) -m pip install --upgrade pybind11 cmake + $(PYTHON) -m pip install torch numpy USE_FP16=ON TORCH_CUDA_ARCH_LIST=Auto $(PYTHON) -m pip install -vvv --no-build-isolation --editable . install-e: install-editable # alias +uninstall: + $(PYTHON) -m pip uninstall -y $(PROJECT_NAME) + build: $(PYTHON) -m pip install --upgrade pip $(PYTHON) -m pip install --upgrade setuptools wheel build @@ -35,15 +42,23 @@ check_pip_install = $(PYTHON) -m pip show $(1) &>/dev/null || (cd && $(PYTHON) - check_pip_install_extra = $(PYTHON) -m pip show $(1) &>/dev/null || (cd && $(PYTHON) -m pip install $(2) --upgrade) pylint-install: - $(call check_pip_install,pylint) + $(call check_pip_install_extra,pylint,pylint[spelling]) + $(call check_pip_install,pyenchant) flake8-install: $(call check_pip_install,flake8) - $(call check_pip_install_extra,bugbear,flake8_bugbear) + $(call check_pip_install,flake8-bugbear) + $(call check_pip_install,flake8-comprehensions) + $(call check_pip_install,flake8-docstrings) + $(call check_pip_install,flake8-pyi) + $(call check_pip_install,flake8-simplify) py-format-install: $(call check_pip_install,isort) - $(call check_pip_install,black) + $(call check_pip_install_extra,black,black[jupyter]) + +ruff-install: + $(call check_pip_install,ruff) mypy-install: $(call check_pip_install,mypy) @@ -53,7 +68,7 @@ pre-commit-install: $(PYTHON) -m pre_commit install --install-hooks docs-install: - $(call check_pip_install,pydocstyle) + $(call check_pip_install_extra,pydocstyle,pydocstyle[toml]) $(call check_pip_install,doc8) $(call check_pip_install,sphinx) $(call check_pip_install,sphinx-rtd-theme) @@ -63,78 +78,121 @@ docs-install: $(call check_pip_install,sphinxcontrib-katex) $(call check_pip_install,sphinxcontrib-bibtex) $(call check_pip_install,sphinx-autodoc-typehints) - $(call check_pip_install,myst_nb) - $(call check_pip_install_extra,sphinxcontrib.spelling,sphinxcontrib.spelling pyenchant) + $(call check_pip_install,myst-nb) + $(call check_pip_install_extra,sphinxcontrib-spelling,sphinxcontrib-spelling pyenchant) pytest-install: $(call check_pip_install,pytest) - $(call check_pip_install,pytest_cov) - $(call check_pip_install,pytest_xdist) + $(call check_pip_install,pytest-cov) + $(call check_pip_install,pytest-xdist) + +test-install: pytest-install + $(PYTHON) -m pip install --requirement tests/requirements.txt + +cmake-install: + command -v cmake || $(call check_pip_install,cmake) cpplint-install: $(call check_pip_install,cpplint) clang-format-install: - command -v clang-format || sudo apt-get install -y clang-format + command -v clang-format-17 || command -v clang-format || \ + sudo apt-get install -y clang-format-17 || \ + sudo apt-get install -y clang-format clang-tidy-install: command -v clang-tidy || sudo apt-get install -y clang-tidy go-install: # requires go >= 1.16 - command -v go || (sudo apt-get install -y golang-1.16 && sudo ln -sf /usr/lib/go-1.16/bin/go /usr/bin/go) + command -v go || (sudo apt-get install -y golang && sudo ln -sf /usr/lib/go/bin/go /usr/bin/go) addlicense-install: go-install command -v addlicense || go install github.com/google/addlicense@latest # Tests -pytest: pytest-install - cd tests && \ - $(PYTHON) -m pytest --verbose --color=yes --durations=0 \ - --cov="$(PROJECT_NAME)" --cov-report=xml --cov-report=term-missing \ - . +pytest: test-install + $(PYTHON) -m pytest --version + cd tests && $(PYTHON) -c 'import $(PROJECT_PATH)' && \ + $(PYTHON) -m pytest --verbose --color=yes \ + --cov="$(PROJECT_PATH)" --cov-config=.coveragerc --cov-report=xml --cov-report=term-missing \ + $(PYTESTOPTS) . test: pytest # Python linters pylint: pylint-install + $(PYTHON) -m pylint --version $(PYTHON) -m pylint $(PROJECT_PATH) flake8: flake8-install - $(PYTHON) -m flake8 $(PYTHON_FILES) --count --select=E9,F63,F7,F82,E225,E251 --show-source --statistics + $(PYTHON) -m flake8 --version + $(PYTHON) -m flake8 --count --show-source --statistics py-format: py-format-install - $(PYTHON) -m isort --project torchopt --check $(PYTHON_FILES) && \ - $(PYTHON) -m black --check $(PYTHON_FILES) + $(PYTHON) -m isort --version + $(PYTHON) -m black --version + $(PYTHON) -m isort --project $(PROJECT_PATH) --check $(PYTHON_FILES) && \ + $(PYTHON) -m black --check $(PYTHON_FILES) tutorials + +ruff: ruff-install + $(PYTHON) -m ruff --version + $(PYTHON) -m ruff check . + +ruff-fix: ruff-install + $(PYTHON) -m ruff --version + $(PYTHON) -m ruff check . --fix --exit-non-zero-on-fix mypy: mypy-install - $(PYTHON) -m mypy $(PROJECT_PATH) + $(PYTHON) -m mypy --version + $(PYTHON) -m mypy $(PROJECT_PATH) --install-types --non-interactive pre-commit: pre-commit-install + $(PYTHON) -m pre_commit --version $(PYTHON) -m pre_commit run --all-files # C++ linters +cmake-configure: cmake-install + cmake --version + cmake -S . -B cmake-build-debug \ + -DCMAKE_BUILD_TYPE=Debug \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ + -DPYTHON_EXECUTABLE="$(PYTHON)" + +cmake-build: cmake-configure + cmake --build cmake-build-debug --parallel + +cmake: cmake-build + cpplint: cpplint-install - $(PYTHON) -m cpplint $(CXX_FILES) + $(PYTHON) -m cpplint --version + $(PYTHON) -m cpplint $(CXX_FILES) $(CUDA_FILES) clang-format: clang-format-install - clang-format --style=file -i $(CXX_FILES) -n --Werror + $(CLANG_FORMAT) --version + $(CLANG_FORMAT) --style=file -i $(CXX_FILES) $(CUDA_FILES) -n --Werror + +clang-tidy: clang-tidy-install cmake-configure + clang-tidy --version + clang-tidy --extra-arg="-v" -p=cmake-build-debug $(CXX_FILES) # Documentation addlicense: addlicense-install - addlicense -c $(COPYRIGHT) -l apache -y 2022 -check $(SOURCE_FOLDERS) + addlicense -c $(COPYRIGHT) -ignore tests/coverage.xml -l apache -y 2022-$(shell date +"%Y") -check $(SOURCE_FOLDERS) docstyle: docs-install + make -C docs clean $(PYTHON) -m pydocstyle $(PROJECT_PATH) && doc8 docs && make -C docs html SPHINXOPTS="-W" docs: docs-install $(PYTHON) -m sphinx_autobuild --watch $(PROJECT_PATH) --open-browser docs/source docs/build spelling: docs-install + make -C docs clean make -C docs spelling SPHINXOPTS="-W" clean-docs: @@ -142,18 +200,23 @@ clean-docs: # Utility functions -lint: flake8 py-format mypy clang-format cpplint docstyle spelling +lint: ruff flake8 py-format mypy pylint clang-format clang-tidy cpplint addlicense docstyle spelling -format: py-format-install clang-format-install addlicense-install - $(PYTHON) -m isort --project torchopt $(PYTHON_FILES) - $(PYTHON) -m black $(PYTHON_FILES) - clang-format -style=file -i $(CXX_FILES) - addlicense -c $(COPYRIGHT) -l apache -y 2022 $(SOURCE_FOLDERS) +format: py-format-install ruff-install clang-format-install addlicense-install + $(PYTHON) -m isort --project $(PROJECT_PATH) $(PYTHON_FILES) + $(PYTHON) -m black $(PYTHON_FILES) tutorials + $(PYTHON) -m ruff check . --fix --exit-zero + $(CLANG_FORMAT) -style=file -i $(CXX_FILES) $(CUDA_FILES) + addlicense -c $(COPYRIGHT) -ignore tests/coverage.xml -l apache -y 2022-$(shell date +"%Y") $(SOURCE_FOLDERS) clean-py: find . -type f -name '*.py[co]' -delete + find . -depth -type d -name "__pycache__" -exec rm -r "{}" + + find . -depth -type d -name ".ruff_cache" -exec rm -r "{}" + find . -depth -type d -name ".mypy_cache" -exec rm -r "{}" + find . -depth -type d -name ".pytest_cache" -exec rm -r "{}" + + rm tests/.coverage + rm tests/coverage.xml clean-build: rm -rf build/ dist/ @@ -173,5 +236,12 @@ docker-devel: docker: docker-base docker-devel +docker-run-base: docker-base + docker run -it --network=host --gpus=all --shm-size=4gb -v /:/host -v "${PWD}:/workspace" \ + -h ubuntu -w /workspace $(PROJECT_NAME):$(COMMIT_HASH) + docker-run-devel: docker-devel - docker run --network=host --gpus=all -v /:/host -h ubuntu -it $(PROJECT_NAME)-devel:$(COMMIT_HASH) + docker run -it --network=host --gpus=all --shm-size=4gb -v /:/host -v "${PWD}:/workspace" \ + -h ubuntu -w /workspace $(PROJECT_NAME)-devel:$(COMMIT_HASH) + +docker-run: docker-run-base diff --git a/README.md b/README.md index 13d005f5..91d44a25 100644 --- a/README.md +++ b/README.md @@ -1,23 +1,40 @@ +
-![Python 3.7+](https://img.shields.io/badge/Python-3.7%2B-brightgreen.svg) -[![PyPI](https://img.shields.io/pypi/v/torchopt?label=PyPI)](https://pypi.org/project/torchopt) -![Status](https://img.shields.io/pypi/status/torchopt?label=Status) -![GitHub Workflow Status](https://img.shields.io/github/workflow/status/metaopt/torchopt/Tests?label=tests&logo=github) -[![Documentation Status](https://readthedocs.org/projects/torchopt/badge/?version=latest)](https://torchopt.readthedocs.io/en/latest/?badge=latest) -[![Downloads](https://static.pepy.tech/personalized-badge/torchopt?period=month&left_color=grey&right_color=blue&left_text=Downloads/month)](https://pepy.tech/project/torchopt) -[![GitHub Repo Stars](https://img.shields.io/github/stars/metaopt/torchopt?label=Stars&logo=github&color=brightgreen)](https://github.com/metaopt/torchopt/stargazers) -[![License](https://img.shields.io/github/license/metaopt/torchopt?label=License)](#license) +
+ + ![Python 3.8+](https://img.shields.io/badge/Python-3.8%2B-brightgreen.svg) + ![PyPI](https://img.shields.io/pypi/v/torchopt?logo=pypi) + ![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/metaopt/torchopt/tests.yml?label=tests&logo=github) + ![CodeCov](https://img.shields.io/codecov/c/github/metaopt/torchopt/main?logo=codecov) + ![Documentation Status](https://img.shields.io/readthedocs/torchopt?logo=readthedocs) + ![Downloads](https://static.pepy.tech/personalized-badge/torchopt?period=total&left_color=grey&right_color=blue&left_text=downloads) + ![License](https://img.shields.io/github/license/metaopt/torchopt?label=license&logo=) +
+ +

+ Installation | + Documentation | + Tutorials | + Examples | + Paper | + Citation +

+ +**TorchOpt** is an efficient library for differentiable optimization built upon [PyTorch](https://pytorch.org). +TorchOpt is: -**TorchOpt** is a high-performance optimizer library built upon [PyTorch](https://pytorch.org/) for easy implementation of functional optimization and gradient-based meta-learning. It consists of two main features: +- **Comprehensive**: TorchOpt provides three differentiation modes - explicit differentiation, implicit differentiation, and zero-order differentiation for handling different differentiable optimization situations. +- **Flexible**: TorchOpt provides both functional and objective-oriented API for users' different preferences. Users can implement differentiable optimization in JAX-like or PyTorch-like style. +- **Efficient**: TorchOpt provides (1) CPU/GPU acceleration differentiable optimizer (2) RPC-based distributed training framework (3) Fast Tree Operations, to largely increase the training efficiency for bi-level optimization problems. -- TorchOpt provides functional optimizer which enables [JAX-like](https://github.com/google/jax) composable functional optimizer for PyTorch. With TorchOpt, one can easily conduct neural network optimization in PyTorch with functional style optimizer, similar to [Optax](https://github.com/deepmind/optax) in JAX. -- With the design of functional programing, TorchOpt provides efficient, flexible, and easy-to-implement differentiable optimizer for gradient-based meta-learning research. It largely reduces the efforts required to implement sophisticated meta-learning algorithms. +Beyond differentiable optimization, TorchOpt can also be regarded as a functional optimizer that enables [JAX-like](https://github.com/google/jax) composable functional optimizer for PyTorch. +With TorchOpt, users can easily conduct neural network optimization in PyTorch with a functional style optimizer, similar to [Optax](https://github.com/deepmind/optax) in JAX. -------------------------------------------------------------------------------- @@ -27,36 +44,37 @@ The README is organized as follows: - [Optax-Like API](#optax-like-api) - [PyTorch-Like API](#pytorch-like-api) - [Differentiable](#differentiable) -- [TorchOpt as Differentiable Optimizer for Meta-Learning](#torchopt-as-differentiable-optimizer-for-meta-learning) - - [Meta-Learning API](#meta-learning-api) -- [Examples](#examples) -- [High-Performance](#high-performance) +- [TorchOpt for Differentiable Optimization](#torchopt-for-differentiable-optimization) + - [Explicit Gradient (EG)](#explicit-gradient-eg) + - [Implicit Gradient (IG)](#implicit-gradient-ig) + - [Zero-order Differentiation (ZD)](#zero-order-differentiation-zd) +- [High-Performance and Distributed Training](#high-performance-and-distributed-training) + - [CPU/GPU accelerated differentiable optimizer](#cpugpu-accelerated-differentiable-optimizer) + - [Distributed Training](#distributed-training) + - [OpTree](#optree) - [Visualization](#visualization) +- [Examples](#examples) - [Installation](#installation) -- [Future Plan](#future-plan) - [Changelog](#changelog) -- [The Team](#the-team) - [Citing TorchOpt](#citing-torchopt) +- [The Team](#the-team) +- [License](#license) -------------------------------------------------------------------------------- ## TorchOpt as Functional Optimizer -The design of TorchOpt follows the philosophy of functional programming. Aligned with [`functorch`](https://github.com/pytorch/functorch), users can conduct functional style programing with models, optimizers and training in PyTorch. We use the Adam optimizer as an example in the following illustration. You can also check out the tutorial notebook [Functional Optimizer](tutorials/1_Functional_Optimizer.ipynb) for more details. +The design of TorchOpt follows the philosophy of functional programming. +Aligned with [`functorch`](https://github.com/pytorch/functorch), users can conduct functional style programming with models, optimizers and training in PyTorch. +We use the Adam optimizer as an example in the following illustration. +You can also check out the tutorial notebook [Functional Optimizer](tutorials/1_Functional_Optimizer.ipynb) for more details. ### Optax-Like API -For those users who prefer fully functional programing, we offer Optax-Like API by passing gradients and optimizers states to the optimizer function. We design base class `torchopt.Optimizer` that has the same interface as `torch.optim.Optimizer`. Here is an example coupled with `functorch`: +For those users who prefer fully functional programming, we offer Optax-Like API by passing gradients and optimizer states to the optimizer function. +Here is an example coupled with `functorch`: ```python -import functorch -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.utils.data import DataLoader - -import torchopt - class Net(nn.Module): ... class Loader(DataLoader): ... @@ -77,9 +95,26 @@ updates, opt_state = optimizer.update(grads, opt_state) # get updates params = torchopt.apply_updates(params, updates) # update network parameters ``` +We also provide a wrapper `torchopt.FuncOptimizer` to make maintaining the optimizer state easier: + +```python +net = Net() # init +loader = Loader() +optimizer = torchopt.FuncOptimizer(torchopt.adam()) # wrap with `torchopt.FuncOptimizer` + +model, params = functorch.make_functional(net) # use functorch extract network parameters + +for xs, ys in loader: # get data + pred = model(params, xs) # forward + loss = F.cross_entropy(pred, ys) # compute loss + + params = optimizer.step(loss, params) # update network parameters +``` + ### PyTorch-Like API -We also offer origin PyTorch APIs (e.g. `zero_grad()` or `step()`) by wrapping our Optax-Like API for traditional PyTorch user: +We also design a base class `torchopt.Optimizer` that has the same interface as `torch.optim.Optimizer`. +We offer origin PyTorch APIs (e.g. `zero_grad()` or `step()`) by wrapping our Optax-Like API for traditional PyTorch users. ```python net = Net() # init @@ -97,137 +132,304 @@ optimizer.step() # step updates ### Differentiable -On top of the same optimization function as `torch.optim`, an important benefit of functional optimizer is that one can implement differentiable optimization easily. This is particularly helpful when the algorithm requires to differentiate through optimization update (such as meta learning practices). We take as the inputs the gradients and optimizer states, use non-in-place operators to compute and output the updates. The processes can be automatically implemented, with the only need from users being to pass the argument `inplace=False` to the functions: - -```python -# Get updates -updates, opt_state = optimizer.update(grad, opt_state, inplace=False) -# Update network parameters -params = torchopt.apply_updates(params, updates, inplace=False) -``` +On top of the same optimization function as `torch.optim`, an important benefit of the functional optimizer is that one can implement differentiable optimization easily. +This is particularly helpful when the algorithm requires differentiation through optimization updates (such as meta-learning practices). +We take as the inputs the gradients and optimizer states, and use non-in-place operators to compute and output the updates. +The processes can be automatically implemented, with the only need from users being to pass the argument `inplace=False` to the functions. +Check out the section [Explicit Gradient (EG)](#explicit-gradient-eg) functional API for example. -------------------------------------------------------------------------------- -## TorchOpt as Differentiable Optimizer for Meta-Learning +## TorchOpt for Differentiable Optimization -Meta-Learning has gained enormous attention in both Supervised Learning and Reinforcement Learning. Meta-Learning algorithms often contain a bi-level optimization process with *inner loop* updating the network parameters and *outer loop* updating meta parameters. The figure below illustrates the basic formulation for meta-optimization in Meta-Learning. The main feature is that the gradients of *outer loss* will back-propagate through all `inner.step` operations. +We design a bilevel-optimization updating scheme, which can be easily extended to realize various differentiable optimization processes.
- +
-Since network parameters become a node of computation graph, a flexible Meta-Learning library should enable users manually control the gradient graph connection which means that users should have access to the network parameters and optimizer states for manually detaching or connecting the computation graph. In PyTorch designing, the network parameters or optimizer states are members of network (a.k.a. `torch.nn.Module`) or optimizer (a.k.a. `torch.optim.Optimizer`), this design significantly introducing difficulty for user control network parameters or optimizer states. Previous differentiable optimizer Repo [`higher`](https://github.com/facebookresearch/higher), [`learn2learn`](https://github.com/learnables/learn2learn) follows the PyTorch designing which leads to inflexible API. +As shown above, the scheme contains an outer level that has parameters $\phi$ that can be learned end-to-end through the inner level parameters solution $\theta^{\prime}(\phi)$ by using the best-response derivatives $\partial \theta^{\prime}(\phi) / \partial \phi$. +TorchOpt supports three differentiation modes. +It can be seen that the key component of this algorithm is to calculate the best-response (BR) Jacobian. +From the BR-based perspective, existing gradient methods can be categorized into three groups: explicit gradient over unrolled optimization, implicit differentiation, and zero-order gradient differentiation. -In contrast to them, TorchOpt realizes differentiable optimizer with functional programing, where Meta-Learning researchers could control the network parameters or optimizer states as normal variables (a.k.a. `torch.Tensor`). This functional optimizer design of TorchOpt is beneficial for implementing complex gradient flow Meta-Learning algorithms and allow us to improve computational efficiency by using techniques like operator fusion. +### Explicit Gradient (EG) -### Meta-Learning API +The idea of the explicit gradient is to treat the gradient step as a differentiable function and try to backpropagate through the unrolled optimization path. +This differentiation mode is suitable for algorithms when the inner-level optimization solution is obtained by a few gradient steps, such as [MAML](https://arxiv.org/abs/1703.03400) and [MGRL](https://arxiv.org/abs/1805.09801). +TorchOpt offers both functional and object-oriented API for EG to fit different user applications. -- We design a base class `torchopt.MetaOptimizer` for managing network updates in Meta-Learning. The constructor of `MetaOptimizer` takes as input the network rather than network parameters. `MetaOptimizer` exposed interface `step(loss)` takes as input the loss for step the network parameter. Refer to the tutorial notebook [Meta Optimizer](tutorials/3_Meta_Optimizer.ipynb) for more details. -- We offer `torchopt.chain` which can apply a list of chainable update transformations. Combined with `MetaOptimizer`, it can help you conduct gradient transformation such as gradient clip before the Meta optimizer steps. Refer to the tutorial notebook [Meta Optimizer](tutorials/3_Meta_Optimizer.ipynb) for more details. -- We observe that different Meta-Learning algorithms vary in inner-loop parameter recovery. TorchOpt provides basic functions for users to extract or recover network parameters and optimizer states anytime anywhere they want. -- Some algorithms such as MGRL ([arXiv:1805.09801](https://arxiv.org/abs/1805.09801)) initialize the inner-loop parameters inherited from previous inner-loop process when conducting a new bi-level process. TorchOpt also provides a finer function `stop_gradient` for manipulating the gradient graph, which is helpful for this kind of algorithms. Refer to the notebook [Stop Gradient](tutorials/4_Stop_Gradient.ipynb) for more details. +#### Functional API -We give an example of MAML ([arXiv:1703.03400](https://arxiv.org/abs/1703.03400)) with inner-loop Adam optimizer to illustrate TorchOpt APIs: +The functional API is to conduct optimization in a functional programming style. +Note that we pass the argument `inplace=False` to the functions to make the optimization differentiable. +Refer to the tutorial notebook [Functional Optimizer](tutorials/1_Functional_Optimizer.ipynb) for more guidance. ```python -net = Net() # init +# Define functional optimizer +optimizer = torchopt.adam() +# Define meta and inner parameters +meta_params = ... +fmodel, params = make_functional(model) +# Initial state +state = optimizer.init(params) + +for iter in range(iter_times): + loss = inner_loss(fmodel, params, meta_params) + grads = torch.autograd.grad(loss, params) + # Apply non-inplace parameter update + updates, state = optimizer.update(grads, state, inplace=False) + params = torchopt.apply_updates(params, updates) + +loss = outer_loss(fmodel, params, meta_params) +meta_grads = torch.autograd.grad(loss, meta_params) +``` + +#### OOP API -# The constructor `MetaOptimizer` takes as input the network -inner_optim = torchopt.MetaAdam(net) -outer_optim = torchopt.Adam(net.parameters()) - -for train_iter in range(train_iters): - outer_loss = 0 - for task in range(tasks): - loader = Loader(tasks) - - # Store states at the initial points - net_state = torchopt.extract_state_dict(net) # extract state - optim_state = torchopt.extract_state_dict(inner_optim) - for inner_iter in range(inner_iters): - # Compute inner loss and perform inner update - xs, ys = next(loader) - pred = net(xs) - inner_loss = F.cross_entropy(pred, ys) - inner_optim.step(inner_loss) - - # Compute outer loss and back-propagate - xs, ys = next(loader) - pred = net(xs) - outer_loss = outer_loss + F.cross_entropy(pred, ys) - - # Recover network and optimizer states at the initial point for the next task - torchopt.recover_state_dict(inner_optim, optim_state) - torchopt.recover_state_dict(net, net_state) - - outer_loss = outer_loss / len(tasks) # task average - outer_optim.zero_grad() - outer_loss.backward() - outer_optim.step() - - # Stop gradient if necessary - torchopt.stop_gradient(net) - torchopt.stop_gradient(inner_optim) +TorchOpt also provides OOP API compatible with the PyTorch programming style. +Refer to the example and the tutorial notebook [Meta-Optimizer](tutorials/3_Meta_Optimizer.ipynb), [Stop Gradient](tutorials/4_Stop_Gradient.ipynb) for more guidance. + +```python +# Define meta and inner parameters +meta_params = ... +model = ... +# Define differentiable optimizer +optimizer = torchopt.MetaAdam(model) # a model instance as the argument instead of model.parameters() + +for iter in range(iter_times): + # Perform inner update + loss = inner_loss(model, meta_params) + optimizer.step(loss) + +loss = outer_loss(model, meta_params) +loss.backward() ``` --------------------------------------------------------------------------------- +### Implicit Gradient (IG) -## Examples +By treating the solution $\theta^{\prime}$ as an implicit function of $\phi$, the idea of IG is to directly get analytical best-response derivatives $\partial \theta^{\prime} (\phi) / \partial \phi$ by [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem). +This is suitable for algorithms when the inner-level optimal solution is achieved ${\left. \frac{\partial F (\theta, \phi)}{\partial \theta} \right\rvert}_{\theta=\theta^{\prime}} = 0$ or reaches some stationary conditions $F (\theta^{\prime}, \phi) = 0$, such as [iMAML](https://arxiv.org/abs/1909.04630) and [DEQ](https://arxiv.org/abs/1909.01377). +TorchOpt offers both functional and OOP APIs for supporting both [conjugate gradient-based](https://arxiv.org/abs/1909.04630) and [Neumann series-based](https://arxiv.org/abs/1911.02590) IG methods. +Refer to the example [iMAML](https://github.com/waterhorse1/torchopt/tree/readme/examples/iMAML) and the notebook [Implicit Gradient](tutorials/5_Implicit_Differentiation.ipynb) for more guidance. -In [`examples`](examples), we offer several examples of functional optimizer and 5 light-weight meta-learning examples with TorchOpt. The meta-learning examples covers 2 Supervised Learning and 3 Reinforcement Learning algorithms. +#### Functional API -- [Model Agnostic Meta Learning (MAML) - Supervised Learning](https://arxiv.org/abs/1703.03400) (ICML 2017) -- [Learning to Reweight Examples for Robust Deep Learning](https://arxiv.org/abs/1803.09050) (ICML 2018) -- [Model Agnostic Meta Learning (MAML) - Reinforcement Learning](https://arxiv.org/abs/1703.03400) (ICML 2017) -- [Meta Gradient Reinforcement Learning (MGRL)](https://arxiv.org/abs/1805.09801) (NeurIPS 2018) -- [Learning through opponent learning process (LOLA)](https://arxiv.org/abs/1709.04326) (AAMAS 2018) +For the implicit gradient, similar to [JAXopt](https://jaxopt.github.io/stable/implicit_diff.html), users need to define the stationary condition and TorchOpt provides the decorator to wrap the solve function for enabling implicit gradient computation. + +```python +# The stationary condition for the inner-loop +def stationary(params, meta_params, data): + # Stationary condition construction + return stationary condition + +# Decorator for wrapping the function +# Optionally specify the linear solver (conjugate gradient or Neumann series) +@torchopt.diff.implicit.custom_root(stationary, solve=linear_solver) +def solve(params, meta_params, data): + # Forward optimization process for params + return output + +# Define params, meta_params and get data +params, meta_prams, data = ..., ..., ... +optimal_params = solve(params, meta_params, data) +loss = outer_loss(optimal_params) + +meta_grads = torch.autograd.grad(loss, meta_params) +``` + +#### OOP API + +TorchOpt also offers an OOP API, which users need to inherit from the class `torchopt.nn.ImplicitMetaGradientModule` to construct the inner-loop network. +Users need to define the stationary condition/objective function and the inner-loop solve function to enable implicit gradient computation. + +```python +# Inherited from the class ImplicitMetaGradientModule +# Optionally specify the linear solver (conjugate gradient or Neumann series) +class InnerNet(ImplicitMetaGradientModule, linear_solve=linear_solver): + def __init__(self, meta_param): + super().__init__() + self.meta_param = meta_param + ... + + def forward(self, batch): + # Forward process + ... + + def optimality(self, batch, labels): + # Stationary condition construction for calculating implicit gradient + # NOTE: If this method is not implemented, it will be automatically + # derived from the gradient of the `objective` function. + ... + + def objective(self, batch, labels): + # Define the inner-loop optimization objective + ... + + def solve(self, batch, labels): + # Conduct the inner-loop optimization + ... + +# Get meta_params and data +meta_params, data = ..., ... +inner_net = InnerNet(meta_params) + +# Solve for inner-loop process related to the meta-parameters +optimal_inner_net = inner_net.solve(data) + +# Get outer loss and solve for meta-gradient +loss = outer_loss(optimal_inner_net) +meta_grads = torch.autograd.grad(loss, meta_params) +``` + +### Zero-order Differentiation (ZD) + +When the inner-loop process is non-differentiable or one wants to eliminate the heavy computation burdens in the previous two modes (brought by Hessian), one can choose Zero-order Differentiation (ZD). +ZD typically gets gradients based on zero-order estimation, such as finite-difference, or [Evolutionary Strategy](https://arxiv.org/abs/1703.03864). +Instead of optimizing the objective $F$, ES optimizes a smoothed objective. +TorchOpt provides both functional and OOP APIs for the ES method. +Refer to the tutorial notebook [Zero-order Differentiation](tutorials/6_Zero_Order_Differentiation.ipynb) for more guidance. + +#### Functional API + +For zero-order differentiation, users need to define the forward pass calculation and the noise sampling procedure. TorchOpt provides the decorator to wrap the forward function for enabling zero-order differentiation. + +```python +# Customize the noise sampling function in ES +def distribution(sample_shape): + # Generate a batch of noise samples + # NOTE: The distribution should be spherical symmetric and with a constant variance of 1. + ... + return noise_batch + +# Distribution can also be an instance of `torch.distributions.Distribution`, e.g., `torch.distributions.Normal(...)` +distribution = torch.distributions.Normal(loc=0, scale=1) + +# Specify method and hyper-parameter of ES +@torchopt.diff.zero_order(distribution, method) +def forward(params, batch, labels): + # Forward process + ... + return objective # the returned tensor should be a scalar tensor +``` + +#### OOP API + +TorchOpt also offers an OOP API, which users need to inherit from the class `torchopt.nn.ZeroOrderGradientModule` to construct the network as an `nn.Module` following a classical PyTorch style. +Users need to define the forward process zero-order gradient procedures `forward()` and a noise sampling function `sample()`. + +```python +# Inherited from the class ZeroOrderGradientModule +# Optionally specify the `method` and/or `num_samples` and/or `sigma` used for sampling +class Net(ZeroOrderGradientModule, method=method, num_samples=num_samples, sigma=sigma): + def __init__(self, ...): + ... + + def forward(self, batch): + # Forward process + ... + return objective # the returned tensor should be a scalar tensor + + def sample(self, sample_shape=torch.Size()): + # Generate a batch of noise samples + # NOTE: The distribution should be spherical symmetric and with a constant variance of 1. + ... + return noise_batch + +# Get model and data +net = Net(...) +data = ... + +# Forward pass +loss = Net(data) +# Backward pass using zero-order differentiation +grads = torch.autograd.grad(loss, net.parameters()) +``` -------------------------------------------------------------------------------- -## High-Performance +## High-Performance and Distributed Training -One can think of the scale procedures on gradients of optimizer algorithms as a combination of several operations. For example, the implementation of the Adam algorithm often includes addition, multiplication, power and square operations, one can fuse these operations into several compound functions. The operator fusion could greatly simplify the computation graph and reduce the GPU function launching stall. In addition, one can also implement the optimizer backward function and manually reuse some intermediate tensors to improve the backward performance. Users can pass argument `use_accelerated_op=True` to `adam`, `Adam` and `MetaAdam` to enable the fused accelerated operator. The arguments are the same between the two kinds of implementations. +### CPU/GPU accelerated differentiable optimizer -Here we evaluate the performance using the MAML-Omniglot code with the inner-loop Adam optimizer on GPU. We comparable the run time of the overall algorithm and the meta-optimization (outer-loop optimization) under different network architecture/inner-step numbers. We choose [`higher`](https://github.com/facebookresearch/higher) as our baseline. The figure below illustrate that our accelerated Adam can achieve at least $1/3$ efficiency improvement over the baseline. +We take the optimizer as a whole instead of separating it into several basic operators (e.g., `sqrt` and `div`). +Therefore, by manually writing the forward and backward functions, we can perform the symbolic reduction. +In addition, we can store some intermediate data that can be reused during the backpropagation. +We write the accelerated functions in C++ OpenMP and CUDA, bind them by [`pybind11`](https://github.com/pybind/pybind11) to allow they can be called by Python, and then define the forward and backward behavior using `torch.autograd.Function`. +Users can use it by simply setting the `use_accelerated_op` flag as `True`. +Refer to the corresponding sections in the tutorials [Functional Optimizer](tutorials/1_Functional_Optimizer.ipynb)](tutorials/1_Functional_Optimizer.ipynb) and [Meta-Optimizer](tutorials/3_Meta_Optimizer.ipynb) -
- -
+```python +optimizer = torchopt.MetaAdam(model, lr, use_accelerated_op=True) +``` + +### Distributed Training + +`TorchOpt` provides distributed training features based on the PyTorch RPC module for better training speed and multi-node multi-GPU support. +Different from the MPI-like parallelization paradigm, which uses multiple homogeneous workers and requires carefully designed communication hooks, the RPC APIs allow users to build their optimization pipeline more flexibly. +Experimental results show that we achieve an approximately linear relationship between the speed-up ratio and the number of workers. +Check out the [Distributed Training Documentation](https://torchopt.readthedocs.io/en/latest/distributed/distributed.html) and [distributed MAML example](https://github.com/metaopt/torchopt/tree/main/examples/distributed/few-shot) for more specific guidance. -Notably, the operator fusion not only increases performance but also help simplify the computation graph, which will be discussed in the next section. +### OpTree + +We implement the *PyTree* to enable fast nested structure flattening using C++. +The tree operations (e.g., flatten and unflatten) are very important in enabling functional and Just-In-Time (JIT) features of deep learning frameworks. +By implementing it in C++, we can use some cache/memory-friendly structures (e.g., `absl::InlinedVector`) to improve the performance. +For more guidance and comparison results, please refer to our open-source project [`OpTree`](https://github.com/metaopt/optree). -------------------------------------------------------------------------------- ## Visualization -Complex gradient flow in meta-learning brings in a great challenge for managing the gradient flow and verifying the correctness of it. TorchOpt provides a visualization tool that draw variable (e.g. network parameters or meta parameters) names on the gradient graph for better analyzing. The visualization tool is modified from [`torchviz`](https://github.com/szagoruyko/pytorchviz). We provide an example using the [visualization code](examples/visualize.py). Also refer to the notebook [Visualization](tutorials/2_Visualization.ipynb) for more details. +Complex gradient flow in meta-learning brings in a great challenge for managing the gradient flow and verifying its correctness of it. +TorchOpt provides a visualization tool that draws variable (e.g., network parameters or meta-parameters) names on the gradient graph for better analysis. +The visualization tool is modified from [`torchviz`](https://github.com/szagoruyko/pytorchviz). +Refer to the example [visualization code](examples/visualize.py) and the tutorial notebook [Visualization](tutorials/2_Visualization.ipynb) for more details. -The figure below show the visualization result. Compared with [`torchviz`](https://github.com/szagoruyko/pytorchviz), TorchOpt fuses the operations within the `Adam` together (orange) to reduce the complexity and provide simpler visualization. +The figure below shows the visualization result. +Compared with [`torchviz`](https://github.com/szagoruyko/pytorchviz), TorchOpt fuses the operations within the `Adam` together (orange) to reduce the complexity and provide simpler visualization.
- +
-------------------------------------------------------------------------------- +## Examples + +In the [`examples`](examples) directory, we offer several examples of functional optimizers and lightweight meta-learning examples with TorchOpt. + +- [Model-Agnostic Meta-Learning (MAML) - Supervised Learning](https://arxiv.org/abs/1703.03400) (ICML 2017) +- [Learning to Reweight Examples for Robust Deep Learning](https://arxiv.org/abs/1803.09050) (ICML 2018) +- [Model-Agnostic Meta-Learning (MAML) - Reinforcement Learning](https://arxiv.org/abs/1703.03400) (ICML 2017) +- [Meta-Gradient Reinforcement Learning (MGRL)](https://arxiv.org/abs/1805.09801) (NeurIPS 2018) +- [Learning through opponent learning process (LOLA)](https://arxiv.org/abs/1709.04326) (AAMAS 2018) +- [Meta-Learning with Implicit Gradients](https://arxiv.org/abs/1909.04630) (NeurIPS 2019) + +Also, check [`examples`](examples) for more distributed/visualization/functorch-compatible examples. + +-------------------------------------------------------------------------------- + ## Installation Requirements - PyTorch - (Optional) For visualizing computation graphs - - [Graphviz](https://graphviz.org/download/) (for Linux users use `apt/yum install graphviz` or `conda install -c anaconda python-graphviz`) + - [Graphviz](https://graphviz.org/download) (for Linux users use `apt/yum install graphviz` or `conda install -c anaconda python-graphviz`) -**Please follow the instructions at to install PyTorch in your Python environment first.** Then run the following command to install TorchOpt from PyPI ([![PyPI](https://img.shields.io/pypi/v/torchopt?label=PyPI)](https://pypi.org/project/torchopt) / ![Status](https://img.shields.io/pypi/status/torchopt?label=Status)): +**Please follow the instructions at to install PyTorch in your Python environment first.** +Then run the following command to install TorchOpt from PyPI ([![PyPI](https://img.shields.io/pypi/v/torchopt?label=pypi&logo=pypi)](https://pypi.org/project/torchopt) / ![Status](https://img.shields.io/pypi/status/torchopt?label=status)): ```bash pip3 install torchopt ``` -If the minimum version of PyTorch is not satisfied, `pip` will install/upgrade it for you. Please be careful about the `torch` build for CPU / CUDA support (e.g. `cpu`, `cu102`, `cu113`). You may need to specify the extra index URL for the `torch` package: +If the minimum version of PyTorch is not satisfied, `pip` will install/upgrade it for you. Please be careful about the `torch` build for CPU / CUDA support (e.g. `cpu`, `cu118`, `cu121`). +You may need to specify the extra index URL for the `torch` package: ```bash -pip3 install torchopt --extra-index-url https://download.pytorch.org/whl/cu116 +pip3 install torchopt --extra-index-url https://download.pytorch.org/whl/cu121 ``` See for more information about installing PyTorch. @@ -240,14 +442,15 @@ cd torchopt pip3 install . ``` -We provide a [conda](https://github.com/conda/conda) environment recipe to install the build toolchain such as `cmake`, `g++`, and `nvcc`: +We provide a [conda](https://github.com/conda/conda) environment recipe to install the build toolchain such as `cmake`, `g++`, and `nvcc`. +You can use the following commands with [`conda`](https://github.com/conda/conda) / [`mamba`](https://github.com/mamba-org/mamba) to create a new isolated environment. ```bash git clone https://github.com/metaopt/torchopt.git cd torchopt # You may need `CONDA_OVERRIDE_CUDA` if conda fails to detect the NVIDIA driver (e.g. in docker or WSL2) -CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe.yaml +CONDA_OVERRIDE_CUDA=12.1 conda env create --file conda-recipe-minimal.yaml conda activate torchopt make install-editable # or run `pip3 install --no-build-isolation --editable .` @@ -255,36 +458,33 @@ make install-editable # or run `pip3 install --no-build-isolation --editable .` -------------------------------------------------------------------------------- -## Future Plan - -- [x] CPU-accelerated optimizer -- [ ] Support general implicit differentiation with functional programing -- [X] Support more optimizers such as AdamW, RMSProp -- [ ] Zero order optimization -- [ ] Distributed optimizers -- [ ] Support `complex` data type - ## Changelog See [CHANGELOG.md](CHANGELOG.md). -------------------------------------------------------------------------------- -## The Team - -TorchOpt is a work by Jie Ren, Xidong Feng, [Bo Liu](https://github.com/Benjamin-eecs), [Xuehai Pan](https://github.com/XuehaiPan), [Luo Mai](https://luomai.github.io/) and [Yaodong Yang](https://www.yangyaodong.com/). - ## Citing TorchOpt If you find TorchOpt useful, please cite it in your publications. ```bibtex -@software{TorchOpt, - author = {Jie Ren and Xidong Feng and Bo Liu and Xuehai Pan and Luo Mai and Yaodong Yang}, - title = {TorchOpt}, - year = {2022}, - publisher = {GitHub}, - journal = {GitHub repository}, - howpublished = {\url{https://github.com/metaopt/torchopt}}, +@article{JMLR:TorchOpt, + author = {Jie Ren* and Xidong Feng* and Bo Liu* and Xuehai Pan* and Yao Fu and Luo Mai and Yaodong Yang}, + title = {TorchOpt: An Efficient Library for Differentiable Optimization}, + journal = {Journal of Machine Learning Research}, + year = {2023}, + volume = {24}, + number = {367}, + pages = {1--14}, + url = {http://jmlr.org/papers/v24/23-0191.html} } ``` + +## The Team + +TorchOpt is a work by [Jie Ren](https://github.com/JieRen98), [Xidong Feng](https://github.com/waterhorse1), [Bo Liu](https://benjamin-eecs.github.io/), [Xuehai Pan](https://github.com/XuehaiPan), [Luo Mai](https://luomai.github.io), and [Yaodong Yang](https://www.yangyaodong.com). + +## License + +TorchOpt is released under the Apache License, Version 2.0. diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 00000000..e1d3aab2 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,12 @@ +coverage: + precision: 2 + round: nearest + status: + project: + default: + target: auto + threshold: 0.05% + patch: + default: + target: 100% + informational: true diff --git a/conda-recipe-minimal-cpu.yaml b/conda-recipe-minimal-cpu.yaml new file mode 100644 index 00000000..dda60369 --- /dev/null +++ b/conda-recipe-minimal-cpu.yaml @@ -0,0 +1,49 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Create virtual environment with command: +# +# $ conda env create --file conda-recipe-minimal-cpu.yaml +# + +name: torchopt + +channels: + - pytorch + - defaults + - conda-forge + +dependencies: + - python = 3.11 + - pip + + # Learning + - pytorch::pytorch >= 2.0 # sync with project.dependencies + - pytorch::torchvision + - pytorch::pytorch-mutex = *=*cpu* + - pip: + - torchviz + + # Build toolchain + - cmake >= 3.11 + - make + - cxx-compiler + - pybind11 >= 2.11.1 + + # Misc + - optree >= 0.4.1 + - typing-extensions + - numpy + - python-graphviz diff --git a/conda-recipe-minimal.yaml b/conda-recipe-minimal.yaml new file mode 100644 index 00000000..7e28d2ef --- /dev/null +++ b/conda-recipe-minimal.yaml @@ -0,0 +1,55 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Create virtual environment with command: +# +# $ CONDA_OVERRIDE_CUDA=12.1 conda env create --file conda-recipe-minimal.yaml +# + +name: torchopt + +channels: + - pytorch + - nvidia/label/cuda-12.1.0 + - defaults + - conda-forge + +dependencies: + - python = 3.11 + - pip + + # Learning + - pytorch::pytorch >= 2.0 # sync with project.dependencies + - pytorch::torchvision + - pytorch::pytorch-mutex = *=*cuda* + - pip: + - torchviz + + # Device select + - nvidia/label/cuda-12.1.0::cuda-toolkit = 12.1 + + # Build toolchain + - cmake >= 3.11 + - make + - cxx-compiler + - nvidia/label/cuda-12.1.0::cuda-nvcc + - nvidia/label/cuda-12.1.0::cuda-cudart-dev + - pybind11 >= 2.11.1 + + # Misc + - optree >= 0.4.1 + - typing-extensions + - numpy + - python-graphviz diff --git a/conda-recipe.yaml b/conda-recipe.yaml index 19229136..9753852b 100644 --- a/conda-recipe.yaml +++ b/conda-recipe.yaml @@ -1,84 +1,103 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# # Create virtual environment with command: # -# $ CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe.yaml +# $ CONDA_OVERRIDE_CUDA=12.1 conda env create --file conda-recipe.yaml # name: torchopt channels: - pytorch + - nvidia/label/cuda-12.1.0 - defaults - - nvidia/label/cuda-11.6.2 - - nvidia - conda-forge dependencies: - - python = 3.8 + - python = 3.11 - pip # Learning - - pytorch::pytorch >= 1.12 + - pytorch::pytorch >= 2.0 # sync with project.dependencies - pytorch::torchvision - pytorch::pytorch-mutex = *=*cuda* - pip: - - functorch >= 0.2 - torchviz - sphinxcontrib-katex # for documentation - - jax # for tutorials - - jaxlib >= 0.3=*cuda* # for tutorials - - optax # for tutorials + - conda-forge::jax # for tutorials + - conda-forge::jaxlib # for tutorials + - conda-forge::optax # for tutorials + - conda-forge::jaxopt # for tests - tensorboard # for examples - - wandb # Device select - - nvidia::cudatoolkit = 11.6 - - cudnn + - nvidia/label/cuda-12.1.0::cuda-toolkit = 12.1 # Build toolchain - - cmake >= 3.4 + - cmake >= 3.11 - make - cxx-compiler - - gxx = 10 - - nvidia/label/cuda-11.6.2::cuda-nvcc - - nvidia/label/cuda-11.6.2::cuda-cudart-dev - - patchelf >= 0.9 - - pybind11 + - nvidia/label/cuda-12.1.0::cuda-nvcc + - nvidia/label/cuda-12.1.0::cuda-cudart-dev + - patchelf >= 0.14 + - pybind11 >= 2.11.1 # Misc + - optree >= 0.4.1 - typing-extensions - numpy - matplotlib-base - seaborn - python-graphviz - pillow + - setproctitle # Documentation - - sphinx - - sphinx_rtd_theme + - sphinx >= 5.2.1 + - sphinx-rtd-theme - sphinx-autobuild - sphinx-copybutton - sphinxcontrib-spelling - sphinxcontrib-bibtex - - sphinx-autodoc-typehints + - sphinx-autodoc-typehints >= 1.19.2 - pyenchant + - hunspell-en - myst-nb - ipykernel - - pandoc - - docutils = 0.16 + - docutils # Testing - pytest - pytest-cov - pytest-xdist - isort - - conda-forge::black >= 22.6.0 + - conda-forge::black-jupyter - pylint - mypy - flake8 - flake8-bugbear - - doc8 < 1.0.0a0 + - flake8-comprehensions + - flake8-docstrings + - flake8-pyi + - flake8-simplify + - ruff + - doc8 - pydocstyle - - clang-format - - clang-tools # clang-tidy - - cpplint - - pre-commit + - conda-forge::clang-format >= 14 + - conda-forge::clang-tools >= 14 # clang-tidy + - conda-forge::cpplint + - conda-forge::pre-commit + - conda-forge::identify diff --git a/docs/conda-recipe.yaml b/docs/conda-recipe.yaml index 7ba50adb..d7d2f288 100644 --- a/docs/conda-recipe.yaml +++ b/docs/conda-recipe.yaml @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,40 +15,39 @@ # # Create virtual environment with command: # -# $ CONDA_OVERRIDE_CUDA=11.7 conda env create --file docs/conda-recipe.yaml +# $ CONDA_OVERRIDE_CUDA=12.1 conda env create --file docs/conda-recipe.yaml # name: torchopt-docs channels: - pytorch + - nvidia/label/cuda-12.1.0 - defaults - conda-forge dependencies: - - python = 3.8 + - python = 3.11 - pip # Learning - - pytorch::pytorch >= 1.12 + - pytorch::pytorch >= 2.0 # sync with project.dependencies + - pytorch::cpuonly - pytorch::pytorch-mutex = *=*cpu* - pip: - - functorch >= 0.2 - torchviz - sphinxcontrib-katex # for documentation - - tensorboard - - wandb # Build toolchain - - cmake >= 3.4 + - cmake >= 3.11 - make - cxx-compiler - - gxx = 10 - - nvidia/label/cuda-11.6.2::cuda-nvcc - - nvidia/label/cuda-11.6.2::cuda-cudart-dev - - pybind11 + - nvidia/label/cuda-12.1.0::cuda-nvcc + - nvidia/label/cuda-12.1.0::cuda-cudart-dev + - pybind11 >= 2.11.1 # Misc + - optree >= 0.4.1 - typing-extensions - numpy - matplotlib-base @@ -57,15 +56,15 @@ dependencies: - pillow # Documentation - - sphinx - - sphinx_rtd_theme + - sphinx >= 5.2.1 + - sphinx-rtd-theme - sphinx-autobuild - sphinx-copybutton - sphinxcontrib-spelling - sphinxcontrib-bibtex - - sphinx-autodoc-typehints + - sphinx-autodoc-typehints >= 1.19.2 - pyenchant + - hunspell-en - myst-nb - ipykernel - - pandoc - - docutils = 0.16 + - docutils diff --git a/docs/requirements.txt b/docs/requirements.txt index cdfc5b18..c9631b75 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,20 +1,20 @@ --extra-index-url https://download.pytorch.org/whl/cpu -torch >= 1.12 -functorch >= 0.2 +# Sync with project.dependencies +torch >= 2.0 --requirement ../requirements.txt -sphinx >= 5.0 +sphinx >= 5.2.1, < 7.0.0a0 +sphinxcontrib-bibtex >= 2.4 +sphinx-autodoc-typehints >= 1.20 +myst-nb >= 0.15 + sphinx-autoapi sphinx-autobuild sphinx-copybutton sphinx-rtd-theme sphinxcontrib-katex -sphinxcontrib-bibtex -sphinx-autodoc-typehints IPython ipykernel -pandoc -myst_nb -docutils == 0.16 +docutils matplotlib diff --git a/docs/source/_static/css/style.css b/docs/source/_static/css/style.css index 15b8d2e2..60f33012 100644 --- a/docs/source/_static/css/style.css +++ b/docs/source/_static/css/style.css @@ -1,5 +1,5 @@ /** - * Copyright 2022 MetaOPT Team. All Rights Reserved. + * Copyright 2022-2024 MetaOPT Team. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/docs/source/_static/images/explicit-gradient.png b/docs/source/_static/images/explicit-gradient.png new file mode 100644 index 00000000..90cf4d4d Binary files /dev/null and b/docs/source/_static/images/explicit-gradient.png differ diff --git a/docs/source/_static/images/implicit-gradient.png b/docs/source/_static/images/implicit-gradient.png new file mode 100644 index 00000000..faf26486 Binary files /dev/null and b/docs/source/_static/images/implicit-gradient.png differ diff --git a/docs/source/_static/images/visualization-fig1.svg b/docs/source/_static/images/visualization-fig1.svg new file mode 100644 index 00000000..281e456b --- /dev/null +++ b/docs/source/_static/images/visualization-fig1.svg @@ -0,0 +1,57 @@ + + + + + + +%3 + + + +140534064715952 + +y +() + + + +140534064838304 + +MulBackward0 + + + +140534064838304->140534064715952 + + + + + +140534064837776 + +AccumulateGrad + + + +140534064837776->140534064838304 + + + + + +140534064714832 + +x +() + + + +140534064714832->140534064837776 + + + + + diff --git a/docs/source/_static/images/visualization-fig2.svg b/docs/source/_static/images/visualization-fig2.svg new file mode 100644 index 00000000..25db4e5a --- /dev/null +++ b/docs/source/_static/images/visualization-fig2.svg @@ -0,0 +1,106 @@ + + + + + + +%3 + + + +140534659780336 + +loss +() + + + +140531595570768 + +MseLossBackward0 + + + +140531595570768->140534659780336 + + + + + +140531595570576 + +AddmmBackward0 + + + +140531595570576->140531595570768 + + + + + +140531595570528 + +AccumulateGrad + + + +140531595570528->140531595570576 + + + + + +140531595583632 + +fc.bias +(1) + + + +140531595583632->140531595570528 + + + + + +140531595571104 + +TBackward0 + + + +140531595571104->140531595570576 + + + + + +140531595570432 + +AccumulateGrad + + + +140531595570432->140531595571104 + + + + + +140531595582816 + +fc.weight +(1, 5) + + + +140531595582816->140531595570432 + + + + + diff --git a/docs/source/_static/images/visualization-fig3.svg b/docs/source/_static/images/visualization-fig3.svg new file mode 100644 index 00000000..c041e0f6 --- /dev/null +++ b/docs/source/_static/images/visualization-fig3.svg @@ -0,0 +1,339 @@ + + + + + + +%3 + + + +140531595614064 + +loss +() + + + +140531595567168 + +MseLossBackward0 + + + +140531595567168->140531595614064 + + + + + +140531595569232 + +AddBackward0 + + + +140531595569232->140531595567168 + + + + + +140531595568800 + +AddmmBackward0 + + + +140531595568800->140531595569232 + + + + + +140534660247264 + +AddBackward0 +step1.fc.bias +(1) + + + +140534660247264->140531595568800 + + + + + +140534553595376 + +AccumulateGrad + + + +140534553595376->140534660247264 + + + + + +140534553592832 + +AddmmBackward0 + + + +140534553595376->140534553592832 + + + + + +140534064448352 + +step0.fc.bias +(1) + + + +140534064448352->140534553595376 + + + + + +140534553595616 + +MulBackward0 + + + +140534553595616->140534660247264 + + + + + +140534553594848 + +ViewBackward0 + + + +140534553594848->140534553595616 + + + + + +140534553594992 + +SumBackward1 + + + +140534553594992->140534553594848 + + + + + +140534553594800 + +MseLossBackwardBackward0 + + + +140534553594800->140534553594992 + + + + + +140531595617904 + +TBackward0 + + + +140534553594800->140531595617904 + + + + + +140534553593072 + +AddBackward0 + + + +140534553593072->140534553594800 + + + + + +140534553592832->140534553593072 + + + + + +140534553593456 + +TBackward0 + + + +140534553593456->140534553592832 + + + + + +140534553593888 + +AccumulateGrad + + + +140534553593888->140534553593456 + + + + + +140531595572368 + +AddBackward0 +step1.fc.weight +(1, 5) + + + +140534553593888->140531595572368 + + + + + +140531595612944 + +step0.fc.weight +(1, 5) + + + +140531595612944->140534553593888 + + + + + +140531595567888 + +AccumulateGrad + + + +140531595567888->140531595569232 + + + + + +140531595567888->140534553593072 + + + + + +140531595613184 + +meta_param +() + + + +140531595613184->140531595567888 + + + + + +140534553594272 + +TBackward0 + + + +140534553594272->140531595568800 + + + + + +140531595572368->140534553594272 + + + + + +140534553593504 + +MulBackward0 + + + +140534553593504->140531595572368 + + + + + +140534553592976 + +TBackward0 + + + +140534553592976->140534553593504 + + + + + +140534553593216 + +TBackward0 + + + +140534553593216->140534553592976 + + + + + +140534553593552 + +MmBackward0 + + + +140534553593552->140534553593216 + + + + + +140531595617904->140534553593552 + + + + + diff --git a/docs/source/_static/images/zero-order.png b/docs/source/_static/images/zero-order.png new file mode 100644 index 00000000..2c94d667 Binary files /dev/null and b/docs/source/_static/images/zero-order.png differ diff --git a/docs/source/api/api.rst b/docs/source/api/api.rst index 545a8d54..0112e877 100644 --- a/docs/source/api/api.rst +++ b/docs/source/api/api.rst @@ -29,10 +29,31 @@ Functional Optimizers .. autosummary:: + FuncOptimizer + adadelta + adagrad adam - sgd - rmsprop adamw + adamax + radam + rmsprop + sgd + +Wrapper for Function Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: FuncOptimizer + :members: + +Functional AdaDelta Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: adadelta + +Functional AdaGrad Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: adagrad Functional Adam Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -44,16 +65,26 @@ Functional AdamW Optimizer .. autofunction:: adamw -Functional SGD Optimizer -~~~~~~~~~~~~~~~~~~~~~~~~ +Functional AdaMax Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: sgd +.. autofunction:: adamax + +Functional RAdam Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: radam Functional RMSProp Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: rmsprop +Functional SGD Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: sgd + ------ Classic Optimizers @@ -63,10 +94,27 @@ Classic Optimizers .. autosummary:: + AdaDelta + Adadelta + AdaGrad + Adagrad Adam - SGD - RMSProp AdamW + AdaMax + Adamax + RAdam + RMSProp + SGD + +Classic AdaDelta Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: AdaDelta + +Classic AdaGrad Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: AdaGrad Classic Adam Optimizer ~~~~~~~~~~~~~~~~~~~~~~ @@ -78,16 +126,26 @@ Classic AdamW Optimizer .. autoclass:: AdamW -Classic SGD Optimizer -~~~~~~~~~~~~~~~~~~~~~ +Classic AdaMax Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: SGD +.. autoclass:: AdaMax + +Classic RAdam Optimizer +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: RAdam Classic RMSProp Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: RMSProp +Classic SGD Optimizer +~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: SGD + ------ Differentiable Meta-Optimizers @@ -97,10 +155,27 @@ Differentiable Meta-Optimizers .. autosummary:: + MetaAdaDelta + MetaAdadelta + MetaAdaGrad + MetaAdagrad MetaAdam - MetaSGD - MetaRMSProp MetaAdamW + MetaAdaMax + MetaAdamax + MetaRAdam + MetaRMSProp + MetaSGD + +Differentiable Meta-AdaDelta Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: MetaAdaDelta + +Differentiable Meta-AdaGrad Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: MetaAdaGrad Differentiable Meta-Adam Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -112,52 +187,140 @@ Differentiable Meta-AdamW Optimizer .. autoclass:: MetaAdamW -Differentiable Meta-SGD Optimizer -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Differentiable Meta-AdaMax Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: MetaSGD +.. autoclass:: MetaAdaMax + +Differentiable Meta-RAdam Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: MetaRAdam Differentiable Meta-RMSProp Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: MetaRMSProp +Differentiable Meta-SGD Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: MetaSGD + +------ + +Implicit Differentiation +======================== + +.. currentmodule:: torchopt.diff.implicit + +.. autosummary:: + + custom_root + nn.ImplicitMetaGradientModule + +Custom Solvers +~~~~~~~~~~~~~~ + +.. autofunction:: custom_root + + +Implicit Meta-Gradient Module +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. currentmodule:: torchopt.diff.implicit.nn + +.. autoclass:: ImplicitMetaGradientModule + :members: + +------ + +Linear System Solvers +===================== + +.. currentmodule:: torchopt.linear_solve + +.. autosummary:: + + solve_cg + solve_normal_cg + solve_inv + +Indirect Solvers +~~~~~~~~~~~~~~~~ + +.. autofunction:: solve_cg +.. autofunction:: solve_normal_cg +.. autofunction:: solve_inv + +------ + +Zero-Order Differentiation +========================== + +.. currentmodule:: torchopt.diff.zero_order + +.. autosummary:: + + zero_order + nn.ZeroOrderGradientModule + +Decorators +~~~~~~~~~~ + +.. autofunction:: zero_order + + +Zero-order Gradient Module +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. currentmodule:: torchopt.diff.zero_order.nn + +.. autoclass:: ZeroOrderGradientModule + :members: + ------ Optimizer Hooks =============== -.. currentmodule:: torchopt._src.hook +.. currentmodule:: torchopt.hook .. autosummary:: register_hook zero_nan_hook + nan_to_num_hook Hook ~~~~ .. autofunction:: register_hook .. autofunction:: zero_nan_hook +.. autofunction:: nan_to_num_hook + +------ Gradient Transformation ======================= -.. currentmodule:: torchopt._src.clip +.. currentmodule:: torchopt .. autosummary:: clip_grad_norm + nan_to_num Transforms ~~~~~~~~~~ .. autofunction:: clip_grad_norm +.. autofunction:: nan_to_num Optimizer Schedules =================== -.. currentmodule:: torchopt._src.schedule +.. currentmodule:: torchopt.schedule .. autosummary:: @@ -188,7 +351,7 @@ Apply Updates Combining Optimizers ==================== -.. currentmodule:: torchopt._src.combine +.. currentmodule:: torchopt.combine .. autosummary:: @@ -200,6 +363,115 @@ Chain .. autofunction:: chain +Distributed Utilities +===================== + +.. currentmodule:: torchopt.distributed + +Initialization and Synchronization +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + auto_init_rpc + barrier + +.. autofunction:: auto_init_rpc +.. autofunction:: barrier + +Process group information +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + get_world_info + get_world_rank + get_rank + get_world_size + get_local_rank + get_local_world_size + get_worker_id + +.. autofunction:: get_world_info +.. autofunction:: get_world_rank +.. autofunction:: get_rank +.. autofunction:: get_world_size +.. autofunction:: get_local_rank +.. autofunction:: get_local_world_size +.. autofunction:: get_worker_id + +Worker selection +~~~~~~~~~~~~~~~~ + +.. autosummary:: + + on_rank + not_on_rank + rank_zero_only + rank_non_zero_only + +.. autofunction:: on_rank +.. autofunction:: not_on_rank +.. autofunction:: rank_zero_only +.. autofunction:: rank_non_zero_only + +Remote Procedure Call (RPC) +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + remote_async_call + remote_sync_call + +.. autofunction:: remote_async_call +.. autofunction:: remote_sync_call + +Predefined partitioners and reducers +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + dim_partitioner + batch_partitioner + mean_reducer + sum_reducer + +.. autofunction:: dim_partitioner +.. autofunction:: batch_partitioner +.. autofunction:: mean_reducer +.. autofunction:: sum_reducer + +Function parallelization wrappers +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + parallelize + parallelize_async + parallelize_sync + +.. autofunction:: parallelize +.. autofunction:: parallelize_async +.. autofunction:: parallelize_sync + +Distributed Autograd +~~~~~~~~~~~~~~~~~~~~ + +.. currentmodule:: torchopt.distributed.autograd + +.. autosummary:: + + context + get_gradients + backward + grad + +.. autofunction:: context +.. autofunction:: get_gradients +.. autofunction:: backward +.. autofunction:: grad + + General Utilities ================= @@ -230,7 +502,7 @@ Stop Gradient Visualizing Gradient Flow ========================= -.. currentmodule:: torchopt._src.visual +.. currentmodule:: torchopt.visual .. autosummary:: diff --git a/docs/source/basics/basics.rst b/docs/source/basics/basics.rst new file mode 100644 index 00000000..8b5d5acd --- /dev/null +++ b/docs/source/basics/basics.rst @@ -0,0 +1,34 @@ +Basics +====== + +This section describes useful concepts across TorchOpt. + +TorchOpt Types +-------------- + +.. autosummary:: + + torchopt.base.GradientTransformation + torchopt.base.TransformInitFn + torchopt.base.TransformUpdateFn + +PyTrees +------- + +`PyTrees `_ is an essential concept in TorchOpt. +They can be thought as a generalization of vectors. +They are a way to structure parameters or weights using tuples and dictionaries. +Many solvers in TorchOpt have native support for pytrees. + +Floating-Point Precision +------------------------ + +TorchOpt uses single (32-bit) floating precision (``torch.float32``) by default. +However, for some algorithms, this may not be enough. +Double (64-bit) floating precision (``torch.float64``) can be enabled by adding the following lines at the beginning of the file: + +.. code-block:: python + + import torch + + torch.set_default_dtype(torch.float64) diff --git a/docs/source/bibtex.json b/docs/source/bibtex.json index c2aa9165..7abea503 100644 --- a/docs/source/bibtex.json +++ b/docs/source/bibtex.json @@ -1,7 +1,7 @@ { - "cited": { - "examples/MAML": [ - "MAML", - ] - } + "cited": { + "examples/MAML": [ + "MAML", + ] + } } diff --git a/docs/source/conf.py b/docs/source/conf.py index 694086fe..a4f23533 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,16 +19,18 @@ # pylint: disable=all -# -- Path setup -------------------------------------------------------------- +# -- Path setup ---------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # +import logging import os import pathlib import sys +import sphinx import sphinxcontrib.katex as katex @@ -38,21 +40,39 @@ def get_version() -> str: sys.path.insert(0, str(PROJECT_ROOT / 'torchopt')) - import version # noqa + import version return version.__version__ -# -- Project information ----------------------------------------------------- +try: + import sphinx_autodoc_typehints +except ImportError: + pass +else: + + class RecursiveForwardRefFilter(logging.Filter): + def filter(self, record: logging.LogRecord) -> bool: + if ( + "name 'TensorTree' is not defined" in record.getMessage() + or "name 'OptionalTensorTree' is not defined" in record.getMessage() + ): + return False + return super().filter(record) + + sphinx_autodoc_typehints._LOGGER.logger.addFilter(RecursiveForwardRefFilter()) + + +# -- Project information ------------------------------------------------------- project = 'TorchOpt' -copyright = '2022 MetaOPT Team' +copyright = '2022-2024 MetaOPT Team' author = 'TorchOpt Contributors' # The full version, including alpha/beta/rc tags release = get_version() -# -- General configuration --------------------------------------------------- +# -- General configuration ----------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom @@ -75,7 +95,7 @@ def get_version() -> str: 'sphinxcontrib.bibtex', 'sphinxcontrib.katex', 'sphinx_autodoc_typehints', - 'myst_nb', # This is used for the .ipynb notebooks + 'myst_nb', # this is used for the .ipynb notebooks ] if not os.getenv('READTHEDOCS', None): @@ -110,8 +130,9 @@ def get_version() -> str: # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'default' -# -- Options for autodoc ----------------------------------------------------- +# -- Options for autodoc ------------------------------------------------------- +autosummary_generate = False autodoc_default_options = { 'member-order': 'bysource', 'undoc-members': True, @@ -120,21 +141,27 @@ def get_version() -> str: 'exclude-members': '__module__, __dict__, __repr__, __str__, __weakref__', } autoclass_content = 'both' +simplify_optional_unions = False + +# -- Options for autosummary --------------------------------------------------- + +autosummary_generate = False +# numpydoc_class_members_toctree = False -# -- Options for bibtex ----------------------------------------------------- +# -- Options for bibtex -------------------------------------------------------- bibtex_bibfiles = ['references.bib'] -# -- Options for myst ------------------------------------------------------- +# -- Options for myst ---------------------------------------------------------- nb_execution_mode = 'force' nb_execution_allow_errors = False -# -- Options for katex ------------------------------------------------------ +# -- Options for katex --------------------------------------------------------- # See: https://sphinxcontrib-katex.readthedocs.io/en/0.4.1/macros.html latex_macros = r""" - \def \d #1{\operatorname{#1}} + \def \d #1{\operatorname{#1}} """ # Translate LaTeX macros to KaTeX and add to options for HTML builder @@ -144,7 +171,7 @@ def get_version() -> str: # Add LaTeX macros for LATEX builder latex_elements = {'preamble': latex_macros} -# -- Options for HTML output ------------------------------------------------- +# -- Options for HTML output --------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. @@ -165,7 +192,7 @@ def get_version() -> str: html_logo = '_static/images/logo.png' -def setup(app): +def setup(app: sphinx.application.Sphinx) -> None: app.add_js_file('https://cdn.jsdelivr.net/npm/vega@5.20.2') app.add_js_file('https://cdn.jsdelivr.net/npm/vega-lite@5.1.0') app.add_js_file('https://cdn.jsdelivr.net/npm/vega-embed@6.17.0') @@ -183,27 +210,27 @@ def setup(app): # # html_sidebars = {} -# -- Source code links ------------------------------------------------------- +# -- Source code links --------------------------------------------------------- extlinks = { 'gitcode': ('https://github.com/metaopt/torchopt/blob/HEAD/%s', '%s'), 'issue': ('https://github.com/metaopt/torchopt/issues/%s', 'issue %s'), } -# -- Extension configuration ------------------------------------------------- +# -- Extension configuration --------------------------------------------------- -# -- Options for napoleon extension ------------------------------------------ +# -- Options for napoleon extension -------------------------------------------- napoleon_include_init_with_doc = True napoleon_include_private_with_doc = False napoleon_include_special_with_doc = True -# -- Options for intersphinx extension --------------------------------------- +# -- Options for intersphinx extension ----------------------------------------- # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = {'python': ('https://docs.python.org/3', None)} -# -- Options for todo extension ---------------------------------------------- +# -- Options for todo extension ------------------------------------------------ # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = True diff --git a/docs/source/developer/contributing.rst b/docs/source/developer/contributing.rst index 93d0cc50..e40a564a 100644 --- a/docs/source/developer/contributing.rst +++ b/docs/source/developer/contributing.rst @@ -12,12 +12,12 @@ Before contributing to TorchOpt, please follow the instructions below to setup. git remote add upstream git@github.com:metaopt/torchopt.git -2. Setup a development environment via `conda `_: +2. Setup a development environment via `conda `_ / `mamba `_: .. code-block:: bash # You may need `CONDA_OVERRIDE_CUDA` if conda fails to detect the NVIDIA driver (e.g. in docker or WSL2) - CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe.yaml + CONDA_OVERRIDE_CUDA=12.1 conda env create --file conda-recipe.yaml conda activate torchopt @@ -43,7 +43,7 @@ in the main directory. This installation is removable by: .. code-block:: bash - pip3 uninstall torchopt + make uninstall Lint Check @@ -51,9 +51,9 @@ Lint Check We use several tools to secure code quality, including: - * PEP8 code style: ``black``, ``isort``, ``pylint``, ``flake8`` + * Python code style: ``black``, ``isort``, ``pylint``, ``flake8``, ``ruff`` * Type hint check: ``mypy`` - * C++ Google-style: ``cpplint``, ``clang-format`` + * C++ Google-style: ``cpplint``, ``clang-format``, ``clang-tidy`` * License: ``addlicense`` * Documentation: ``pydocstyle``, ``doc8`` @@ -91,20 +91,20 @@ To build compatible **manylinux2014** (:pep:`599`) wheels for distribution, you pip3 install --upgrade cibuildwheel - export TEST_TORCH_SPECS="cpu cu113 cu116" # `torch` builds for testing - export CUDA_VERSION="11.6" # version of `nvcc` for compilation + export TEST_TORCH_SPECS="cpu cu118" # `torch` builds for testing + export CUDA_VERSION="12.1" # version of `nvcc` for compilation python3 -m cibuildwheel --platform=linux --output-dir=wheelhouse --config-file=pyproject.toml It will install the CUDA compiler with ``CUDA_VERSION`` in the build container. Then build wheel binaries for all supported CPython versions. The outputs will be placed in the ``wheelhouse`` directory. To build a wheel for a specific CPython version, you can use the |CIBW_BUILD|_ environment variable. -For example, the following command will build a wheel for Python 3.7: +For example, the following command will build a wheel for Python 3.8: .. code-block:: bash - CIBW_BUILD="cp37*manylinux*" python3 -m cibuildwheel --platform=linux --output-dir=wheelhouse --config-file=pyproject.toml + CIBW_BUILD="cp38*manylinux*" python3 -m cibuildwheel --platform=linux --output-dir=wheelhouse --config-file=pyproject.toml -You can change ``cp37*`` to ``cp310*`` to build for Python 3.10. See https://cibuildwheel.readthedocs.io/en/stable/options for more options. +You can change ``cp38*`` to ``cp310*`` to build for Python 3.10. See https://cibuildwheel.readthedocs.io/en/stable/options for more options. .. |cibuildwheel| replace:: ``cibuildwheel`` .. _cibuildwheel: https://github.com/pypa/cibuildwheel diff --git a/docs/source/developer/contributor.rst b/docs/source/developer/contributor.rst index 407b53b0..2358f963 100644 --- a/docs/source/developer/contributor.rst +++ b/docs/source/developer/contributor.rst @@ -3,5 +3,5 @@ Contributor We always welcome contributions to help make TorchOpt better. Below is an incomplete list of our contributors (find more on `this page `_). -* Yao Fu (`future-xy `_) -* Vincent Moens (`vmoens `_) +- Yao Fu (`future-xy `_) +- Vincent Moens (`vmoens `_) diff --git a/docs/source/distributed/distributed.rst b/docs/source/distributed/distributed.rst new file mode 100644 index 00000000..0b1bf536 --- /dev/null +++ b/docs/source/distributed/distributed.rst @@ -0,0 +1,733 @@ +Distributed Training +==================== + +.. currentmodule:: torchopt.distributed + +Distributed training is a technique that allows you to train your pipeline on multiple workers/machines. +This is useful when you have a large model or computation graph that doesn't fit on a single GPU/machine, or when you want to train a model faster by using more resources. + +TorchOpt offers a simple API to train your model on multiple GPUs/machines based on the PyTorch |Distributed RPC|_. +Here are some key concepts that TorchOpt's distributed mechanism relies on: + +- **Remote Procedure Call (RPC)** supports running a function on the specified destination worker with the given arguments and getting the return value back or creating a reference to the return value. + + That is, you can treat the remote worker as an accelerator. You can call a function on a remote worker and get the result back to the local worker. + +- **Distributed Autograd** stitches together local autograd engines on all the workers involved in the forward pass, and automatically reach out to them during the backward pass to compute gradients. + + This is much more flexible to fit the meta-learning use case to have a complex task dependency tree. + +.. |Distributed RPC| replace:: Distributed RPC Framework (``torch.distributed.rpc``) +.. _Distributed RPC: https://pytorch.org/docs/stable/rpc.html + +Here are some useful resources to learn more about distributed training: + +- `Distributed RPC Framework `_ +- `Distributed Autograd Design `_ +- `Remote Reference Protocol `_ +- `RPC tutorials `_ +- `Autograd mechanics `_ +- **Example**: :ref:`Using TorchOpt with Distributed Training ` + +------ + +Why RPC-Based Distributed Training +---------------------------------- + +Due to the Global Interpreter Lock (GIL) in Python, only one thread can execute Python code at a time. +This means that you can't take advantage of multiple cores on your machine. +Distribute the workload across multiple processes, or namely workers, that will run in parallel to gain faster execution performance. +Each worker will have its own Python interpreter and memory namespace. + +Compare to single-process programming, you need to be aware of the following: + +- **Communication**: You need to explicitly send and receive messages between workers. +- **Synchronization**: You need to explicitly synchronize the states between workers. + +Message Passing Interface (MPI) and Distributed Data-Parallel Training (DDP) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +`MPI `_ is a standard for message passing between processes. +It is a popular choice for `Distributed Data-Parallel Training (DDP) `_. +PyTorch has implemented this with several `backends `_, including `Gloo `_, `MPI `_, and `NCCL `_. + +However, MPI-based parallelism has some drawbacks: + +- **MPI is not user-friendly**. + MPI-like APIs only provide low-level primitives for sending and receiving messages. + It requires the users to manage the message passing between workers manually. + The users should be aware of the communication pattern and the synchronization between workers. + +- **MPI is not flexible**. + MPI-like APIs are designed for `Distributed Data-Parallel Training (DDP) `_, which is a widely adopted `single-program multiple-data (SPMD) `_ training paradigm. + However, for meta-learning tasks, the task dependency tree is complex and dynamic. + It may not fit into the SPMD paradigm. + It is hard to implement the distributed autograd engine on top of MPI. + +- **MPI only communicates the value of tensors but not the gradients and graphs**. + This is a limitation of MPI. + The users need to handle the gradients manually across multiple workers. + For example, receive the gradients from other workers and put them as ``grad_outputs`` to function |torch.autograd.grad|_. + +.. |torch.autograd.grad| replace:: ``torch.autograd.grad`` +.. _torch.autograd.grad: https://pytorch.org/docs/stable/generated/torch.autograd.grad.html + +Distributed Autograd with Remote Procedure Call (RPC) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To address the needs of meta-learning tasks, which have complex and dynamic nature of the training process. +TorchOpt uses PyTorch |Distributed RPC|_ to implement the distributed training mechanism. +PyTorch implements the RPC communication operations with appropriate ``RpcSendBackward`` and ``RpcRecvBackward`` functions. +The `Distributed Autograd Engine `_ automatically calls these functions to send and receive the gradients between workers. + +With **RPC** and **Distributed Autograd**, TorchOpt distributes a **differentiable optimization** job across multiple workers and executes the workers in parallel. +It allows the users to build the whole computation graph (**both forward and backward**) across multiple workers. +The users can wrap code in the distributed autograd module and achieve substantial speedup in training time with only a few changes in existing training scripts. (:ref:`example `) + +Here is an example of distributed autograd graph using RPC from `Distributed Backward Pass `_ documentation: + +.. code-block:: python + :emphasize-lines: 13, 18, 28, 31 + + import torch + import torch.distributed.autograd as dist_autograd + import torch.distributed.rpc as rpc + + def my_add(t1, t2): + return torch.add(t1, t2) + + # On worker 0: + + # Setup the autograd context. Computations that take + # part in the distributed backward pass must be within + # the distributed autograd context manager. + with dist_autograd.context() as context_id: + t1 = torch.rand((3, 3), requires_grad=True) + t2 = torch.rand((3, 3), requires_grad=True) + + # Perform some computation remotely. + t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2)) + + # Perform some computation locally based on the remote result. + t4 = torch.rand((3, 3), requires_grad=True) + t5 = torch.mul(t3, t4) + + # Compute some loss. + loss = t5.sum() + + # Run the backward pass. + dist_autograd.backward(context_id, [loss]) + + # Retrieve the gradients from the context. + dist_autograd.get_gradients(context_id) + +.. image:: https://pytorch.org/docs/stable/_images/distributed_dependencies_computed.png + +For more details, please refer to the `Distributed Autograd Design `_ documentation. + +------ + +TorchOpt's Distributed Training +------------------------------- + +TorchOpt's distributed package is built upon the PyTorch |Distributed RPC|_ and |Distributed Autograd Framework|_. + +.. |Distributed Autograd Framework| replace:: Distributed Autograd Framework (``torch.distributed.autograd``) +.. _Distributed Autograd Framework: https://pytorch.org/docs/stable/rpc.html#distributed-autograd-framework + +TorchOpt provides some utility functions to make it easier to use the distributed training mechanism. + +Initialization and Synchronization +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + torchopt.distributed.auto_init_rpc + torchopt.distributed.barrier + +Users can wrap their program entry function with the decorator :func:`torchopt.distributed.auto_init_rpc`: + +.. code-block:: python + :emphasize-lines: 13 + + import torchopt.distributed as todist + + def parse_arguments(): + parser = argparse.ArgumentParser() + ... + + return args + + def worker_init_fn(): + # set process title, seeding, etc. + ... + + @todist.auto_init_rpc(worker_init_fn) + def main(): + # Your code here + args = parse_arguments() + ... + + if __name__ == '__main__': + main() + +The decorator will initialize the RPC framework and synchronize the workers on startup. + +.. note:: + + By default, all tensors must move to the CPU before sending them to other workers. + If you want to send/receive the tensors directly between GPUs from different workers, you need to specify the ``rpc_backend_options`` with ``device_maps``. + Please refer to the documentation of |torch.distributed.rpc.init_rpc|_ for more details. + +.. |torch.distributed.rpc.init_rpc| replace:: ``torch.distributed.rpc.init_rpc`` +.. _torch.distributed.rpc.init_rpc: https://pytorch.org/docs/stable/rpc.html#torch.distributed.rpc.init_rpc + +Then, users can use |torchrun|_ to launch the program: + +.. code-block:: bash + + torchrun --nnodes=1 --nproc_per_node=8 YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) + +.. |torchrun| replace:: ``torchrun`` (Elastic Launch) +.. _torchrun: https://pytorch.org/docs/stable/elastic/run.html + +Process group information +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + torchopt.distributed.get_world_info + torchopt.distributed.get_world_rank + torchopt.distributed.get_rank + torchopt.distributed.get_world_size + torchopt.distributed.get_local_rank + torchopt.distributed.get_local_world_size + torchopt.distributed.get_worker_id + +After initializing the RPC server, users can use the above functions to get the process group information. + +For example, use :func:`torchopt.distributed.get_local_rank` to determine which GPU device to use: + +.. code-block:: python + + import torch + import torchopt.distributed as todist + + def worker_init_fn(): + local_rank = todist.get_local_rank() + torch.cuda.set_device(local_rank) + + @todist.auto_init_rpc(worker_init_fn) + def main(): + ... + +Worker selection +~~~~~~~~~~~~~~~~ + +.. autosummary:: + + torchopt.distributed.on_rank + torchopt.distributed.not_on_rank + torchopt.distributed.rank_zero_only + torchopt.distributed.rank_non_zero_only + +TorchOpt provides some decorators to execute the decorated function on specific workers. + +For example, use :func:`torchopt.distributed.rank_zero_only` to execute the function only on the main worker (``worker0``), such as saving checkpoints or logging the results: + +.. code-block:: python + :emphasize-lines: 3, 7, 11 + + import torchopt.distributed as todist + + @todist.rank_non_zero_only + def greet(): + print(f'Greetings from worker(rank={todist.get_rank()})!') + + @todist.rank_zero_only + def save_checkpoint(model): + ... + + @todist.rank_zero_only + def log_results(writer, results): + ... + + @todist.auto_init_rpc() + def main(): + greet() + + ... + + for epoch in range(args.epochs): + ... + + if epoch % args.log_interval == 0: + log_results(writer, results) + + if epoch % args.save_interval == 0: + save_checkpoint(model) + +Remote Procedure Call (RPC) +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + torchopt.distributed.remote_async_call + torchopt.distributed.remote_sync_call + +TorchOpt provides two functions to execute the remote procedure call (RPC) on remote workers. +The asynchronous version :func:`remote_async_call` function returns a |torch.Future|_ object, and the :func:`remote_sync_call` function executes and returns the result directly. + +.. |torch.Future| replace:: ``torch.Future`` +.. _torch.Future: https://pytorch.org/docs/stable/futures.html#torch.futures.Future + +Users can distribute their workload (a function) to a specific worker by: + +.. code-block:: python + :emphasize-lines: 12 + + import torchopt.distributed as todist + + @todist.auto_init_rpc(worker_init_fn) + def main(): + ... + + # Execute the function on the remote worker (asynchronously) + future = todist.remote_async_call( + func, + args=(arg1, arg2, ...), + kwargs={...}, + partitioner=worker_id, + ) + + # Wait for the result + result = future.wait() + + ... + +or + +.. code-block:: python + :emphasize-lines: 12 + + import torchopt.distributed as todist + + @todist.auto_init_rpc(worker_init_fn) + def main(): + ... + + # Execute the function on the remote worker + result = todist.remote_sync_call( + func, + args=(arg1, arg2, ...), + kwargs={...}, + partitioner=worker_id, + ) + + ... + +TorchOpt follows the `MapReduce programming model `_ to distribute the workload. + +The ``partitioner`` argument specifies the worker to execute the function. +The users can optionally specify the ``reducer`` argument to aggregate the results from the workers. +Finally, the caller will get a reference to the result on the local worker. + +- ``partitioner``: a function that takes the ``args`` and ``kwargs`` arguments and returns a list of triplets ``(worker_id, worker_args, worker_kwargs)``. + + The ``partitioner`` is responsible for partitioning the workload (inputs) and distributing them to the remote workers. + + If the ``partitioner`` is given by a worker ID (:class:`int` or :class:`str`), the function will be executed on the specified worker. + + If the ``partitioner`` is not given, the :func:`torchopt.distributed.batch_partitioner` will be used. + +- ``mapper``: the ``func`` argument to be executed on the remote worker. +- ``reducer`` (optional): aggregation function, takes a list of results from the remote workers and returns the final result. + + If the ``reducer`` is not given, returns the original unaggregated list. + +Predefined partitioners and reducers +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + torchopt.distributed.dim_partitioner + torchopt.distributed.batch_partitioner + torchopt.distributed.mean_reducer + torchopt.distributed.sum_reducer + +We provide some predefined partitioners and reducers. +Users can combine the :func:`torchopt.distributed.batch_partitioner` and :func:`torchopt.distributed.mean_reducer` to achieve the distributed data parallelism (DDP) easily: + +.. code-block:: python + :emphasize-lines: 18, 19 + + import torchopt.distributed as todist + + def loss_fn(model, batch): + ... + + @todist.rank_zero_only + def train(args): + + for epoch in range(args.epochs): + ... + + for batch in dataloader: + # Partition the data on the batch (first) dimension and distribute them to the remote workers + # Aggregate the results from the remote workers and return the mean loss + loss = todist.remote_sync_call( + loss_fn, + args=(model, batch), + partitioner=todist.batch_partitioner, + reducer=todist.mean_reducer, + ) + + ... + +We also provide a :func:`torchopt.distributed.dim_partitioner` to partition the data on the specified dimension. +While implementing the **Model-Agnostic Meta-Learning** (MAML) :cite:`MAML` algorithm, users can use this to parallel the training for the inner loop: + +.. code-block:: python + :emphasize-lines: 29, 30 + + import torchopt.distributed as todist + + def inner_loop(model, task_batch, args): + # task_batch: shape = (B, *) + inner_model = torchopt.module_clone(model, by='reference', detach_buffers=True) + + # Inner optimization + for inner_step in range(args.inner_steps): + inner_loss = inner_loss_fn(inner_model, task_batch) + + # Update the inner model + ... + + # Compute the outer loss + outer_loss = inner_loss_fn(inner_model, task_batch) + return outer_loss + + @todist.rank_zero_only + def train(args): + + for epoch in range(args.epochs): + ... + + for batch in dataloader: + # batch: shape = (T, B, *) + outer_loss = todist.remote_sync_call( + inner_loop, + args=(model, batch), + partitioner=todist.dim_partitioner(0, exclusive=True, keepdim=False), + reducer=todist.mean_reducer, + ) + + ... + +The ``dim_partitioner(0, exclusive=True, keepdim=False)`` will split the batch of size ``(T, B, *)`` into ``T`` batches of size ``(B, *)``. +Each task will be executed on the remote worker **independently** (``exclusive=True``). +Finally, the results will be aggregated by the :func:`torchopt.distributed.mean_reducer` to compute the mean loss. +Inside the ``inner_loop`` function, users may use another RPC call to further parallelize the inner loop optimization. + +Function parallelization wrappers +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + torchopt.distributed.parallelize + torchopt.distributed.parallelize_async + torchopt.distributed.parallelize_sync + +TorchOpt offers wrappers to parallelize the function execution on the remote workers. +It makes the function execution on the remote workers more transparent to the users and makes the code structure clear. + +.. code-block:: python + :emphasize-lines: 3, 9, 10, 11, 12 + + import torchopt.distributed as todist + + @todist.parallelize(partitioner=todist.batch_partitioner, reducer=todist.mean_reducer) + def distributed_data_parallelism(model, batch, args): + # Compute local loss of the given batch + ... + return loss + + @todist.parallelize( + partitioner=todist.dim_partitioner(0, exclusive=True, keepdim=False), + reducer=todist.mean_reducer, + ) + def inner_loop(model, batch, args): # distributed MAML inner loop + # batch: shape = (B, *) + inner_model = torchopt.module_clone(model, by='reference', detach_buffers=True) + + # Inner optimization + ... + + # Compute the outer loss + outer_loss = inner_loss_fn(inner_model, task_batch) + return outer_loss + + @todist.rank_zero_only + def train(args): + + for epoch in range(args.epochs): + ... + + for batch in dataloader: + # batch: shape = (T, B, *) + outer_loss = inner_loop(model, batch, args) + + ... + +Distributed Autograd +~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + torchopt.distributed.autograd.context + torchopt.distributed.autograd.get_gradients + torchopt.distributed.autograd.backward + torchopt.distributed.autograd.grad + +In this section, we will introduce the distributed autograd system. +Please refer to `Autograd mechanics `_ and `Distributed Autograd Design `_ first before going through this section. + +Recap: Autograd mechanics in single-process training +"""""""""""""""""""""""""""""""""""""""""""""""""""" + +In single-process training, the autograd engine will automatically track the operations on the forward pass and compute the gradients on the backward pass. +For each operation, if the input tensors have ``requires_grad=True`` set, the output tensor will have a ``grad_fn`` attribute to trace the computation graph. +On the backward pass, the autograd engine will traverse the computation graph from the output tensors to the input tensors and compute the gradients for each operation. + +The |torch.autograd.grad|_ function will compute the gradients of the given ``outputs`` with respect to the given ``inputs``. + +.. code-block:: python + + import torch + + model = build_model() + loss = compute_loss(model, data) + + params = tuple(model.parameters()) + grads = torch.autograd.grad(loss, params) + + print(grads) + +In practice, users usually use the PyTorch Autograd Engine with ``loss.backward()`` (or |torch.autograd.backward|_) and optimizers: + +.. code-block:: python + + import torch + import torch.optim as optim + + model = build_model() + optimizer = optim.SGD(model.parameters(), lr=lr) + + loss = compute_loss(model, data) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + +Compare to |torch.autograd.grad|_, the |torch.autograd.backward|_ function will sum and update the ``.grad`` attribute of the parameters. + +.. |torch.autograd.backward| replace:: ``torch.autograd.backward`` +.. _torch.autograd.backward: https://pytorch.org/docs/stable/generated/torch.autograd.backward.html + +RPC-based Distributed Autograd +"""""""""""""""""""""""""""""" + +PyTorch RPC framework implements the communication ``send-recv`` operations with appropriate backward functions (``RpcSendBackward`` and ``RpcRecvBackward``). +They can be tracked by the **Distributed Autograd Engine** like the single-process program we discussed above. + +The only difference between the single-process and distributed training is that users need to explicitly create a **Distributed Autograd Context** and wrap around the forward and backward passes. + +.. code-block:: python + :emphasize-lines: 4, 9, 12 + + import torch + import torch.distributed.autograd as dist_autograd + + with dist_autograd.context() as context_id: + # Forward pass + loss = ... # e.g. remote calls + + # Backward pass + dist_autograd.backward(context_id, [loss]) + + # Retrieve the gradients from the context. + grad_dict = dist_autograd.get_gradients(context_id) # type: Dict[Tensor, Tensor] + +.. warning:: + + Sending |torch.nn.Parameter|_\s over RPC will automatically detach from the autograd graph. + This is an intentional behavior of the PyTorch framework because the |torch.nn.Parameter|_\s are always leaf nodes in the graph. + The leaf tensors will not have ``grad_fn`` attribute and thus cannot be tracked by the autograd engine after sending them to other workers. + + To make the graph can be properly tracked across workers, users should convert the |torch.nn.Parameter|_\s to |torch.Tensor|_\s before sending them over RPC. + For example, explicitly ``clone()`` the parameters to tensors before taking them as arguments of the RPC call. + + .. code-block:: python + + import torch + import torch.distributed.rpc as rpc + + def compute_loss(param): + return param.mean() + + param = torch.nn.Parameter(torch.randn(2, 2), requires_grad=True) + + # The RPC call will detach the parameter from the autograd graph on worker1 + loss1 = rpc.rpc_sync('worker1', compute_loss, args=(param,)) + + # The RPC call will keep connection to the parameter in the autograd graph on worker1 + loss2 = rpc.rpc_sync('worker1', compute_loss, args=(param.clone(),)) + + Users can use :func:`torchopt.module_clone` function to clone the module and convert all its parameters to tensors. + The tensors will have a ``grad_fn`` attribute ``CloneBackward`` to track the computation graph to the original parameters. + + .. code-block:: python + + import torch + import torch.nn as nn + import torchopt + + def compute_loss(model, batch): + ... + return loss + + model = nn.Linear(2, 2) + tuple(model.parameters()) # -> `nn.Parameter`s + + cloned_model = torchopt.module_clone(model, by='clone') + tuple(cloned_model.parameters()) # -> `torch.Tensor`s with `CloneBackward` grad_fn + + # The RPC call will detach the parameter from the autograd graph on worker1 + loss1 = rpc.rpc_sync('worker1', compute_loss, args=(model, batch)) + + # The RPC call will keep the connection to the parameter in the autograd graph on worker1 + loss2 = rpc.rpc_sync('worker1', compute_loss, args=(cloned_model, batch)) + +.. |torch.nn.Parameter| replace:: ``torch.nn.Parameter`` +.. _torch.nn.Parameter: https://pytorch.org/docs/stable/generated/torch.nn.parameter.Parameter.html +.. |torch.Tensor| replace:: ``torch.Tensor`` +.. _torch.Tensor: https://pytorch.org/docs/stable/tensors.html + +TorchOpt wraps the distributed autograd context and provides a more convenient interface to use. + +.. code-block:: python + :emphasize-lines: 5, 10 + + import torchopt.distributed as todist + + model = build_model() + + with todist.autograd.context() as context_id: + # Forward pass + loss = ... # e.g. remote calls + + # Backward pass + grads = todist.autograd.grad(context_id, loss, model.parameters()) + +or + +.. code-block:: python + :emphasize-lines: 7, 13 + + import torch + import torchopt.distributed as todist + + model = build_model() + optimizer = torch.optim.SGD(model.parameters(), lr=lr) + + with todist.autograd.context() as context_id: + # Forward pass + loss = ... # e.g. remote calls + + # Backward pass + optimizer.zero_grad() + todist.autograd.backward(context_id, loss) + optimizer.step() + +.. warning:: + + The distributed autograd context is not thread-safe. + Users should not use the same context in multiple threads. + +Users can update their single-process training code to distributed training code with minimum changes: + +#. Add the distributed autograd context around the forward and backward passes. +#. Wrap the functions with :func:`torchopt.distributed.parallelize` to enable parallel execution. +#. Convert the parameters to tensors before sending them over RPC. +#. Replace the ``torch.autograd`` to ``torchopt.distributed.autograd``. + +Here is a full example of converting the single-process training code to distributed training code: + +.. code-block:: python + :emphasize-lines: 17, 32, 40, 42, 43, 47, 52 + :name: distributed-example + + import torch + import torch.nn as nn + import torchopt.distributed as todist + + def parse_arguments(): + parser = argparse.ArgumentParser(description='TorchOpt Distributed Training') + ... + + args = parser.parse_args() + return args + + def worker_init_fn(): + # set process title, seeding, etc. + setproctitle.setproctitle(f'Worker{todist.get_rank()}') + torch.manual_seed(args.seed + todist.get_rank()) + + @todist.parallelize(partitioner=todist.batch_partitioner, reducer=todist.mean_reducer) + def compute_loss(model, batch): + device = torch.device(f'cuda:{todist.get_local_rank()}') + model = model.to(device) + batch = batch.to(device) + + # Compute local loss of the given batch + ... + return loss.cpu() + + def build_model(): + return nn.Sequential( + ... + ) + + @todist.rank_zero_only + def train(args): + model = build_model() + optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) + train_loader = ... + + for epoch in range(args.epochs): + for batch in train_loader: + with todist.autograd.context() as context_id: + # Forward pass + cloned_model = todist.module_clone(model, by='clone') + loss = compute_loss(cloned_model, batch) + + # Backward pass + optimizer.zero_grad() + todist.autograd.backward(context_id, loss) + + # Update parameters + optimizer.step() + + @todist.auto_init_rpc(worker_init_fn) + def main(): + args = parse_arguments() + train(args) + + if __name__ == '__main__': + main() + +Then, users can use |torchrun|_ to launch the program: + +.. code-block:: bash + + torchrun --nnodes=1 --nproc_per_node=8 YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) diff --git a/docs/source/examples/MAML.rst b/docs/source/examples/MAML.rst index bba6c35a..ee5a638c 100644 --- a/docs/source/examples/MAML.rst +++ b/docs/source/examples/MAML.rst @@ -1,7 +1,7 @@ Model-Agnostic Meta-Learning ============================ -Meta reinforcement learning has achieved significant successes in various applications. +Meta-reinforcement learning has achieved significant successes in various applications. **Model-Agnostic Meta-Learning** (MAML) :cite:`MAML` is the pioneer one. In this tutorial, we will show how to train MAML on few-shot Omniglot classification with TorchOpt step by step. The full script is at :gitcode:`examples/few-shot/maml_omniglot.py`. @@ -63,16 +63,17 @@ TorchOpt supports any user-defined PyTorch networks. Here is an example: net = nn.Sequential( nn.Conv2d(1, 64, 3), - nn.BatchNorm2d(64, momentum=1., affine=True), + nn.BatchNorm2d(64, momentum=1.0, affine=True), nn.ReLU(inplace=False), nn.MaxPool2d(2, 2), nn.Conv2d(64, 64, 3), - nn.BatchNorm2d(64, momentum=1., affine=True), + nn.BatchNorm2d(64, momentum=1.0, affine=True), nn.ReLU(inplace=False), nn.MaxPool2d(2, 2), nn.Conv2d(64, 64, 3), - nn.BatchNorm2d(64, momentum=1., affine=True), - nn.ReLU(inplace=False), nn.MaxPool2d(2, 2), + nn.BatchNorm2d(64, momentum=1.0, affine=True), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), nn.Flatten(), nn.Linear(64, args.n_way), ).to(device) @@ -98,8 +99,7 @@ Define the ``train`` function: # Sample a batch of support and query images and labels. x_spt, y_spt, x_qry, y_qry = db.next() - task_num, setsz, c_, h, w = x_spt.size() - querysz = x_qry.size(1) + task_num = x_spt.size(0) # TODO: Maybe pull this out into a separate module so it # doesn't have to be duplicated between `train` and `test`? @@ -128,28 +128,24 @@ Define the ``train`` function: # These will be used to update the model's meta-parameters. qry_logits = net(x_qry[i]) qry_loss = F.cross_entropy(qry_logits, y_qry[i]) - qry_losses.append(qry_loss.detach()) - qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).sum().item() / querysz - qry_accs.append(qry_acc) - - # Update the model's meta-parameters to optimize the query - # losses across all of the tasks sampled in this batch. - # This unrolls through the gradient steps. - qry_loss.backward() + qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean() + qry_losses.append(qry_loss) + qry_accs.append(qry_acc.item()) torchopt.recover_state_dict(net, net_state_dict) torchopt.recover_state_dict(inner_opt, optim_state_dict) + qry_losses = torch.mean(torch.stack(qry_losses)) + qry_losses.backward() meta_opt.step() - qry_losses = sum(qry_losses) / task_num - qry_accs = 100. * sum(qry_accs) / task_num + qry_losses = qry_losses.item() + qry_accs = 100.0 * np.mean(qry_accs) i = epoch + float(batch_idx) / n_train_iter iter_time = time.time() - start_time print( f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}' ) - log.append( { 'epoch': i, @@ -183,8 +179,7 @@ Define the ``test`` function: for batch_idx in range(n_test_iter): x_spt, y_spt, x_qry, y_qry = db.next('test') - task_num, setsz, c_, h, w = x_spt.size() - querysz = x_qry.size(1) + task_num = x_spt.size(0) # TODO: Maybe pull this out into a separate module so it # doesn't have to be duplicated between `train` and `test`? @@ -203,15 +198,17 @@ Define the ``test`` function: # The query loss and acc induced by these parameters. qry_logits = net(x_qry[i]).detach() - qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction='none') - qry_losses.append(qry_loss.detach()) - qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach()) + qry_loss = F.cross_entropy(qry_logits, y_qry[i]) + qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean() + qry_losses.append(qry_loss.item()) + qry_accs.append(qry_acc.item()) torchopt.recover_state_dict(net, net_state_dict) torchopt.recover_state_dict(inner_opt, optim_state_dict) - qry_losses = torch.cat(qry_losses).mean().item() - qry_accs = 100. * torch.cat(qry_accs).float().mean().item() + qry_losses = np.mean(qry_losses) + qry_accs = 100.0 * np.mean(qry_accs) + print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}') log.append( { diff --git a/docs/source/explicit_diff/explicit_diff.rst b/docs/source/explicit_diff/explicit_diff.rst new file mode 100644 index 00000000..28e06f77 --- /dev/null +++ b/docs/source/explicit_diff/explicit_diff.rst @@ -0,0 +1,169 @@ +Explicit Gradient Differentiation +================================= + +.. currentmodule:: torchopt + +Explicit Gradient +----------------- + +.. image:: /_static/images/explicit-gradient.png + :width: 80% + :align: center + +The idea of explicit gradient is to treat the gradient step as a differentiable function and try to backpropagate through the unrolled optimization path. +Namely, given + +.. math:: + + \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi}) \triangleq \boldsymbol{\theta}_0 - \alpha \sum_{i=0}^{K-1} \nabla_{\boldsymbol{\theta}_i} \mathcal{L}^{\text{in}} (\boldsymbol{\phi},\boldsymbol{\theta}_i), + +we would like to compute the gradient :math:`\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})`. +This is usually done by AutoDiff through an inner optimization's unrolled iterates. + +Differentiable Functional Optimizers +------------------------------------ + +By passing the argument ``inplace`` as :data:`False` to the ``update`` functions, we can make the optimization differentiable. +Here is an example of making :func:`torchopt.adam` differentiable. + +.. code-block:: python + + opt = torchopt.adam() + # Define meta and inner parameters + meta_params = ... + fmodel, params = make_functional(model) + # Initialize optimizer state + state = opt.init(params) + + for iter in range(iter_times): + loss = inner_loss(fmodel, params, meta_params) + grads = torch.autograd.grad(loss, params) + # Apply non-inplace parameter update + updates, state = opt.update(grads, state, inplace=False) + params = torchopt.apply_updates(params, updates) + + loss = outer_loss(fmodel, params, meta_params) + meta_grads = torch.autograd.grad(loss, meta_params) + +Differentiable OOP Meta-Optimizers +---------------------------------- + +For PyTorch-like API (e.g., ``step()``), we designed a base class :class:`torchopt.MetaOptimizer` to wrap our functional optimizers to become differentiable OOP meta-optimizers. + +.. autosummary:: + + torchopt.MetaOptimizer + torchopt.MetaAdaDelta + torchopt.MetaAdadelta + torchopt.MetaAdaGrad + torchopt.MetaAdagrad + torchopt.MetaAdam + torchopt.MetaAdamW + torchopt.MetaAdaMax + torchopt.MetaAdamax + torchopt.MetaRAdam + torchopt.MetaRMSProp + torchopt.MetaSGD + +By combining low-level API :class:`torchopt.MetaOptimizer` with the previous functional optimizer, we can achieve high-level API: + +.. code-block:: python + + # Low-level API + optim = torchopt.MetaOptimizer(net, torchopt.sgd(lr=1.0)) + + # High-level API + optim = torchopt.MetaSGD(net, lr=1.0) + +Here is an example of using the OOP API :class:`torchopt.MetaAdam` to conduct meta-gradient calculation. + +.. code-block:: python + + # Define meta and inner parameters + meta_params = ... + model = ... + # Define differentiable optimizer + opt = torchopt.MetaAdam(model) + + for iter in range(iter_times): + # Perform the inner update + loss = inner_loss(model, meta_params) + opt.step(loss) + + loss = outer_loss(model, meta_params) + loss.backward() + +CPU/GPU Accelerated Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +TorchOpt performs the symbolic reduction by manually writing the forward and backward functions using C++ OpenMP (CPU) and CUDA (GPU), which largely increase meta-gradient computational efficiency. +Users can use accelerated optimizer by setting the ``use_accelerated_op`` as :data:`True`. +TorchOpt will automatically detect the device and allocate the corresponding accelerated optimizer. + +.. code-block:: python + + # Check whether the `accelerated_op` is available: + torchopt.accelerated_op_available(torch.device('cpu')) + + torchopt.accelerated_op_available(torch.device('cuda')) + + net = Net(1).cuda() + optim = torchopt.Adam(net.parameters(), lr=1.0, use_accelerated_op=True) + +General Utilities +----------------- + +We provide the :func:`torchopt.extract_state_dict` and :func:`torchopt.recover_state_dict` functions to extract and restore the state of network and optimizer. +By default, the extracted state dictionary is a reference (this design is for accumulating gradient of multi-task batch training, MAML for example). +You can also set ``by='copy'`` to extract the copy of the state dictionary or set ``by='deepcopy'`` to have a detached copy. + +.. autosummary:: + + torchopt.extract_state_dict + torchopt.recover_state_dict + torchopt.stop_gradient + +Here is an usage example. + +.. code-block:: python + + net = Net() + x = nn.Parameter(torch.tensor(2.0), requires_grad=True) + + optim = torchopt.MetaAdam(net, lr=1.0) + + # Get the reference of state dictionary + init_net_state = torchopt.extract_state_dict(net, by='reference') + init_optim_state = torchopt.extract_state_dict(optim, by='reference') + # If set `detach_buffers=True`, the parameters are referenced as references while buffers are detached copies + init_net_state = torchopt.extract_state_dict(net, by='reference', detach_buffers=True) + + # Set `copy` to get the copy of the state dictionary + init_net_state_copy = torchopt.extract_state_dict(net, by='copy') + init_optim_state_copy = torchopt.extract_state_dict(optim, by='copy') + + # Set `deepcopy` to get the detached copy of state dictionary + init_net_state_deepcopy = torchopt.extract_state_dict(net, by='deepcopy') + init_optim_state_deepcopy = torchopt.extract_state_dict(optim, by='deepcopy') + + # Conduct 2 inner-loop optimization + for i in range(2): + inner_loss = net(x) + optim.step(inner_loss) + + print(f'a = {net.a!r}') + + # Recover and reconduct 2 inner-loop optimization + torchopt.recover_state_dict(net, init_net_state) + torchopt.recover_state_dict(optim, init_optim_state) + + for i in range(2): + inner_loss = net(x) + optim.step(inner_loss) + + print(f'a = {net.a!r}') # the same result + +Notebook Tutorial +----------------- + +Check the notebook tutorials at `Meta Optimizer `_ and `Stop Gradient `_. diff --git a/docs/source/implicit_diff/implicit_diff.rst b/docs/source/implicit_diff/implicit_diff.rst new file mode 100644 index 00000000..5544c25f --- /dev/null +++ b/docs/source/implicit_diff/implicit_diff.rst @@ -0,0 +1,246 @@ +Implicit Gradient Differentiation +================================= + +.. currentmodule:: torchopt.diff.implicit + +Implicit Differentiation +------------------------ + +.. image:: /_static/images/implicit-gradient.png + :width: 80% + :align: center + +Implicit differentiation is the task of differentiating through the solution of an optimization problem satisfying a mapping function :math:`T` capturing the optimality conditions of the problem. +The simplest example is to differentiate through the solution of a minimization problem with respect to its inputs. +Namely, given + +.. math:: + + \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi}) \triangleq \underset{\boldsymbol{\theta}}{\mathop{\operatorname{argmin}}} ~ \mathcal{L}^{\text{in}} (\boldsymbol{\phi},\boldsymbol{\theta}). + +By treating the solution :math:`\boldsymbol{\theta}^{\prime}` as an implicit function of :math:`\boldsymbol{\phi}`, the idea of implicit differentiation is to directly get analytical best-response derivatives :math:`\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})` by the implicit function theorem. + +Root Finding +~~~~~~~~~~~~ + +This is suitable for algorithms when the inner-level optimality conditions :math:`T` is defined by a root of a function, such as: + +.. math:: + + T (\boldsymbol{\phi}, \boldsymbol{\theta}) = \frac{ \partial \mathcal{L}^{\text{in}} (\boldsymbol{\phi}, \boldsymbol{\theta})}{\partial \boldsymbol{\theta}}, \qquad T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})) = \left. \frac{ \partial \mathcal{L}^{\text{in}} (\boldsymbol{\phi}, \boldsymbol{\theta})}{\partial \boldsymbol{\theta}} \right\rvert_{\boldsymbol{\theta} = \boldsymbol{\theta}^{\prime}} = \boldsymbol{0}. + +In `IMAML `_, the function :math:`F` in the figure means the inner-level optimal solution is obtained by unrolled gradient update: + +.. math:: + + \boldsymbol{\theta}_{k + 1} = F (\boldsymbol{\phi}, \boldsymbol{\theta}_k) = \boldsymbol{\theta}_k - \alpha \nabla_{\boldsymbol{\theta}_k} \mathcal{L}^{\text{in}} (\boldsymbol{\phi}, \boldsymbol{\theta}_k). + +Fixed-point Iteration +~~~~~~~~~~~~~~~~~~~~~ + +Sometimes the inner-level optimal solution can also be achieved by fixed point where the optimality :math:`T` takes the form: + +.. math:: + + \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi}) = F (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})) \quad \Longleftrightarrow \quad T (\boldsymbol{\phi}, \boldsymbol{\theta}) = F (\boldsymbol{\phi}, \boldsymbol{\theta}) - \boldsymbol{\theta}, \quad T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})) = \boldsymbol{0}. + +In `DEQ `_, the function :math:`F` in the figure means the inner-level optimal solution is obtained by fixed point update: + +.. math:: + + \boldsymbol{\theta}_{k + 1} = F (\boldsymbol{\phi}, \boldsymbol{\theta}_k). + +This can be seen as a particular case of root of function by defining the optimality function as :math:`T (\boldsymbol{\phi}, \boldsymbol{\theta}) = F (\boldsymbol{\phi}, \boldsymbol{\theta}) - \boldsymbol{\theta}`. +This can be implemented with: + +.. code-block:: python + + def fixed_point_function(phi: TensorTree, theta: TensorTree) -> TensorTree: + ... + return new_theta + + # A root function can be derived from the fixed point function + def root_function(phi: TensorTree, theta: TensorTree) -> TensorTree: + new_theta = fixed_point_function(phi, theta) + return torchopt.pytree.tree_sub(new_theta, theta) + +Custom Solvers +-------------- + +.. autosummary:: + + torchopt.diff.implicit.custom_root + +Let :math:`T (\boldsymbol{\phi}, \boldsymbol{\theta}): \mathbb{R}^n \times \mathbb{R}^d \to \mathbb{R}^d` be a user-provided mapping function, that captures the optimality conditions of a problem. +An optimal solution, denoted :math:`\boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})`, should be a root of :math:`T`: + +.. math:: + + T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi})) = \boldsymbol{0}. + +We can see :math:`\boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})` as an implicitly defined function of :math:`\boldsymbol{\phi} \in \mathbb{R}^n`, i.e., :math:`\boldsymbol{\theta}^{\prime}: \mathbb{R}^n \rightarrow \mathbb{R}^d`. +More precisely, from the `implicit function theorem `_, we know that for :math:`(\boldsymbol{\phi}_0, \boldsymbol{\theta}^{\prime}_0)` satisfying :math:`T (\boldsymbol{\phi}_0, \boldsymbol{\theta}^{\prime}_0) = \boldsymbol{0}` with a continuously differentiable :math:`T`, if the Jacobian :math:`\nabla_{\boldsymbol{\theta}^{\prime}} T` evaluated at :math:`(\boldsymbol{\phi}_0, \boldsymbol{\theta}^{\prime}_0)` is a square invertible matrix, then there exists a function :math:`\boldsymbol{\theta}^{\prime} (\cdot)` defined on a neighborhood of :math:`\boldsymbol{\phi}_0` such that :math:`\boldsymbol{\theta}^{\prime} (\boldsymbol{\phi}_0) = \boldsymbol{\theta}^{\prime}_0`. +Furthermore, for all :math:`\boldsymbol{\phi}` in this neighborhood, we have that :math:`T (\boldsymbol{\phi}_0, \boldsymbol{\theta}^{\prime}_0) = \boldsymbol{0}` and :math:`\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})` exists. Using the chain rule, the Jacobian :math:`\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi})` satisfies: + +.. math:: + + \frac{d T}{d \boldsymbol{\phi}} = \underbrace{\nabla_{\boldsymbol{\theta}^{\prime}} T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi}))}_{\frac{\partial T}{\partial \boldsymbol{\theta}^{\prime}}} \underbrace{\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})}_{\frac{d \boldsymbol{\theta}^{\prime}}{d \boldsymbol{\phi}}} + \underbrace{\nabla_{\boldsymbol{\phi}} T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi}))}_{\frac{\partial T}{\partial \boldsymbol{\phi}}} = \boldsymbol{0}. \qquad ( T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime}) = \boldsymbol{0} = \text{const}) + +Computing :math:`\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi})` therefore boils down to the resolution of the linear system of equations + +.. math:: + + \underbrace{\nabla_{\boldsymbol{\theta}^{\prime}} T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi}))}_{A \in \mathbb{R}^{d \times d}} \underbrace{\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})}_{J \in \mathbb{R}^{d \times n}} = \underbrace{- \nabla_{\boldsymbol{\phi}} T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi}))}_{B \in \mathbb{R}^{d \times n}}. + +TorchOpt provides a decorator function :func:`custom_root`, for easily adding implicit differentiation on top of any existing inner optimization solver (also called forward optimization). +The :func:`custom_root` decorator requires users to define the stationary conditions for the problem solution (e.g., `KKT conditions `_) and will automatically calculate the gradient for backward gradient computation. + +Here is an example of the :func:`custom_root` decorators, which is also the **functional API** for implicit gradient. + +.. code-block:: python + + # Functional API for implicit gradient + def stationary(params, meta_params, data): + # stationary condition construction + return stationary condition + + # Decorator that wraps the function + # Optionally specify the linear solver (conjugate gradient or Neumann series) + @torchopt.diff.implicit.custom_root(stationary) + def solve(params, meta_params, data): + # Forward optimization process for params + return optimal_params + + # Define params, meta_params and get data + params, meta_prams, data = ..., ..., ... + optimal_params = solve(params, meta_params, data) + loss = outer_loss(optimal_params) + + meta_grads = torch.autograd.grad(loss, meta_params) + +OOP API +~~~~~~~ + +.. autosummary:: + + torchopt.nn.ImplicitMetaGradientModule + +Coupled with PyTorch |torch.nn.Module|_, we also design the OOP API :class:`nn.ImplicitMetaGradientModule` for implicit gradient. +The core idea of :class:`nn.ImplicitMetaGradientModule` is to enable the gradient flow from ``self.parameters()`` (usually lower-level parameters) to ``self.meta_parameters()`` (usually the high-level parameters). +Users need to define the forward process ``forward()``, a stationary function ``optimality()`` (or ``objective()``), and inner-loop optimization ``solve``. + +.. |torch.nn.Module| replace:: ``torch.nn.Module`` +.. _torch.nn.Module: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module + +Here is an example of the OOP API. + +.. code-block:: python + + from torchopt.nn import ImplicitMetaGradientModule + + # Inherited from the class ImplicitMetaGradientModule + class InnerNet(ImplicitMetaGradientModule): + def __init__(self, meta_module): + ... + + def forward(self, batch): + # Forward process + ... + + def optimality(self, batch, labels): + # Stationary condition construction for calculating implicit gradient + # NOTE: If this method is not implemented, it will be automatically derived from the + # gradient of the `objective` function. + ... + + def objective(self, batch, labels): + # Define the inner-loop optimization objective + # NOTE: This method is optional if method `optimality` is implemented. + ... + + def solve(self, batch, labels): + # Conduct the inner-loop optimization + ... + return self # optimized module + + # Get meta_params and data + meta_params, data = ..., ... + inner_net = InnerNet() + + # Solve for inner-loop process related to the meta-parameters + optimal_inner_net = inner_net.solve(meta_params, *data) + + # Get outer-loss and solve for meta-gradient + loss = outer_loss(optimal_inner_net) + meta_grad = torch.autograd.grad(loss, meta_params) + +If the optimization objective is to minimize/maximize an objective function, we offer an ``objective`` method interface to simplify the implementation. +Users only need to define the ``objective`` method, while TorchOpt will automatically analyze it for the stationary (optimality) condition from the KKT condition. + +.. note:: + + In ``__init__`` method, users need to define the inner parameters and meta-parameters. + By default, :class:`nn.ImplicitMetaGradientModule` treats all tensors and modules from the method inputs as ``self.meta_parameters()`` / ``self.meta_modules()``. + For example, statement ``self.yyy = xxx`` will assign ``xxx`` as a meta-parameter with name ``'yyy'`` if ``xxx`` is present in the method inputs (e.g., ``def __init__(self, xxx, ...): ...``). + All tensors and modules defined in the ``__init__`` are regarded as ``self.parameters()`` / ``self.modules()``. + Users can also register parameters and meta-parameters by calling ``self.register_parameter()`` and ``self.register_meta_parameter()`` respectively. + +Linear System Solvers +--------------------- + +.. autosummary:: + + torchopt.linear_solve.solve_cg + torchopt.linear_solve.solve_inv + torchopt.linear_solve.solve_normal_cg + +Usually, the computation of implicit gradient involves the computation of the inverse Hessian matrix. +However, the high-dimensional Hessian matrix also makes direct computation intractable, and this is where linear solver comes into play. +By iteratively solving the linear system problem, we can calculate the inverse Hessian matrix up to some precision. We offer the `conjugate-gradient `_ based solver and `neuman-series `_ based solver. + +Here is an example of the linear solver. + +.. code-block:: python + + import torch + from torchopt import linear_solve + + torch.manual_seed(42) + A = torch.randn(3, 3) + b = torch.randn(3) + + def matvec(x): + return torch.matmul(A, x) + + solve_fn = linear_solve.solve_normal_cg(atol=1e-5) + solution = solve_fn(matvec, b) + print(solution) + + solve_fn = linear_solve.solve_cg(atol=1e-5) + solution = solve_fn(matvec, b) + print(solution) + +Users can also select the corresponding solver in functional and OOP APIs. + +.. code-block:: python + + # For functional API + @torchopt.diff.implicit.custom_root( + functorch.grad(objective_fn, argnums=0), # optimality function + argnums=1, + solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0), + ) + def solve_fn(...): + ... + + # For OOP API + class InnerNet( + torchopt.nn.ImplicitMetaGradientModule, + linear_solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0), + ): + ... + +Notebook Tutorial +----------------- + +Check the notebook tutorial at `Implicit Differentiation `_. diff --git a/docs/source/index.rst b/docs/source/index.rst index fd488b6e..83602090 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -3,20 +3,23 @@ TorchOpt -------- -**TorchOpt** is a high-performance optimizer library built upon `PyTorch `_ for easy implementation of functional optimization and gradient-based meta-learning. It consists of two main features: +**TorchOpt** is an efficient library for differentiable optimization built upon `PyTorch `_. +Torchopt is -* TorchOpt provides functional optimizer which enables `JAX-like `_ composable functional optimizer for PyTorch. With TorchOpt, one can easily conduct neural network optimization in PyTorch with functional style optimizer, similar to `Optax `_ in JAX. -* With the design of functional programming, TorchOpt provides efficient, flexible, and easy-to-implement differentiable optimizer for gradient-based meta-learning research. It largely reduces the efforts required to implement sophisticated meta-learning algorithms. +- **Comprehensive**: TorchOpt provides three differentiation modes - explicit differentiation, implicit differentiation, and zero-order differentiation for handling different differentiable optimization situations. +- **Flexible**: TorchOpt provides both functional and objective-oriented API for users different preferences. Users can implement differentiable optimization in JAX-like or PyTorch-like style. +- **Efficient**: TorchOpt provides (1) CPU/GPU acceleration differentiable optimizer (2) RPC-based distributed training framework (3) Fast Tree Operations, to largely increase the training efficiency for bi-level optimization problems. Installation ------------ Requirements: -* `PyTorch `_ -* (Optional) `Graphviz `_ +- `PyTorch `_ +- (Optional) `Graphviz `_ -Please follow the instructions at https://pytorch.org to install PyTorch in your Python environment first. Then run the following command to install TorchOpt from PyPI: +Please follow the instructions at https://pytorch.org to install PyTorch in your Python environment first. +Then run the following command to install TorchOpt from PyPI: .. code-block:: bash @@ -30,7 +33,8 @@ You can also build shared libraries from source, use: cd torchopt pip3 install . -We provide a `conda `_ environment recipe to install the build toolchain such as `cmake`, `g++`, and `nvcc`: +We provide a `conda `_ environment recipe to install the build toolchain such as ``cmake``, ``g++``, and ``nvcc``. +You can use the following commands with `conda `_ / `mamba `_ to create a new isolated environment. .. code-block:: bash @@ -38,49 +42,58 @@ We provide a `conda `_ environment recipe to ins cd torchopt # You may need `CONDA_OVERRIDE_CUDA` if conda fails to detect the NVIDIA driver (e.g. in docker or WSL2) - CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe.yaml + CONDA_OVERRIDE_CUDA=12.1 conda env create --file conda-recipe-minimal.yaml conda activate torchopt - .. toctree:: - :caption: Getting Started - :maxdepth: 1 - - torchopt101/torchopt-101.rst + :maxdepth: 1 + :caption: Documentation + basics/basics.rst + optimizer/optim.rst + explicit_diff/explicit_diff.rst + implicit_diff/implicit_diff.rst + zero_order_diff/zero_order_diff.rst + distributed/distributed.rst + visualization/visualization.rst .. toctree:: - :caption: Examples - :maxdepth: 1 + :caption: Tutorial Notebooks + :maxdepth: 1 + + torchopt101/torchopt-101.rst - examples/MAML.rst +.. toctree:: + :caption: Examples + :maxdepth: 1 + examples/MAML.rst .. toctree:: - :caption: Developer Documentation - :maxdepth: 1 + :caption: Developer Documentation + :maxdepth: 1 - developer/contributing.rst - developer/contributor.rst + developer/contributing.rst + developer/contributor.rst .. toctree:: - :caption: API Documentation - :maxdepth: 2 + :caption: API Documentation + :maxdepth: 2 - api/api.rst + api/api.rst The Team -------- TorchOpt is a work by -* Jie Ren (`JieRen98 `_) -* Xidong Feng (`waterhorse1 `_) -* Bo Liu (`Benjamin-eecs `_) -* Xuehai Pan (`XuehaiPan `_) -* Luo Mai (`luomai `_) -* Yaodong Yang (`PKU-YYang `_). +- Jie Ren (`JieRen98 `_) +- Xidong Feng (`waterhorse1 `_) +- Bo Liu (`Benjamin-eecs `_) +- Xuehai Pan (`XuehaiPan `_) +- Luo Mai (`luomai `_) +- Yaodong Yang (`PKU-YYang `_). Support ------- @@ -97,3 +110,27 @@ License ------- TorchOpt is licensed under the Apache 2.0 License. + +Citing +------ + +If you find TorchOpt useful, please cite it in your publications. + +.. code-block:: bibtex + + @article{JMLR:TorchOpt, + author = {Jie Ren* and Xidong Feng* and Bo Liu* and Xuehai Pan* and Yao Fu and Luo Mai and Yaodong Yang}, + title = {TorchOpt: An Efficient Library for Differentiable Optimization}, + journal = {Journal of Machine Learning Research}, + year = {2023}, + volume = {24}, + number = {367}, + pages = {1--14}, + url = {http://jmlr.org/papers/v24/23-0191.html} + } + + +Indices and tables +------------------ + +- :ref:`genindex` diff --git a/docs/source/optimizer/optim.rst b/docs/source/optimizer/optim.rst new file mode 100644 index 00000000..4f2e17f8 --- /dev/null +++ b/docs/source/optimizer/optim.rst @@ -0,0 +1,205 @@ +Optimizers +========== + +.. currentmodule:: torchopt + +The core design of TorchOpt follows the philosophy of functional programming. +Aligned with |functorch|_, users can conduct functional-style programming with models, optimizers, and training in PyTorch. +We first introduce our functional optimizers, which treat the optimization process as a functional transformation. + +.. |functorch| replace:: ``functorch`` +.. _functorch: https://pytorch.org/functorch + +Functional Optimizers +--------------------- + +Currently, TorchOpt supports 4 functional optimizers: :func:`sgd`, :func:`adam`, :func:`rmsprop`, and :func:`adamw`. + +.. autosummary:: + + torchopt.FuncOptimizer + torchopt.adadelta + torchopt.adagrad + torchopt.adam + torchopt.adamw + torchopt.adamax + torchopt.radam + torchopt.rmsprop + torchopt.sgd + +Apply Parameter Updates +----------------------- + +TorchOpt offers functional API by passing gradients and optimizer states to the optimizer function to apply updates. + +.. autosummary:: + + torchopt.apply_updates + +Here is an example of functional optimization coupled with |functorch|_: + +.. code-block:: python + + class Net(nn.Module): ... + + class Loader(DataLoader): ... + + net = Net() # init + loader = Loader() + optimizer = torchopt.adam(lr) + + model, params = functorch.make_functional(net) # use functorch extract network parameters + opt_state = optimizer.init(params) # init optimizer + + xs, ys = next(loader) # get data + pred = model(params, xs) # forward + loss = F.cross_entropy(pred, ys) # compute loss + + grads = torch.autograd.grad(loss, params) # compute gradients + updates, opt_state = optimizer.update(grads, opt_state) # get updates + params = torchopt.apply_updates(params, updates) # update network parameters + +We also provide a wrapper :class:`torchopt.FuncOptimizer` to make maintaining the optimizer state easier: + +.. code-block:: python + + net = Net() # init + loader = Loader() + optimizer = torchopt.FuncOptimizer(torchopt.adam()) # wrap with `torchopt.FuncOptimizer` + + model, params = functorch.make_functional(net) # use functorch extract network parameters + + for xs, ys in loader: # get data + pred = model(params, xs) # forward + loss = F.cross_entropy(pred, ys) # compute loss + + params = optimizer.step(loss, params) # update network parameters + +Classic OOP Optimizers +---------------------- + +Combined with the functional optimizers above, we can define our classic OOP optimizers. +We designed a base class :class:`torchopt.Optimizer` that has the same interface as |torch.optim.Optimizer|_. +We offer original PyTorch APIs (e.g., ``zero_grad()`` or ``step()``) for traditional PyTorch-like (OOP) parameter update. + +.. |torch.optim.Optimizer| replace:: ``torch.optim.Optimizer`` +.. _torch.optim.Optimizer: https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer + +.. autosummary:: + + torchopt.Optimizer + torchopt.AdaDelta + torchopt.Adadelta + torchopt.AdaGrad + torchopt.Adagrad + torchopt.Adam + torchopt.AdamW + torchopt.AdaMax + torchopt.Adamax + torchopt.RAdam + torchopt.RMSProp + torchopt.SGD + + +By combining low-level API :class:`torchopt.Optimizer` with the previous functional optimizer, we can achieve high-level API: + +.. code-block:: python + + learning_rate = 1.0 + # High-level API + optim = torchopt.Adam(net.parameters(), lr=learning_rate) + # which can be achieved by low-level API: + optim = torchopt.Optimizer(net.parameters(), torchopt.adam(lr=learning_rate)) + +Here is an example of PyTorch-like APIs: + +.. code-block:: python + + net = Net() # init + loader = Loader() + optimizer = torchopt.Adam(net.parameters()) + + xs, ys = next(loader) # get data + pred = net(xs) # forward + loss = F.cross_entropy(pred, ys) # compute loss + + optimizer.zero_grad() # zero gradients + loss.backward() # backward + optimizer.step() # step updates + +Combining Transformation +------------------------ + +Users always need to conduct multiple gradient transformations (functions) before the final update. +In the designing of TorchOpt, we treat these functions as derivations of :func:`torchopt.chain`. +So we can build our own chain like ``torchopt.chain(torchopt.clip_grad_norm(max_norm=1.), torchopt.sgd(lr=1., moment_requires_grad=True))`` to clip the gradient and update parameters using :func:`sgd`. + +.. autosummary:: + + torchopt.chain + +.. note:: + + :func:`torchopt.chain` will sequentially conduct transformations, so the order matters. + For example, we need to first conduct gradient normalization and then conduct the optimizer step. + The order should be (clip, sgd) in :func:`torchopt.chain` function. + + +Here is an example of chaining :func:`torchopt.clip_grad_norm` and :func:`torchopt.adam` for functional optimizer and OOP optimizer. + +.. code-block:: python + + func_optimizer = torchopt.chain(torchopt.clip_grad_norm(max_norm=2.0), torchopt.adam(1e-1)) + oop_optimizer = torchopt.Optimizer(net.parameters() func_optimizer) + +Optimizer Hooks +--------------- + +Users can also add optimizer hook to control the gradient flow. + +.. autosummary:: + + torchopt.hook.register_hook + torchopt.hook.zero_nan_hook + torchopt.hook.nan_to_num_hook + +For example, :func:`torchopt.hook.zero_nan_hook` registers hook to the first-order gradients. +During the backpropagation, the **NaN** gradients will be set to 0. +Here is an example of such operation coupled with :func:`torchopt.chain`. + +.. code-block:: python + + impl = torchopt.chain(torchopt.hook.register_hook(torchopt.hook.zero_nan_hook), torchopt.adam(1e-1)) + +Optimizer Schedules +------------------- + +TorchOpt also provides implementations of learning rate schedulers, which can be used to control the learning rate during the training process. +TorchOpt mainly offers the linear learning rate scheduler and the polynomial learning rate scheduler. + +.. autosummary:: + + torchopt.schedule.linear_schedule + torchopt.schedule.polynomial_schedule + +Here is an example of combining optimizer with learning rate scheduler. + +.. code-block:: python + + functional_adam = torchopt.adam( + lr=torchopt.schedule.linear_schedule( + init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000 + ) + ) + + adam = torchopt.Adam( + net.parameters(), + lr=torchopt.schedule.linear_schedule( + init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000 + ), + ) + +Notebook Tutorial +----------------- + +Check the notebook tutorial at `Functional Optimizer `_. diff --git a/docs/source/spelling_wordlist.txt b/docs/source/spelling_wordlist.txt index ca34dd05..6e0cca78 100644 --- a/docs/source/spelling_wordlist.txt +++ b/docs/source/spelling_wordlist.txt @@ -26,7 +26,7 @@ Pan Yao Fu Jupyter -Colaboratory +Colab Omniglot differentiable Dataset @@ -56,10 +56,14 @@ iterable nan param Graphviz +Autograd autograd attrs +GradientTransformation GradientTransformations args +kwargs +kwds chainable adam Adam @@ -76,5 +80,105 @@ Moens AdamW Loshchilov pytree +pytrees booleans subtrees +optimality +argnums +matvec +hermitian +deepcopy +deepclone +RRef +rref +ints +Karush +Kuhn +Tucker +Neumann +KKT +num +posinf +neginf +backpropagated +backpropagating +backpropagation +backprop +fmt +pragma +broadcasted +keepdim +ndim +partitioner +partitioners +RPC +rpc +MPI +async +parallelization +unaggregated +maxiter +str +bool +algo +const +attr +sys +recurse +boldsymbol +optim +optimizer's +stateful +preload +submodules +prepend +jit +compilable +RMS +LLC +ns +th +treespec +namespace +atol +rtol +pre +numerics +parallelize +parallelizing +JAX +Optax +func +subfn +vjp +jvp +ATen +samplable +conj +TransformInitFn +TransformUpdateFn +argmin +Jacobian +autodiff +backend +reparametrize +reparameterize +rtype +backpropagate +NaN +iteratively +issubclass +abc +ABCMeta +subclasscheck +ctx +Duchi +invertible +AdaGrad +Adadelta +Zeiler +radam +adamax +RAdam +AdaDelta +AdaMax diff --git a/docs/source/torchopt101/torchopt-101.rst b/docs/source/torchopt101/torchopt-101.rst index 87bffd4c..89809691 100644 --- a/docs/source/torchopt101/torchopt-101.rst +++ b/docs/source/torchopt101/torchopt-101.rst @@ -1,9 +1,11 @@ Get Started with Jupyter Notebook ================================= -In this tutorial, we will use Google Colaboratory to show you the most basic usages of TorchOpt. +In this tutorial, we will use Google Colab notebooks to show you the most basic usages of TorchOpt. -- 1: `Functional Optimizer `_ -- 2: `Visualization `_ -- 3: `Meta Optimizer `_ -- 4: `Stop Gradient `_ +- 1: `Functional Optimizer `_ +- 2: `Visualization `_ +- 3: `Meta-Optimizer `_ +- 4: `Stop Gradient `_ +- 5: `Implicit Differentiation `_ +- 6: `Zero-order Differentiation `_ diff --git a/docs/source/visualization/visualization.rst b/docs/source/visualization/visualization.rst new file mode 100644 index 00000000..718c6725 --- /dev/null +++ b/docs/source/visualization/visualization.rst @@ -0,0 +1,146 @@ +Visualization +============= + +.. currentmodule:: torchopt.visual + +In `PyTorch `_, if the attribute ``requires_grad`` of a tensor is :data:`True`, the computation graph will be created if we use the tensor to do any operations. +The computation graph is implemented like a link list -- ``Tensors`` are nodes and they are linked by their attribute ``gran_fn``. +`PyTorchViz `_ is a Python package that uses `Graphviz `_ as a backend for plotting computation graphs. +TorchOpt uses PyTorchViz as the blueprint and provides more easy-to-use visualization functions on the premise of supporting all its functions. + +------ + +Usage +----- + +Let's start with a simple multiplication computation graph. +We declared the variable ``x`` with the flag ``requires_grad=True`` and compute ``y = 2 * x``. Then we visualize the computation graph of ``y``. + +We provide the function :func:`make_dot` which takes a tensor as input. +The visualization code is shown as follows: + +.. code-block:: python + + from IPython.display import display + import torch + import torchopt + + + x = torch.tensor(1.0, requires_grad=True) + y = 2 * x + display(torchopt.visual.make_dot(y)) + +.. image:: /_static/images/visualization-fig1.svg + :width: 20% + :align: center + +The figure shows ``y`` is connected by the multiplication edge. +The gradient of ``y`` will flow through the multiplication backward function and then accumulate on ``x``. +Note that we pass a dictionary for adding node labels. + +To add auxiliary notes to the computation graph, we can pass a dictionary as argument ``params`` to :func:`make_dot`. +The keys are the notes which would be shown in the computation figure and the values are the tensors that need to be noted. +So the code above can be modified as follows: + +.. code-block:: python + + from IPython.display import display + import torch + import torchopt + + + x = torch.tensor(1.0, requires_grad=True) + y = 2 * x + display(torchopt.visual.make_dot(y, params={'x': x, 'y': y})) + +Then let's plot a neural network. +Note that we can pass the generator returned by the method ``named_parameters`` for adding node labels. + +.. code-block:: python + + from IPython.display import display + import torch + from torch import nn + import torchopt + + + class Net(nn.Module): + def __init__(self, dim): + super().__init__() + self.fc = nn.Linear(dim, 1, bias=True) + + def forward(self, x): + return self.fc(x) + + + dim = 5 + batch_size = 2 + net = Net(dim) + xs = torch.ones((batch_size, dim)) + ys = torch.ones((batch_size, 1)) + pred = net(xs) + loss = F.mse_loss(pred, ys) + + display(torchopt.visual.make_dot(loss, params=(net.named_parameters(), {'loss': loss}))) + +.. image:: /_static/images/visualization-fig2.svg + :width: 45% + :align: center + +The computation graph of meta-learning algorithms will be much more complex. +Our visualization tool allows users to take as input the extracted network state for better visualization. + +.. code-block:: python + + from IPython.display import display + import torch + from torch import nn + import torchopt + + class MetaNet(nn.Module): + def __init__(self, dim): + super().__init__() + self.fc = nn.Linear(dim, 1, bias=True) + + def forward(self, x, meta_param): + return self.fc(x) + meta_param + + + dim = 5 + batch_size = 2 + net = MetaNet(dim) + + xs = torch.ones((batch_size, dim)) + ys = torch.ones((batch_size, 1)) + + optimizer = torchopt.MetaSGD(net, lr=1e-3) + meta_param = torch.tensor(1.0, requires_grad=True) + + # Set enable_visual + net_state_0 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step0.') + + pred = net(xs, meta_param) + loss = F.mse_loss(pred, ys) + optimizer.step(loss) + + # Set enable_visual + net_state_1 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step1.') + + pred = net(xs, meta_param) + loss = F.mse_loss(pred, torch.ones_like(pred)) + + # Draw computation graph + display( + torchopt.visual.make_dot( + loss, [net_state_0, net_state_1, {'meta_param': meta_param, 'loss': loss}] + ) + ) + +.. image:: /_static/images/visualization-fig3.svg + :width: 65% + :align: center + +Notebook Tutorial +----------------- + +Check the notebook tutorial at `Visualization `_. diff --git a/docs/source/zero_order_diff/zero_order_diff.rst b/docs/source/zero_order_diff/zero_order_diff.rst new file mode 100644 index 00000000..4cc7a034 --- /dev/null +++ b/docs/source/zero_order_diff/zero_order_diff.rst @@ -0,0 +1,146 @@ +Zero-order Gradient Differentiation +=================================== + +.. currentmodule:: torchopt.diff.zero_order + +Evolutionary Strategy +--------------------- + +.. image:: /_static/images/zero-order.png + :width: 80% + :align: center + +When the inner-loop process is non-differentiable or one wants to eliminate the heavy computation burdens in the previous two modes (brought by Hessian), one can choose Zeroth-order differentiation. +Zero-order differentiation typically gets gradients based on zero-order estimation, such as finite-difference, or `Evolutionary Strategy `_ (ES). +`ES-MAML `_ and `NAC `_ successfully solve the non-differentiable optimization problem based on ES. + +TorchOpt offers API for ES-based differentiation. +Instead of optimizing the objective :math:`f (\boldsymbol{\theta}): \mathbb{R}^n \to \mathbb{R}`, ES optimizes a Gaussian smoothing objective defined as :math:`\tilde{f}_{\sigma} (\boldsymbol{\theta}) = \mathbb{E}_{\boldsymbol{z} \sim \mathcal{N}( 0, {I}_d )} [ f (\boldsymbol{\theta} + \sigma \, \boldsymbol{z}) ]`, where :math:`\sigma` denotes the precision. +The gradient of such objective is :math:`\nabla_{\boldsymbol{\theta}} \tilde{f}_{\sigma} (\boldsymbol{\theta}) = \frac{1}{\sigma} \mathbb{E}_{\boldsymbol{z} \sim \mathcal{N}( 0, {I}_d )} [ f (\boldsymbol{\theta} + \sigma \, \boldsymbol{z}) \cdot \boldsymbol{z} ]`. +Based on such technique, one can treat the bi-level process as a whole to calculate the meta-gradient based on pure forward process. +Refer to `ES-MAML `_ for more explanations. + +Decorators +---------- + +.. autosummary:: + + torchopt.diff.zero_order.zero_order + +Similar to the implicit gradient, we also use the decorator for ES methods. + +Functional API +~~~~~~~~~~~~~~ + +The basic functional API is :func:`torchopt.diff.zero_order.zero_order`, which is used as the decorator for the forward process zero-order gradient procedures. +Users are required to implement the noise sampling function, which will be used as the input of the zero_order decorator. +Here we show the specific meaning for each parameter used in the decorator. + +- ``distribution`` for noise sampling distribution. The distribution :math:`\lambda` should be spherical symmetric and with a constant variance of :math:`1` for each element. I.e.: + + - Spherical symmetric: :math:`\mathbb{E}_{\boldsymbol{z} \sim \lambda} [ \boldsymbol{z} ] = \boldsymbol{0}`. + - Constant variance of :math:`1` for each element: :math:`\mathbb{E}_{\boldsymbol{z} \sim \lambda} [ {\lvert z_i \rvert}^2 ] = 1`. + - For example, the standard multi-dimensional normal distribution :math:`\mathcal{N} (\boldsymbol{0}, \boldsymbol{1})`. + +- ``method`` for different kind of algorithms, we support ``'naive'`` (`ES RL `_), ``'forward'`` (`Forward-FD `_), and ``'antithetic'`` (`antithetic `_). + + .. math:: + + \begin{align*} + \text{naive} \qquad & \nabla_{\boldsymbol{\theta}} \tilde{f}_{\sigma} (\boldsymbol{\theta}) = \frac{1}{\sigma} \mathbb{E}_{\boldsymbol{z} \sim \lambda} [ f (\boldsymbol{\theta} + \sigma \, \boldsymbol{z}) \cdot \boldsymbol{z} ] \\ + \text{forward} \qquad & \nabla_{\boldsymbol{\theta}} \tilde{f}_{\sigma} (\boldsymbol{\theta}) = \frac{1}{\sigma} \mathbb{E}_{\boldsymbol{z} \sim \lambda} [ ( f (\boldsymbol{\theta} + \sigma \, \boldsymbol{z}) - f (\boldsymbol{\theta}) ) \cdot \boldsymbol{z} ] \\ + \text{antithetic} \qquad & \nabla_{\boldsymbol{\theta}} \tilde{f}_{\sigma} (\boldsymbol{\theta}) = \frac{1}{2 \sigma} \mathbb{E}_{\boldsymbol{z} \sim \lambda} [ (f (\boldsymbol{\theta} + \sigma \, \boldsymbol{z}) - f (\boldsymbol{\theta} + \sigma \, \boldsymbol{z}) ) \cdot \boldsymbol{z} ] + \end{align*} + +- ``argnums`` specifies which parameter we want to trace the meta-gradient. +- ``num_samples`` specifies how many times we want to conduct the sampling. +- ``sigma`` is for precision. This is the scaling factor for the sampling distribution. + +We show the pseudo code in the following part. + +.. code-block:: python + + # Functional API for zero-order differentiation + # 1. Customize the noise distribution via a distribution class + class Distribution: + def sample(self, sample_shape=torch.Size()): + # Sampling function for noise + # NOTE: The distribution should be spherical symmetric and with a constant variance of 1. + ... + return noise_batch + + distribution = Distribution() + + # 2. Customize the noise distribution via a sampling function + def distribution(sample_shape=torch.Size()): + # Sampling function for noise + # NOTE: The distribution should be spherical symmetric and with a constant variance of 1. + ... + return noise_batch + + # 3. Distribution can also be an instance of `torch.distributions.Distribution`, e.g., `torch.distributions.Normal(...)` + distribution = torch.distributions.Normal(loc=0, scale=1) + + # Decorator that wraps the function + @torchopt.diff.zero_order(distribution=distribution, method='naive', argnums=0, num_samples=100, sigma=0.01) + def forward(params, data): + # Forward optimization process for params + ... + return objective # the returned tensor should be a scalar tensor + + # Define params and get data + params, data = ..., ... + + # Forward pass + loss = forward(params, data) + # Backward pass using zero-order differentiation + grads = torch.autograd.grad(loss, params) + +OOP API +~~~~~~~ + +.. autosummary:: + + torchopt.nn.ZeroOrderGradientModule + +Coupled with PyTorch |torch.nn.Module|_, we also design the OOP API :class:`nn.ZeroOrderGradientModule` for ES. +The core idea of :class:`nn.ZeroOrderGradientModule` is to enable the gradient flow forward process to `self.parameters()` (can be the meta-parameters when calculating meta-gradient). +Users need to define the forward process zero-order gradient procedures ``forward()`` and a noise sampling function ``sample()``. + +.. |torch.nn.Module| replace:: ``torch.nn.Module`` +.. _torch.nn.Module: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module + +.. code-block:: python + + from torchopt.nn import ZeroOrderGradientModule + + # Inherited from the class ZeroOrderGradientModule + # Optionally specify the `method` and/or `num_samples` and/or `sigma` used for sampling + class Net(ZeroOrderGradientModule, method='naive', num_samples=100, sigma=0.01): + def __init__(self, ...): + ... + + def forward(self, batch): + # Forward process + ... + return objective # the returned tensor should be a scalar tensor + + def sample(self, sample_shape=torch.Size()): + # Generate a batch of noise samples + # NOTE: The distribution should be spherical symmetric and with a constant variance of 1. + ... + return noise_batch + + # Get model and data + net = Net(...) + data = ... + + # Forward pass + loss = Net(data) + # Backward pass using zero-order differentiation + grads = torch.autograd.grad(loss, net.parameters()) + +Notebook Tutorial +----------------- + +For more details, check the notebook tutorial at `zero-order `_. diff --git a/examples/FuncTorch/maml_omniglot_vmap.py b/examples/FuncTorch/maml_omniglot_vmap.py index 9bbb30ce..2f42e050 100644 --- a/examples/FuncTorch/maml_omniglot_vmap.py +++ b/examples/FuncTorch/maml_omniglot_vmap.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -39,16 +39,10 @@ https://github.com/bamos/HowToTrainYourMAMLPytorch """ - -import os -import sys - - -cur = os.path.abspath(os.path.dirname(__file__)) -root = os.path.split(cur)[0] -sys.path.append(root + '/few-shot') import argparse import functools +import pathlib +import sys import time import functorch @@ -59,12 +53,17 @@ import torch import torch.nn.functional as F import torch.optim as optim -from support.omniglot_loaders import OmniglotNShot from torch import nn import torchopt +CWD = pathlib(__file__).absolute().parent +sys.path.append(str(CWD.parent / 'few-shot')) + +from helpers.omniglot_loaders import OmniglotNShot + + mpl.use('Agg') plt.style.use('bmh') @@ -80,7 +79,10 @@ def main(): argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15) argparser.add_argument('--device', type=str, help='device', default='cuda') argparser.add_argument( - '--task_num', type=int, help='meta batch size, namely task num', default=32 + '--task_num', + type=int, + help='meta batch size, namely task num', + default=32, ) argparser.add_argument('--seed', type=int, help='random seed', default=1) args = argparser.parse_args() @@ -148,8 +150,6 @@ def loss_for_task(net, n_inner_iter, x_spt, y_spt, x_qry, y_qry): opt = torchopt.sgd(lr=1e-1) opt_state = opt.init(params) - querysz = x_qry.size(0) - def compute_loss(new_params, buffers, x, y): logits = fnet(new_params, buffers, x) loss = F.cross_entropy(logits, y) @@ -167,7 +167,7 @@ def compute_loss(new_params, buffers, x, y): # These will be used to update the model's meta-parameters. qry_logits = fnet(new_params, buffers, x_qry) qry_loss = F.cross_entropy(qry_logits, y_qry) - qry_acc = (qry_logits.argmax(dim=1) == y_qry).sum() / querysz + qry_acc = (qry_logits.argmax(dim=1) == y_qry).float().mean() return qry_loss, qry_acc @@ -192,18 +192,18 @@ def train(db, net, device, meta_opt, epoch, log): qry_losses, qry_accs = functorch.vmap(compute_loss_for_task)(x_spt, y_spt, x_qry, y_qry) # Compute the maml loss by summing together the returned losses. - qry_losses.sum().backward() - + qry_losses = torch.mean(torch.stack(qry_losses)) + qry_losses.backward() meta_opt.step() - qry_losses = qry_losses.detach().sum() / task_num - qry_accs = 100.0 * qry_accs.sum() / task_num + qry_losses = qry_losses.item() + qry_accs = 100.0 * torch.mean(torch.stack(qry_accs)).item() i = epoch + float(batch_idx) / n_train_iter iter_time = time.time() - start_time + if batch_idx % 4 == 0: print( - f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}' + f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}', ) - log.append( { 'epoch': i, @@ -211,7 +211,7 @@ def train(db, net, device, meta_opt, epoch, log): 'acc': qry_accs, 'mode': 'train', 'time': time.time(), - } + }, ) @@ -227,7 +227,7 @@ def test(db, net, device, epoch, log): qry_losses = [] qry_accs = [] - for batch_idx in range(n_test_iter): + for _ in range(n_test_iter): x_spt, y_spt, x_qry, y_qry = db.next('test') task_num, setsz, c_, h, w = x_spt.size() @@ -249,8 +249,9 @@ def test(db, net, device, epoch, log): qry_losses.append(qry_loss.detach()) qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach()) - qry_losses = torch.cat(qry_losses).mean().item() - qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item() + qry_losses = torch.mean(torch.stack(qry_losses)).item() + qry_accs = 100.0 * torch.mean(torch.stack(qry_accs)).item() + print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}') log.append( { @@ -259,7 +260,7 @@ def test(db, net, device, epoch, log): 'acc': qry_accs, 'mode': 'test', 'time': time.time(), - } + }, ) diff --git a/examples/FuncTorch/parallel_train_torchopt.py b/examples/FuncTorch/parallel_train_torchopt.py index 640763cb..523515e1 100644 --- a/examples/FuncTorch/parallel_train_torchopt.py +++ b/examples/FuncTorch/parallel_train_torchopt.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,8 +15,6 @@ import argparse import math -from collections import namedtuple -from typing import Any, NamedTuple import functorch import torch @@ -137,7 +135,9 @@ def test_parallel_train_step_fn(self, num_models): weights, opt_state = parallel_init_fn(torch.ones(num_models, 1)) for i in range(2000): loss, (weights, opt_states) = parallel_train_step_fn( - (weights, opt_state), points, labels + (weights, opt_state), + points, + labels, ) if i % 200 == 0: print(loss) @@ -188,7 +188,9 @@ def test_parallel_train_step_fn(self, num_models): optimizer = torchopt.adam(lr=0.2) opt_state = optimizer.init(weights) functorch_original = ParallelTrainFunctorchTorchOpt( - loss_fn=loss_fn, optimizer=optimizer, device=DEVICE + loss_fn=loss_fn, + optimizer=optimizer, + device=DEVICE, ) # Step 4: Let's verify this actually trains. # We should see the loss decrease. @@ -201,7 +203,7 @@ def test_parallel_train_step_fn(self, num_models): # Step 7: Now, the flaw with step 6 is that we were training on the same exact # data. This can lead to all of the models in the ensemble overfitting in the # same way. The solution that http://willwhitney.com/parallel-training-jax.html - # applies is to randomly subset the data in a way that the models do not recieve + # applies is to randomly subset the data in a way that the models do not receive # exactly the same data in each training step! # Because the goal of this doc is to show that we can use eager-mode vmap to # achieve similar things as JAX, the rest of this is left as an exercise to the reader. diff --git a/examples/L2R/helper/argument.py b/examples/L2R/helpers/argument.py similarity index 96% rename from examples/L2R/helper/argument.py rename to examples/L2R/helpers/argument.py index 5df9f314..7db6c982 100644 --- a/examples/L2R/helper/argument.py +++ b/examples/L2R/helpers/argument.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/L2R/helper/model.py b/examples/L2R/helpers/model.py similarity index 95% rename from examples/L2R/helper/model.py rename to examples/L2R/helpers/model.py index 80fae8ac..877ad50a 100644 --- a/examples/L2R/helper/model.py +++ b/examples/L2R/helpers/model.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ class LeNet5(nn.Module): def __init__(self, args): - super(LeNet5, self).__init__() + super().__init__() self.model = nn.Sequential( nn.Conv2d(1, 16, 5), nn.ReLU(), @@ -51,7 +51,7 @@ def __init__(self, args): ) self.args = args self.meta_weights = torch.zeros(self.args.batch_size, requires_grad=True).to( - self.args.device + self.args.device, ) self.criterion = nn.BCELoss() diff --git a/examples/L2R/helper/utils.py b/examples/L2R/helpers/utils.py similarity index 93% rename from examples/L2R/helper/utils.py rename to examples/L2R/helpers/utils.py index 954b27b2..ade64236 100644 --- a/examples/L2R/helper/utils.py +++ b/examples/L2R/helpers/utils.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,7 +33,6 @@ def get_imbalance_dataset( class_0=4, class_1=9, ): - ratio = 1 - pos_ratio ratio_test = 0.5 @@ -90,16 +89,10 @@ def get_imbalance_dataset( y_val_subset = np.concatenate([np.zeros([x_val_0.shape[0]]), np.ones([x_val_1.shape[0]])]) y_test_subset = np.concatenate([np.zeros([x_test_0.shape[0]]), np.ones([x_test_1.shape[0]])]) - y_train_pos_subset = np.ones([x_train_1.shape[0]]) - y_train_neg_subset = np.zeros([x_train_0.shape[0]]) - x_train_subset = np.concatenate([x_train_0, x_train_1], axis=0)[:, None, :, :] x_val_subset = np.concatenate([x_val_0, x_val_1], axis=0)[:, None, :, :] x_test_subset = np.concatenate([x_test_0, x_test_1], axis=0)[:, None, :, :] - x_train_pos_subset = x_train_1[:, None, :, :] - x_train_neg_subset = x_train_0[:, None, :, :] - # Final shuffle. idx = np.arange(x_train_subset.shape[0]) np.random.shuffle(idx) @@ -116,7 +109,7 @@ def get_imbalance_dataset( x_test_subset = x_test_subset[idx].astype(np.float32) y_test_subset = y_test_subset[idx].astype(np.float32) - (x_train_subset, y_train_subset, x_val_subset, y_val_subset, x_test_subset, y_test_subset,) = ( + x_train_subset, y_train_subset, x_val_subset, y_val_subset, x_test_subset, y_test_subset = ( torch.tensor(x_train_subset), torch.tensor(y_train_subset), torch.tensor(x_val_subset), @@ -147,7 +140,7 @@ def set_seed(seed, cudnn=True): torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) # note: the below slows down the code but makes it reproducible - # Sets the seed for generating random numbers on all GPUs. It’s safe to + # Sets the seed for generating random numbers on all GPUs. It's safe to # call this function if CUDA is not available; in that case, it is # silently ignored. torch.cuda.manual_seed_all(seed) @@ -158,7 +151,6 @@ def set_seed(seed, cudnn=True): def plot(baseline, l2r): import matplotlib.pyplot as plt - import numpy as np import seaborn as sns sns.set(style='darkgrid') diff --git a/examples/L2R/l2r.py b/examples/L2R/l2r.py index cd093313..a0ae764b 100644 --- a/examples/L2R/l2r.py +++ b/examples/L2R/l2r.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,12 +36,9 @@ from torchvision.datasets import MNIST import torchopt - - -# isort: off -from helper.argument import parse_args -from helper.model import LeNet5 -from helper.utils import get_imbalance_dataset, plot, set_seed +from helpers.argument import parse_args +from helpers.model import LeNet5 +from helpers.utils import get_imbalance_dataset, plot, set_seed def run_baseline(args, mnist_train, mnist_test): @@ -54,14 +51,13 @@ def run_baseline(args, mnist_train, mnist_test): ntest = args.ntest epoch = args.epoch - folder = './result/baseline/' writer = SummaryWriter('./result/baseline') with open('./result/baseline/config.json', 'w') as f: json.dump(args.__dict__, f) args.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') - train_set, val_set, test_set = get_imbalance_dataset( + train_set, _, test_set = get_imbalance_dataset( mnist_train, mnist_test, pos_ratio=pos_ratio, @@ -70,11 +66,10 @@ def run_baseline(args, mnist_train, mnist_test): ntest=ntest, ) train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4) - valid_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=True, num_workers=1) test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=True, num_workers=1) model = LeNet5(args).to(args.device) - model_optimiser = torch.optim.Adam(model.parameters(), lr=args.lr) + model_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) step = 0 running_train_loss = [] @@ -85,16 +80,16 @@ def run_baseline(args, mnist_train, mnist_test): train_x, train_label = train_x.to(args.device), train_label.to(args.device) outer_loss = model.outer_loss(train_x, train_label) - model_optimiser.zero_grad() + model_optimizer.zero_grad() outer_loss.backward() - model_optimiser.step() + model_optimizer.step() running_train_loss.append(outer_loss.item()) writer.add_scalar('train_loss', outer_loss.item(), step) if step % 10 == 0 and step > 0: running_train_mean = np.mean(np.array(running_train_loss)) - print('EPOCH: {}, BATCH: {}, LOSS: {}'.format(_epoch, idx, running_train_mean)) + print(f'EPOCH: {_epoch}, BATCH: {idx}, LOSS: {running_train_mean}') writer.add_scalar('running_train_loss', running_train_mean, step) running_train_loss = [] @@ -109,7 +104,7 @@ def run_baseline(args, mnist_train, mnist_test): writer.add_scalar('train_acc', train_acc, _epoch) writer.add_scalar('test_acc', test_acc, _epoch) test_acc_result.append(test_acc) - print('EPOCH: {}, TRAIN_ACC: {}, TEST_ACC: {}'.format(_epoch, train_acc, test_acc)) + print(f'EPOCH: {_epoch}, TRAIN_ACC: {train_acc}, TEST_ACC: {test_acc}') return test_acc_result @@ -123,7 +118,6 @@ def run_L2R(args, mnist_train, mnist_test): ntest = args.ntest epoch = args.epoch - folder = './result/l2r/' writer = SummaryWriter('./result/l2r/log') with open('./result/l2r/config.json', 'w') as f: json.dump(args.__dict__, f) @@ -142,11 +136,10 @@ def run_L2R(args, mnist_train, mnist_test): valid_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=True, num_workers=1) test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=True, num_workers=1) model = LeNet5(args).to(args.device) - model_optimiser = torchopt.MetaSGD(model, lr=args.lr) - real_model_optimiser = torch.optim.Adam(model.parameters(), lr=args.lr) + model_optimizer = torchopt.MetaSGD(model, lr=args.lr) + real_model_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) step = 0 - time_bp = 0 running_valid_loss = [] valid = iter(valid_loader) running_train_loss = [] @@ -170,13 +163,13 @@ def run_L2R(args, mnist_train, mnist_test): model.reset_meta(size=train_x.size(0)) net_state_dict = torchopt.extract_state_dict(model) - optim_state_dict = torchopt.extract_state_dict(model_optimiser) + optim_state_dict = torchopt.extract_state_dict(model_optimizer) for _ in range(1): inner_loss = model.inner_loss(train_x, train_label) - model_optimiser.step(inner_loss) + model_optimizer.step(inner_loss) - # caclulate outer_loss, deirve meta-gradient and normalise + # calculate outer_loss, derive meta-gradient and normalize outer_loss = model.outer_loss(valid_x, valid_label) model.meta_weights = -torch.autograd.grad(outer_loss, model.meta_weights)[0] model.meta_weights = torch.nn.ReLU()(model.meta_weights) @@ -186,17 +179,17 @@ def run_L2R(args, mnist_train, mnist_test): running_valid_loss.append(outer_loss.item()) writer.add_scalar('validation_loss', outer_loss.item(), step) - # reset the model and model optimiser + # reset the model and model optimizer torchopt.recover_state_dict(model, net_state_dict) - torchopt.recover_state_dict(model_optimiser, optim_state_dict) + torchopt.recover_state_dict(model_optimizer, optim_state_dict) # reuse inner_adapt to conduct real update based on learned meta weights inner_loss = model.inner_loss(train_x, train_label) for _ in range(1): inner_loss = model.inner_loss(train_x, train_label) - real_model_optimiser.zero_grad() + real_model_optimizer.zero_grad() inner_loss.backward() - real_model_optimiser.step() + real_model_optimizer.step() running_train_loss.append(inner_loss.item()) writer.add_scalar('weighted_train_loss', inner_loss.item(), step) @@ -206,8 +199,11 @@ def run_L2R(args, mnist_train, mnist_test): running_train_mean = np.mean(np.array(running_train_loss)) print( 'EPOCH: {}, BATCH: {}, WEIGHTED_TRAIN_LOSS: {}, VALID_LOSS: {}'.format( - _epoch, idx, running_train_mean, running_valid_mean - ) + _epoch, + idx, + running_train_mean, + running_valid_mean, + ), ) running_valid_loss = [] running_train_loss = [] @@ -225,7 +221,7 @@ def run_L2R(args, mnist_train, mnist_test): writer.add_scalar('train_acc', train_acc, _epoch) writer.add_scalar('test_acc', test_acc, _epoch) test_acc_result.append(test_acc) - print('EPOCH: {}, TRAIN_ACC: {}, TEST_ACC: {}'.format(_epoch, train_acc, test_acc)) + print(f'EPOCH: {_epoch}, TRAIN_ACC: {train_acc}, TEST_ACC: {test_acc}') return test_acc_result diff --git a/examples/LOLA/helper/agent.py b/examples/LOLA/helpers/agent.py similarity index 93% rename from examples/LOLA/helper/agent.py rename to examples/LOLA/helpers/agent.py index 8b30a983..78946ee7 100644 --- a/examples/LOLA/helper/agent.py +++ b/examples/LOLA/helpers/agent.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,7 +30,6 @@ def __init__(self, theta): class Agent: def __init__(self, args): - self.args = args # init theta and its optimizer self.theta = nn.Parameter(torch.zeros(5, requires_grad=True)) @@ -44,7 +43,7 @@ def __init__(self, args): def set_virtual(self): self.virtual_theta = theta_model(self.theta) - self.virtual_optimiser = torchopt.MetaSGD(self.virtual_theta, lr=self.args.lr_in) + self.virtual_optimizer = torchopt.MetaSGD(self.virtual_theta, lr=self.args.lr_in) def value_update(self, loss): self.value_optimizer.zero_grad() diff --git a/examples/LOLA/helper/argument.py b/examples/LOLA/helpers/argument.py similarity index 96% rename from examples/LOLA/helper/argument.py rename to examples/LOLA/helpers/argument.py index 39618134..ad53c056 100644 --- a/examples/LOLA/helper/argument.py +++ b/examples/LOLA/helpers/argument.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/LOLA/helper/env.py b/examples/LOLA/helpers/env.py similarity index 95% rename from examples/LOLA/helper/env.py rename to examples/LOLA/helpers/env.py index f1ef6e6f..e1576a7d 100644 --- a/examples/LOLA/helper/env.py +++ b/examples/LOLA/helpers/env.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -54,7 +54,7 @@ def __eq__(self, other): class IPD(gym.Env): """ A two-agent vectorized environment. - Possible actions for each agent are (C)ooperate and (D)efect. + Possible actions for each agent are Cooperate (C) and Defect (D). """ # Possible actions diff --git a/examples/LOLA/helper/utils.py b/examples/LOLA/helpers/utils.py similarity index 96% rename from examples/LOLA/helper/utils.py rename to examples/LOLA/helpers/utils.py index afa9e609..4dd436ec 100644 --- a/examples/LOLA/helper/utils.py +++ b/examples/LOLA/helpers/utils.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,7 +27,7 @@ def step(ipd, theta1, theta2, values1, values2, args): (s1, s2), _ = ipd.reset() score1 = 0 score2 = 0 - for t in range(args.len_rollout): + for _ in range(args.len_rollout): a1, lp1, v1 = act(s1, theta1, values1) a2, lp2, v2 = act(s2, theta2, values2) (s1, s2), (r1, r2), _, _ = ipd.step((a1, a2)) @@ -82,7 +82,7 @@ def dice_objective(self, use_baseline=True): if use_baseline: # variance_reduction: baseline_term = torch.mean( - torch.sum((1 - magic_box(stochastic_nodes)) * discounted_values, dim=1) + torch.sum((1 - magic_box(stochastic_nodes)) * discounted_values, dim=1), ) dice_objective = dice_objective + baseline_term @@ -109,7 +109,7 @@ def sample(ipd, policy, value, args): (s1, s2), _ = ipd.reset() memory_agent1 = Memory(args) memory_agent2 = Memory(args) - for t in range(args.len_rollout): + for _ in range(args.len_rollout): a1, lp1, v1 = act(s1, theta1, value1) a2, lp2, v2 = act(s2, theta2, value2) (s1, s2), (r1, r2), _, _ = ipd.step((a1, a2)) diff --git a/examples/LOLA/lola_dice.py b/examples/LOLA/lola_dice.py index 61d2e22c..485de894 100644 --- a/examples/LOLA/lola_dice.py +++ b/examples/LOLA/lola_dice.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,12 +19,10 @@ import numpy as np import torch - -# isort: off -from helper.agent import Agent -from helper.argument import parse_args -from helper.env import IPD -from helper.utils import sample, step +from helpers.agent import Agent +from helpers.argument import parse_args +from helpers.env import IPD +from helpers.utils import sample, step def main(args): @@ -49,7 +47,7 @@ def main(args): args, ) inner_loss = memory1.dice_objective(use_baseline=args.use_baseline) - agent1.virtual_optimiser.step(inner_loss) + agent1.virtual_optimizer.step(inner_loss) # agent 1 assumes that agent 2 conducts n-step lookahead for _ in range(n_lookaheads): @@ -60,7 +58,7 @@ def main(args): args, ) inner_loss = memory2.dice_objective(use_baseline=args.use_baseline) - agent2.virtual_optimiser.step(inner_loss) + agent2.virtual_optimizer.step(inner_loss) # update agent 1 memory1, memory2 = sample( @@ -98,17 +96,16 @@ def main(args): score = step(ipd, agent1.theta, agent2.theta, agent1.values, agent2.values, args) joint_scores.append(0.5 * (score[0] + score[1])) - # print if update % 10 == 0: p1 = [p.item() for p in torch.sigmoid(agent1.theta)] p2 = [p.item() for p in torch.sigmoid(agent2.theta)] print( 'update', update, - 'score (%.3f,%.3f)' % (score[0], score[1]), + f'score ({score[0]:.3f},{score[1]:.3f})', 'policy (agent1) = {%.3f, %.3f, %.3f, %.3f, %.3f}' % (p1[0], p1[1], p1[2], p1[3], p1[4]), - ' (agent2) = {%.3f, %.3f, %.3f, %.3f, %.3f}' % (p2[0], p2[1], p2[2], p2[3], p2[4]), + f' (agent2) = {{{p2[0]:.3f}, {p2[1]:.3f}, {p2[2]:.3f}, {p2[3]:.3f}, {p2[4]:.3f}}}', ) return joint_scores @@ -116,7 +113,7 @@ def main(args): if __name__ == '__main__': args = parse_args() - joint_score = dict() + joint_score = {} for nla in range(3): args.n_lookaheads = nla joint_score[nla] = main(args) diff --git a/examples/LOLA/visualize.py b/examples/LOLA/visualize.py index 6dc54ddf..7af19b21 100755 --- a/examples/LOLA/visualize.py +++ b/examples/LOLA/visualize.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/MAML-RL/func_maml.py b/examples/MAML-RL/func_maml.py new file mode 100644 index 00000000..475c1b12 --- /dev/null +++ b/examples/MAML-RL/func_maml.py @@ -0,0 +1,201 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import argparse +from typing import NamedTuple + +import functorch +import gym +import numpy as np +import torch +import torch.optim as optim + +import torchopt +from helpers.policy import CategoricalMLPPolicy + + +TASK_NUM = 40 +TRAJ_NUM = 20 +TRAJ_LEN = 10 + +STATE_DIM = 10 +ACTION_DIM = 5 + +GAMMA = 0.99 +LAMBDA = 0.95 + +outer_iters = 500 +inner_iters = 1 + + +class Traj(NamedTuple): + obs: np.ndarray + acs: np.ndarray + next_obs: np.ndarray + rews: np.ndarray + gammas: np.ndarray + + +def sample_traj(env, task, fpolicy, params): + env.reset_task(task) + obs_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM, STATE_DIM), dtype=np.float32) + next_obs_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM, STATE_DIM), dtype=np.float32) + acs_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM), dtype=np.int8) + rews_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM), dtype=np.float32) + gammas_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM), dtype=np.float32) + with torch.no_grad(): + for batch in range(TRAJ_NUM): + ob = env.reset() + for step in range(TRAJ_LEN): + ob_tensor = torch.from_numpy(ob) + pi, _ = fpolicy(params, ob_tensor) + ac_tensor = pi.sample() + ac = ac_tensor.cpu().numpy() + next_ob, rew, done, info = env.step(ac) + + obs_buf[step][batch] = ob + next_obs_buf[step][batch] = next_ob + acs_buf[step][batch] = ac + rews_buf[step][batch] = rew + gammas_buf[step][batch] = done * GAMMA + ob = next_ob + return Traj(obs=obs_buf, acs=acs_buf, next_obs=next_obs_buf, rews=rews_buf, gammas=gammas_buf) + + +def a2c_loss(traj, fpolicy, params, value_coef): + lambdas = np.ones_like(traj.gammas) * LAMBDA + _, next_values = fpolicy(params, torch.from_numpy(traj.next_obs)) + next_values = torch.squeeze(next_values, -1).detach().numpy() + # Work backwards to compute `G_{T-1}`, ..., `G_0`. + returns = [] + g = next_values[-1, :] + for i in reversed(range(next_values.shape[0])): + g = traj.rews[i, :] + traj.gammas[i, :] * ( + (1 - lambdas[i, :]) * next_values[i, :] + lambdas[i, :] * g + ) + returns.insert(0, g) + lambda_returns = torch.from_numpy(np.array(returns)) + pi, values = fpolicy(params, torch.from_numpy(traj.obs)) + log_probs = pi.log_prob(torch.from_numpy(traj.acs)) + advs = lambda_returns - torch.squeeze(values, -1) + action_loss = -(advs.detach() * log_probs).mean() + value_loss = advs.pow(2).mean() + + loss = action_loss + value_coef * value_loss + return loss + + +def evaluate(env, seed, task_num, fpolicy, params): + pre_reward_ls = [] + post_reward_ls = [] + inner_opt = torchopt.MetaSGD(lr=0.5) + env = gym.make( + 'TabularMDP-v0', + num_states=STATE_DIM, + num_actions=ACTION_DIM, + max_episode_steps=TRAJ_LEN, + seed=args.seed, + ) + tasks = env.sample_tasks(num_tasks=task_num) + + for idx in range(task_num): + for _ in range(inner_iters): + pre_trajs = sample_traj(env, tasks[idx], fpolicy, params) + + inner_loss = a2c_loss(pre_trajs, fpolicy, params, value_coef=0.5) + params = inner_opt.step(inner_loss, params) + post_trajs = sample_traj(env, tasks[idx], fpolicy, params) + + # Logging + pre_reward_ls.append(np.sum(pre_trajs.rews, axis=0).mean()) + post_reward_ls.append(np.sum(post_trajs.rews, axis=0).mean()) + + return pre_reward_ls, post_reward_ls + + +def main(args): + # init training + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + # Env + env = gym.make( + 'TabularMDP-v0', + num_states=STATE_DIM, + num_actions=ACTION_DIM, + max_episode_steps=TRAJ_LEN, + seed=args.seed, + ) + # Policy + policy = CategoricalMLPPolicy(input_size=STATE_DIM, output_size=ACTION_DIM) + fpolicy, params = functorch.make_functional(policy) + + inner_opt = torchopt.MetaSGD(lr=0.5) + outer_opt = optim.Adam(params, lr=1e-3) + train_pre_reward = [] + train_post_reward = [] + test_pre_reward = [] + test_post_reward = [] + + for i in range(outer_iters): + tasks = env.sample_tasks(num_tasks=TASK_NUM) + train_pre_reward_ls = [] + train_post_reward_ls = [] + + outer_opt.zero_grad() + + param_orig = [p.detach().clone().requires_grad_() for p in params] + _params = list(params) + for idx in range(TASK_NUM): + for _ in range(inner_iters): + pre_trajs = sample_traj(env, tasks[idx], fpolicy, _params) + inner_loss = a2c_loss(pre_trajs, fpolicy, _params, value_coef=0.5) + _params = inner_opt.step(inner_loss, _params) + post_trajs = sample_traj(env, tasks[idx], fpolicy, _params) + outer_loss = a2c_loss(post_trajs, fpolicy, _params, value_coef=0.5) + outer_loss.backward() + _params = [p.detach().clone().requires_grad_() for p in param_orig] + + # Logging + train_pre_reward_ls.append(np.sum(pre_trajs.rews, axis=0).mean()) + train_post_reward_ls.append(np.sum(post_trajs.rews, axis=0).mean()) + outer_opt.step() + + test_pre_reward_ls, test_post_reward_ls = evaluate( + env, + args.seed, + TASK_NUM, + fpolicy, + params, + ) + + train_pre_reward.append(sum(train_pre_reward_ls) / TASK_NUM) + train_post_reward.append(sum(train_post_reward_ls) / TASK_NUM) + test_pre_reward.append(sum(test_pre_reward_ls) / TASK_NUM) + test_post_reward.append(sum(test_post_reward_ls) / TASK_NUM) + + print('Train_iters', i) + print('train_pre_reward', sum(train_pre_reward_ls) / TASK_NUM) + print('train_post_reward', sum(train_post_reward_ls) / TASK_NUM) + print('test_pre_reward', sum(test_pre_reward_ls) / TASK_NUM) + print('test_post_reward', sum(test_post_reward_ls) / TASK_NUM) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Reinforcement learning with Model-Agnostic Meta-Learning (MAML) - Train', + ) + parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)') + args = parser.parse_args() + main(args) diff --git a/examples/MAML-RL/helpers/__init__.py b/examples/MAML-RL/helpers/__init__.py index 9855e0b3..31d45c37 100644 --- a/examples/MAML-RL/helpers/__init__.py +++ b/examples/MAML-RL/helpers/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/MAML-RL/helpers/policy.py b/examples/MAML-RL/helpers/policy.py index 0a43a5b1..0bf8e188 100644 --- a/examples/MAML-RL/helpers/policy.py +++ b/examples/MAML-RL/helpers/policy.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/MAML-RL/helpers/policy_torchrl.py b/examples/MAML-RL/helpers/policy_torchrl.py index 103a4ec5..5bcb3863 100644 --- a/examples/MAML-RL/helpers/policy_torchrl.py +++ b/examples/MAML-RL/helpers/policy_torchrl.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,9 +13,7 @@ # limitations under the License. # ============================================================================== -import torch import torch.nn as nn -from torch.distributions import Categorical from torchrl.modules import ( ActorValueOperator, OneHotCategorical, diff --git a/examples/MAML-RL/helpers/tabular_mdp.py b/examples/MAML-RL/helpers/tabular_mdp.py index 3a6bee60..0d8f5f7f 100644 --- a/examples/MAML-RL/helpers/tabular_mdp.py +++ b/examples/MAML-RL/helpers/tabular_mdp.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -49,7 +49,10 @@ def __init__(self, num_states, num_actions, max_episode_steps, seed, task=None): self.action_space = spaces.Discrete(num_actions) self.observation_space = spaces.Box( - low=0.0, high=1.0, shape=(num_states,), dtype=np.float32 + low=0.0, + high=1.0, + shape=(num_states,), + dtype=np.float32, ) self._task = task @@ -62,7 +65,8 @@ def __init__(self, num_states, num_actions, max_episode_steps, seed, task=None): ), ) self._rewards_mean = task.get( - 'rewards_mean', np.zeros((num_states, num_actions), dtype=np.float32) + 'rewards_mean', + np.zeros((num_states, num_actions), dtype=np.float32), ) self._state = 0 self._elapsed_steps = None @@ -79,7 +83,9 @@ def sample_tasks(self, num_tasks): size=(num_tasks, self.num_states, self.num_actions), ) rewards_mean = self.np_random.normal( - 1.0, 1.0, size=(num_tasks, self.num_states, self.num_actions) + 1.0, + 1.0, + size=(num_tasks, self.num_states, self.num_actions), ) tasks = [ {'transitions': transition, 'rewards_mean': reward_mean} @@ -93,7 +99,6 @@ def reset_task(self, task): self._rewards_mean = task['rewards_mean'] def reset(self): - # From [1]: "an episode always starts on the first state" self._state = 0 observation = np.zeros(self.num_states, dtype=np.float32) observation[self._state] = 1.0 @@ -107,13 +112,11 @@ def step(self, action): reward = self.np_random.normal(mean, 1.0) self._state = self.np_random.choice( - self.num_states, p=self._transitions[self._state, action] + self.num_states, + p=self._transitions[self._state, action], ) observation = np.zeros(self.num_states, dtype=np.float32) observation[self._state] = 1.0 self._elapsed_steps += 1 - if self._elapsed_steps >= self.max_episode_steps: - done = True - else: - done = False + done = self._elapsed_steps >= self.max_episode_steps return observation, reward, done, {'task': self._task} diff --git a/examples/MAML-RL/maml.py b/examples/MAML-RL/maml.py index f2bb38e9..0cb57a92 100644 --- a/examples/MAML-RL/maml.py +++ b/examples/MAML-RL/maml.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,9 +22,7 @@ import torch.optim as optim import torchopt - - -from helpers.policy import CategoricalMLPPolicy # isort: skip +from helpers.policy import CategoricalMLPPolicy TASK_NUM = 40 @@ -99,8 +97,9 @@ def a2c_loss(traj, policy, value_coef): advs = lambda_returns - torch.squeeze(values, -1) action_loss = -(advs.detach() * log_probs).mean() value_loss = advs.pow(2).mean() - a2c_loss = action_loss + value_coef * value_loss - return a2c_loss + + loss = action_loss + value_coef * value_loss + return loss def evaluate(env, seed, task_num, policy): @@ -109,12 +108,10 @@ def evaluate(env, seed, task_num, policy): inner_opt = torchopt.MetaSGD(policy, lr=0.1) env = gym.make( 'TabularMDP-v0', - **dict( - num_states=STATE_DIM, - num_actions=ACTION_DIM, - max_episode_steps=TRAJ_LEN, - seed=args.seed, - ), + num_states=STATE_DIM, + num_actions=ACTION_DIM, + max_episode_steps=TRAJ_LEN, + seed=args.seed, ) tasks = env.sample_tasks(num_tasks=task_num) policy_state_dict = torchopt.extract_state_dict(policy) @@ -142,12 +139,10 @@ def main(args): # Env env = gym.make( 'TabularMDP-v0', - **dict( - num_states=STATE_DIM, - num_actions=ACTION_DIM, - max_episode_steps=TRAJ_LEN, - seed=args.seed, - ), + num_states=STATE_DIM, + num_actions=ACTION_DIM, + max_episode_steps=TRAJ_LEN, + seed=args.seed, ) # Policy policy = CategoricalMLPPolicy(input_size=STATE_DIM, output_size=ACTION_DIM) @@ -198,7 +193,7 @@ def main(args): if __name__ == '__main__': parser = argparse.ArgumentParser( - description='Reinforcement learning with Model-Agnostic Meta-Learning (MAML) - Train' + description='Reinforcement learning with Model-Agnostic Meta-Learning (MAML) - Train', ) parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)') args = parser.parse_args() diff --git a/examples/MAML-RL/maml_torchrl.py b/examples/MAML-RL/maml_torchrl.py index 9d1bfe56..56db91ef 100644 --- a/examples/MAML-RL/maml_torchrl.py +++ b/examples/MAML-RL/maml_torchrl.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,9 +14,7 @@ # ============================================================================== import argparse -import time -import numpy as np import torch import torch.optim as optim import tqdm @@ -25,9 +23,7 @@ from torchrl.objectives.returns.functional import td_lambda_advantage_estimate import torchopt - - -from helpers.policy_torchrl import ActorCritic # isort: skip +from helpers.policy_torchrl import ActorCritic TASK_NUM = 40 @@ -62,8 +58,6 @@ def a2c_loss(traj, policy_module, value_module, value_coef): next_traj = step_tensordict(traj) next_value = value_module(next_traj).get('state_value').detach() - # tderror = TDEstimate(GAMMA, value_module, gradient_mode=True) - # tderror = TDLambdaEstimate(GAMMA, LAMBDA, value_module, gradient_mode=True) advantage = td_lambda_advantage_estimate(GAMMA, LAMBDA, value, next_value, reward, done) action_loss = -(advantage.detach() * log_probs.view_as(advantage)).mean() value_error = advantage @@ -133,14 +127,17 @@ def main(args): # init training torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) + # Env - lambda_env = lambda: GymEnv( - 'TabularMDP-v0', - num_states=STATE_DIM, - num_actions=ACTION_DIM, - max_episode_steps=TRAJ_LEN, - device=device, - ) + def lambda_env(): + return GymEnv( + 'TabularMDP-v0', + num_states=STATE_DIM, + num_actions=ACTION_DIM, + max_episode_steps=TRAJ_LEN, + device=device, + ) + if args.parallel: env = ParallelEnv( NUM_ENVS, @@ -173,8 +170,7 @@ def main(args): dummy_env.set_seed(args.seed) pbar = tqdm.tqdm(range(outer_iters)) - for i in pbar: - # print("i: ", i) + for _ in pbar: tasks = dummy_env.sample_tasks(num_tasks=TASK_NUM) train_pre_reward_ls = [] train_post_reward_ls = [] @@ -186,7 +182,7 @@ def main(args): env.reset_task(tasks[idx]) policy_module = actor_critic_module.get_policy_operator() value_module = actor_critic_module.get_value_operator() - for k in range(inner_iters): + for __ in range(inner_iters): with set_exploration_mode('random'), torch.no_grad(): pre_traj_td = ( env.rollout( @@ -238,7 +234,7 @@ def main(args): f'train_pre_reward: {train_pre_reward[-1]: 4.4f}, ' f'train_post_reward: {train_post_reward[-1]: 4.4f}, ' f'test_pre_reward: {test_pre_reward[-1]: 4.4f}, ' - f'test_post_reward: {test_post_reward[-1]: 4.4f}, ' + f'test_post_reward: {test_post_reward[-1]: 4.4f}, ', ) env.close() @@ -246,7 +242,7 @@ def main(args): if __name__ == '__main__': parser = argparse.ArgumentParser( - description='Reinforcement learning with Model-Agnostic Meta-Learning (MAML) - Train' + description='Reinforcement learning with Model-Agnostic Meta-Learning (MAML) - Train', ) parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)') parser.add_argument('--parallel', action='store_true', help='run envs in parallel') diff --git a/examples/MGRL/mgrl.py b/examples/MGRL/mgrl.py index 152e4177..bdc4b140 100644 --- a/examples/MGRL/mgrl.py +++ b/examples/MGRL/mgrl.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -55,7 +55,7 @@ def forward(self, x): meta_optimizer = torchopt.SGD([gamma], lr=5e-1) net_state = torchopt.extract_state_dict(net) for i in range(outer_iters): - for j in range(inner_iters): + for _ in range(inner_iters): trajectory, state = Rollout.get() backup = Rollout.rollout(trajectory, torch.sigmoid(gamma)) pred_value = net(state.float()) diff --git a/examples/distributed/few-shot/README.md b/examples/distributed/few-shot/README.md new file mode 100644 index 00000000..a0a758fa --- /dev/null +++ b/examples/distributed/few-shot/README.md @@ -0,0 +1,18 @@ +# MAML few-shot Omniglot classification-examples + +Code on MAML few-shot Omniglot classification in paper [Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks](https://arxiv.org/abs/1703.03400) using TorchOpt. We use `MetaSGD` as the inner-loop optimizer. + +## Usage + +```bash +### Run +torchrun --nnode 1 --nproc_per_node 8 maml_omniglot.py +``` + +## Results + +The figure illustrate the experimental result. + +
+ +
diff --git a/examples/few-shot/support/omniglot_loaders.py b/examples/distributed/few-shot/helpers/omniglot_loaders.py similarity index 90% rename from examples/few-shot/support/omniglot_loaders.py rename to examples/distributed/few-shot/helpers/omniglot_loaders.py index d857d386..52fab28a 100644 --- a/examples/few-shot/support/omniglot_loaders.py +++ b/examples/distributed/few-shot/helpers/omniglot_loaders.py @@ -80,7 +80,7 @@ def __len__(self): def _check_exists(self): return os.path.exists( - os.path.join(self.root, self.processed_folder, 'images_evaluation') + os.path.join(self.root, self.processed_folder, 'images_evaluation'), ) and os.path.exists(os.path.join(self.root, self.processed_folder, 'images_background')) def download(self): @@ -118,7 +118,7 @@ def download(self): def find_classes(root_dir): retour = [] - for (root, dirs, files) in os.walk(root_dir): + for root, _, files in os.walk(root_dir): for f in files: if f.endswith('png'): r = root.split('/') @@ -164,14 +164,14 @@ def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, rng, device=Non lambda x: np.reshape(x, (imgsz, imgsz, 1)), lambda x: np.transpose(x, [2, 0, 1]), lambda x: x / 255.0, - ] + ], ), ) # {label: [img1, img2..., img20], label2: [img1, img2, ...], ... 1623 labels in total} temp = {} - for (img, label) in self.x: - if label in temp.keys(): + for img, label in self.x: + if label in temp: temp[label].append(img) else: temp[label] = [img] @@ -196,12 +196,9 @@ def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, rng, device=Non self.x = np.load(os.path.join(root, 'omniglot.npy')) print('load from omniglot.npy.') - # [1623, 20, 84, 84, 1] - # TODO: can not shuffle here, we must keep training and test set distinct! + # NOTE: do not shuffle here, we must keep training and test set distinct! self.x_train, self.x_test = self.x[:1200], self.x[1200:] - # self.normalization() - self.batchsz = batchsz self.n_cls = self.x.shape[0] # 1623 self.n_way = n_way # n way @@ -230,7 +227,6 @@ def normalization(self): self.std = np.std(self.x_train) self.max = np.max(self.x_train) self.min = np.min(self.x_train) - # print("before norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std) self.x_train = (self.x_train - self.mean) / self.std self.x_test = (self.x_test - self.mean) / self.std @@ -239,8 +235,6 @@ def normalization(self): self.max = np.max(self.x_train) self.min = np.min(self.x_train) - # print("after norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std) - def load_data_cache(self, data_pack): """ Collects several batches data for N-shot learning @@ -253,17 +247,13 @@ def load_data_cache(self, data_pack): querysz = self.k_query * self.n_way data_cache = [] - # print('preload next 50 caches of batchsz of batch.') - for sample in range(10): # num of episodes - + for _sample in range(10): # num of episodes x_spts, y_spts, x_qrys, y_qrys = [], [], [], [] - for i in range(self.batchsz): # one batch means one set - + for _ in range(self.batchsz): # one batch means one set x_spt, y_spt, x_qry, y_qry = [], [], [], [] selected_cls = self.rng.choice(data_pack.shape[0], self.n_way, False) for j, cur_class in enumerate(selected_cls): - selected_img = self.rng.choice(20, self.k_shot + self.k_query, False) # meta-training and meta-test @@ -275,12 +265,18 @@ def load_data_cache(self, data_pack): # shuffle inside a batch perm = self.rng.permutation(self.n_way * self.k_shot) x_spt = np.array(x_spt).reshape( - self.n_way * self.k_shot, 1, self.resize, self.resize + self.n_way * self.k_shot, + 1, + self.resize, + self.resize, )[perm] y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm] perm = self.rng.permutation(self.n_way * self.k_query) x_qry = np.array(x_qry).reshape( - self.n_way * self.k_query, 1, self.resize, self.resize + self.n_way * self.k_query, + 1, + self.resize, + self.resize, )[perm] y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm] @@ -290,20 +286,29 @@ def load_data_cache(self, data_pack): x_qrys.append(x_qry) y_qrys.append(y_qry) - # [b, setsz, 1, 84, 84] x_spts = np.array(x_spts, dtype=np.float32).reshape( - self.batchsz, setsz, 1, self.resize, self.resize - ) - y_spts = np.array(y_spts, dtype=np.int).reshape(self.batchsz, setsz) - # [b, qrysz, 1, 84, 84] + self.batchsz, + setsz, + 1, + self.resize, + self.resize, + ) # [b, setsz, 1, 84, 84] + y_spts = np.array(y_spts, dtype=np.int).reshape( + self.batchsz, + setsz, + ) # [b, qrysz, 1, 84, 84] x_qrys = np.array(x_qrys, dtype=np.float32).reshape( - self.batchsz, querysz, 1, self.resize, self.resize + self.batchsz, + querysz, + 1, + self.resize, + self.resize, ) y_qrys = np.array(y_qrys, dtype=np.int).reshape(self.batchsz, querysz) - x_spts, y_spts, x_qrys, y_qrys = [ + x_spts, y_spts, x_qrys, y_qrys = ( torch.from_numpy(z).to(self.device) for z in [x_spts, y_spts, x_qrys, y_qrys] - ] + ) data_cache.append([x_spts, y_spts, x_qrys, y_qrys]) diff --git a/examples/distributed/few-shot/maml-accs.png b/examples/distributed/few-shot/maml-accs.png new file mode 100644 index 00000000..8d70607c Binary files /dev/null and b/examples/distributed/few-shot/maml-accs.png differ diff --git a/examples/distributed/few-shot/maml_omniglot.py b/examples/distributed/few-shot/maml_omniglot.py new file mode 100644 index 00000000..f840e65e --- /dev/null +++ b/examples/distributed/few-shot/maml_omniglot.py @@ -0,0 +1,314 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/facebookresearch/higher/blob/main/examples/maml-omniglot.py +# ============================================================================== +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This example shows how to use TorchOpt to do Model Agnostic Meta Learning (MAML) +for few-shot Omniglot classification. +For more details see the original MAML paper: +https://arxiv.org/abs/1703.03400 +This code has been modified from Jackie Loong's PyTorch MAML implementation: +https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py +Our MAML++ fork and experiments are available at: +https://github.com/bamos/HowToTrainYourMAMLPytorch +""" + +import argparse +import os +import random +import time + +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from setproctitle import getproctitle, setproctitle + +import torchopt +import torchopt.distributed as todist +from helpers.omniglot_loaders import OmniglotNShot + + +mpl.use('Agg') +plt.style.use('bmh') + + +def worker_init(): + world_info = todist.get_world_info() + + proctitle = f'{world_info.worker_name}: {getproctitle().strip()}' + print(f'Worker init:=> {proctitle}') + setproctitle(proctitle) + + seed = world_info.local_rank + + os.environ['PYTHONHASHSEED'] = str(seed) + + random.seed(seed) + np.random.seed(seed) + + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + if world_info.local_rank < torch.cuda.device_count(): + torch.cuda.set_device(world_info.local_rank) + + +def build_model(args, device): + return nn.Sequential( + nn.Conv2d(1, 64, 3), + nn.BatchNorm2d(64, momentum=1.0, affine=True), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, momentum=1.0, affine=True), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, momentum=1.0, affine=True), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), + nn.Flatten(), + nn.Linear(64, args.n_way), + ).to(device) + + +@todist.rank_zero_only +def get_data_loader(args, device): + rng = np.random.default_rng(args.seed) + + return OmniglotNShot( + '/tmp/omniglot-data', + batchsz=args.task_num, + n_way=args.n_way, + k_shot=args.k_spt, + k_query=args.k_qry, + imgsz=28, + rng=rng, + device=device, + ) + + +@todist.auto_init_rpc(worker_init) +def main(): + argparser = argparse.ArgumentParser() + argparser.add_argument('--n_way', type=int, help='n way', default=5) + argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=5) + argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15) + argparser.add_argument( + '--task_num', + type=int, + help='meta batch size, namely task num', + default=32, + ) + argparser.add_argument('--seed', type=int, help='random seed', default=1) + args = argparser.parse_args() + + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + np.random.seed(args.seed) + + # Set up the Omniglot loader. + db = get_data_loader(args, device=torch.device('cpu')) + + # Create a vanilla PyTorch neural network. + net = build_model(args, device=torch.device('cpu')) + + # We will use Adam to (meta-)optimize the initial parameters + # to be adapted. + meta_opt = optim.Adam(net.parameters(), lr=1e-3) + + log = [] + test(db, net, epoch=-1, log=log) + for epoch in range(10): + train(db, net, meta_opt, epoch=epoch, log=log) + test(db, net, epoch=epoch, log=log) + plot(log) + + +def transpose_mean_reducer(results): + qry_losses, qry_accs = tuple(zip(*results)) + qry_loss = torch.mean(torch.stack(qry_losses)) + qry_acc = np.mean(qry_accs) + return qry_loss, qry_acc + + +@todist.parallelize( + partitioner=todist.dim_partitioner(dim=0, exclusive=True, keepdim=False), + reducer=transpose_mean_reducer, +) +def inner_loop(net_rref, x_spt, y_spt, x_qry, y_qry, n_inner_iter): + if torch.cuda.is_available(): + device = torch.device(f'cuda:{todist.get_local_rank() % torch.cuda.device_count()}') + torch.cuda.set_device(device) + else: + device = None + + original_net = net_rref.to_here() + # The local net can be shared across multiple RPC calls on the current worker + # We need to detach the buffers to avoid sharing the same buffers across + net = torchopt.module_clone(original_net, by='reference', detach_buffers=True, device=device) + if device is not None: + x_spt = x_spt.to(device) + y_spt = y_spt.to(device) + x_qry = x_qry.to(device) + y_qry = y_qry.to(device) + + inner_opt = torchopt.MetaSGD(net, lr=1e-1) + + for _ in range(n_inner_iter): + spt_logits = net(x_spt) + spt_loss = F.cross_entropy(spt_logits, y_spt) + inner_opt.step(spt_loss) + + qry_logits = net(x_qry) + qry_loss = F.cross_entropy(qry_logits, y_qry).cpu() + qry_acc = (qry_logits.argmax(dim=1) == y_qry).float().mean().item() + + return qry_loss, qry_acc + + +@todist.rank_zero_only +def train(db: OmniglotNShot, net: nn.Module, meta_opt: optim.Adam, epoch: int, log: list): + net.train() + n_train_iter = db.x_train.shape[0] // db.batchsz + + for batch_idx in range(n_train_iter): + start_time = time.time() + # Sample a batch of support and query images and labels. + x_spt, y_spt, x_qry, y_qry = db.next() + + # TODO: Maybe pull this out into a separate module so it + # doesn't have to be duplicated between `train` and `test`? + + # Initialize the inner optimizer to adapt the parameters to + # the support set. + n_inner_iter = 5 + + meta_opt.zero_grad() + # Sending modules contains nn.Parameter will detach from the current computation graph + # Here we explicitly convert the parameters to tensors with `CloneBackward` + net_rref = todist.rpc.RRef(torchopt.module_clone(net, by='copy')) + with todist.autograd.context() as context_id: + qry_loss, qry_acc = inner_loop(net_rref, x_spt, y_spt, x_qry, y_qry, n_inner_iter) + todist.autograd.backward(context_id, qry_loss) + meta_opt.step() + + qry_loss = qry_loss.item() + qry_acc = 100.0 * qry_acc + i = epoch + float(batch_idx) / n_train_iter + iter_time = time.time() - start_time + + print( + f'[Epoch {i:.2f}] Train Loss: {qry_loss:.2f} | Acc: {qry_acc:.2f} | Time: {iter_time:.2f}', + ) + log.append( + { + 'epoch': i, + 'loss': qry_loss, + 'acc': qry_acc, + 'mode': 'train', + 'time': time.time(), + }, + ) + + +@todist.rank_zero_only +def test(db, net, epoch, log): + # Crucially in our testing procedure here, we do *not* fine-tune + # the model during testing for simplicity. + # Most research papers using MAML for this task do an extra + # stage of fine-tuning here that should be added if you are + # adapting this code for research. + net.train() + n_test_iter = db.x_test.shape[0] // db.batchsz + + qry_losses = [] + qry_accs = [] + + net_rref = todist.rpc.RRef(net) + for _ in range(n_test_iter): + x_spt, y_spt, x_qry, y_qry = db.next('test') + + # TODO: Maybe pull this out into a separate module so it + # doesn't have to be duplicated between `train` and `test`? + n_inner_iter = 5 + + qry_loss, qry_acc = inner_loop(net_rref, x_spt, y_spt, x_qry, y_qry, n_inner_iter) + qry_losses.append(qry_loss.item()) + qry_accs.append(qry_acc) + + qry_losses = np.mean(qry_losses) + qry_accs = 100.0 * np.mean(qry_accs) + + print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}') + log.append( + { + 'epoch': epoch + 1, + 'loss': qry_losses, + 'acc': qry_accs, + 'mode': 'test', + 'time': time.time(), + }, + ) + + +@todist.rank_zero_only +def plot(log): + # Generally you should pull your plotting code out of your training + # script but we are doing it here for brevity. + df = pd.DataFrame(log) + + fig, ax = plt.subplots(figsize=(8, 4), dpi=250) + train_df = df[df['mode'] == 'train'] + test_df = df[df['mode'] == 'test'] + ax.plot(train_df['epoch'], train_df['acc'], label='Train') + ax.plot(test_df['epoch'], test_df['acc'], label='Test') + ax.set_xlabel('Epoch') + ax.set_ylabel('Accuracy') + ax.set_ylim(85, 100) + ax.set_title('Distributed MAML Omniglot') + ax.legend(ncol=2, loc='lower right') + fig.tight_layout() + fname = 'maml-accs.png' + print(f'--- Plotting accuracy to {fname}') + fig.savefig(fname) + plt.close(fig) + + +if __name__ == '__main__': + main() diff --git a/examples/distributed/few-shot/maml_omniglot_local_loader.py b/examples/distributed/few-shot/maml_omniglot_local_loader.py new file mode 100644 index 00000000..fb737d4f --- /dev/null +++ b/examples/distributed/few-shot/maml_omniglot_local_loader.py @@ -0,0 +1,358 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/facebookresearch/higher/blob/main/examples/maml-omniglot.py +# ============================================================================== +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This example shows how to use TorchOpt to do Model Agnostic Meta Learning (MAML) +for few-shot Omniglot classification. +For more details see the original MAML paper: +https://arxiv.org/abs/1703.03400 +This code has been modified from Jackie Loong's PyTorch MAML implementation: +https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py +Our MAML++ fork and experiments are available at: +https://github.com/bamos/HowToTrainYourMAMLPytorch +""" + +import argparse +import copy +import os +import random +import threading +import time + +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from setproctitle import getproctitle, setproctitle + +import torchopt +import torchopt.distributed as todist +from helpers.omniglot_loaders import OmniglotNShot + + +mpl.use('Agg') +plt.style.use('bmh') + + +LOCK = threading.Lock() +LOCAL_DATA_LOADER = None +TASK_DATA_LOADERS = {} +LOCAL_DEVICE = None + + +def worker_init(): + global LOCAL_DEVICE + + world_info = todist.get_world_info() + + proctitle = f'{world_info.worker_name}: {getproctitle().strip()}' + print(f'Worker init:=> {proctitle}') + setproctitle(proctitle) + + seed = world_info.world_rank + local_rank = world_info.local_rank + + os.environ['PYTHONHASHSEED'] = str(seed) + + random.seed(seed) + np.random.seed(seed) + + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + if world_info.local_rank < torch.cuda.device_count(): + torch.cuda.set_device(world_info.local_rank) + + if torch.cuda.is_available(): + device = torch.device(f'cuda:{local_rank % torch.cuda.device_count()}') + torch.cuda.set_device(device) + else: + device = None + LOCAL_DEVICE = device + + +def build_model(args, device): + return nn.Sequential( + nn.Conv2d(1, 64, 3), + nn.BatchNorm2d(64, momentum=1.0, affine=True), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, momentum=1.0, affine=True), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, momentum=1.0, affine=True), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), + nn.Flatten(), + nn.Linear(64, args.n_way), + ).to(device) + + +def set_local_data_loader(args, device): + global LOCAL_DATA_LOADER + + if LOCAL_DATA_LOADER is None: + rng = np.random.default_rng(args.seed) + + with LOCK: + LOCAL_DATA_LOADER = OmniglotNShot( + '/tmp/omniglot-data', + batchsz=args.task_num, + n_way=args.n_way, + k_shot=args.k_spt, + k_query=args.k_qry, + imgsz=28, + rng=rng, + device=device, + ) + + return LOCAL_DATA_LOADER + + +def get_next_batch(task_id, mode): + assert LOCAL_DATA_LOADER is not None + + if task_id not in TASK_DATA_LOADERS: + with LOCK: + TASK_DATA_LOADERS[task_id] = copy.deepcopy(LOCAL_DATA_LOADER) + + db = TASK_DATA_LOADERS[task_id] + x_spt, y_spt, x_qry, y_qry = db.next(mode) + x_spt, y_spt, x_qry, y_qry = x_spt[task_id], y_spt[task_id], x_qry[task_id], y_qry[task_id] + return x_qry, y_qry, x_spt, y_spt + + +@todist.auto_init_rpc(worker_init) +def main(): + argparser = argparse.ArgumentParser() + argparser.add_argument('--n_way', type=int, help='n way', default=5) + argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=5) + argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15) + argparser.add_argument( + '--task_num', + type=int, + help='meta batch size, namely task num', + default=32, + ) + argparser.add_argument('--seed', type=int, help='random seed', default=1) + args = argparser.parse_args() + + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + np.random.seed(args.seed) + + # Set up the Omniglot loader. + set_local_data_loader(args, device=LOCAL_DEVICE) + todist.barrier() # ensure that all workers have loaded the data + + # Create a vanilla PyTorch neural network. + net = build_model(args, device=torch.device('cpu')) + + # We will use Adam to (meta-)optimize the initial parameters + # to be adapted. + meta_opt = optim.Adam(net.parameters(), lr=1e-3) + + log = [] + test(net, epoch=-1, log=log) + for epoch in range(10): + train(net, meta_opt, epoch=epoch, log=log) + test(net, epoch=epoch, log=log) + plot(log) + + +def args_replicator(net_rref, n_inner_iter, task_id, task_num, mode): + del task_id + num_workers = todist.get_world_size() + return [ + (task_id % num_workers, (net_rref, n_inner_iter, task_id, task_num, mode), None) + for task_id in range(task_num) + ] + + +def transpose_mean_reducer(results): + qry_losses, qry_accs = tuple(zip(*results)) + qry_loss = torch.mean(torch.stack(qry_losses)) + qry_acc = np.mean(qry_accs) + return qry_loss, qry_acc + + +@todist.parallelize(partitioner=args_replicator, reducer=transpose_mean_reducer) +def inner_loop(net_rref, n_inner_iter, task_id, task_num, mode): + device = LOCAL_DEVICE + + original_net = net_rref.to_here() + # The local net can be shared across multiple RPC calls on the current worker + # We need to detach the buffers to avoid sharing the same buffers across + net = torchopt.module_clone(original_net, by='reference', detach_buffers=True, device=device) + + x_spt, y_spt, x_qry, y_qry = get_next_batch(task_id, mode) + if device is not None: + x_spt = x_spt.to(device) + y_spt = y_spt.to(device) + x_qry = x_qry.to(device) + y_qry = y_qry.to(device) + + inner_opt = torchopt.MetaSGD(net, lr=1e-1) + + for _ in range(n_inner_iter): + spt_logits = net(x_spt) + spt_loss = F.cross_entropy(spt_logits, y_spt) + inner_opt.step(spt_loss) + + qry_logits = net(x_qry) + qry_loss = F.cross_entropy(qry_logits, y_qry).cpu() + qry_acc = (qry_logits.argmax(dim=1) == y_qry).float().mean().item() + + return qry_loss, qry_acc + + +@todist.rank_zero_only +def train(net: nn.Module, meta_opt: optim.Adam, epoch: int, log: list): + net.train() + + db = LOCAL_DATA_LOADER + n_train_iter = db.x_train.shape[0] // db.batchsz + task_num = db.x_train.shape[1] + + net_rref = todist.rpc.RRef(net) + for batch_idx in range(n_train_iter): + start_time = time.time() + + # TODO: Maybe pull this out into a separate module so it + # doesn't have to be duplicated between `train` and `test`? + + # Initialize the inner optimizer to adapt the parameters to + # the support set. + n_inner_iter = 5 + + meta_opt.zero_grad() + # Sending modules contains nn.Parameter will detach from the current computation graph + # Here we explicitly convert the parameters to tensors with `CloneBackward` + net_rref = todist.rpc.RRef(torchopt.module_clone(net, by='copy')) + with todist.autograd.context() as context_id: + qry_loss, qry_acc = inner_loop(net_rref, n_inner_iter, None, task_num, 'train') + todist.autograd.backward(context_id, qry_loss) + meta_opt.step() + + qry_loss = qry_loss.item() + qry_acc = 100.0 * qry_acc + i = epoch + float(batch_idx) / n_train_iter + iter_time = time.time() - start_time + + print( + f'[Epoch {i:.2f}] Train Loss: {qry_loss:.2f} | Acc: {qry_acc:.2f} | Time: {iter_time:.2f}', + ) + log.append( + { + 'epoch': i, + 'loss': qry_loss, + 'acc': qry_acc, + 'mode': 'train', + 'time': time.time(), + }, + ) + + +@todist.rank_zero_only +def test(net, epoch, log): + # Crucially in our testing procedure here, we do *not* fine-tune + # the model during testing for simplicity. + # Most research papers using MAML for this task do an extra + # stage of fine-tuning here that should be added if you are + # adapting this code for research. + net.train() + + db = LOCAL_DATA_LOADER + n_test_iter = db.x_test.shape[0] // db.batchsz + task_num = db.x_train.shape[1] + + qry_losses = [] + qry_accs = [] + + net_rref = todist.rpc.RRef(net) + for _ in range(n_test_iter): + # TODO: Maybe pull this out into a separate module so it + # doesn't have to be duplicated between `train` and `test`? + n_inner_iter = 5 + + qry_loss, qry_acc = inner_loop(net_rref, n_inner_iter, None, task_num, 'test') + qry_losses.append(qry_loss.item()) + qry_accs.append(qry_acc) + + qry_losses = np.mean(qry_losses) + qry_accs = 100.0 * np.mean(qry_accs) + + print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}') + log.append( + { + 'epoch': epoch + 1, + 'loss': qry_losses, + 'acc': qry_accs, + 'mode': 'test', + 'time': time.time(), + }, + ) + + +@todist.rank_zero_only +def plot(log): + # Generally you should pull your plotting code out of your training + # script but we are doing it here for brevity. + df = pd.DataFrame(log) + + fig, ax = plt.subplots(figsize=(8, 4), dpi=250) + train_df = df[df['mode'] == 'train'] + test_df = df[df['mode'] == 'test'] + ax.plot(train_df['epoch'], train_df['acc'], label='Train') + ax.plot(test_df['epoch'], test_df['acc'], label='Test') + ax.set_xlabel('Epoch') + ax.set_ylabel('Accuracy') + ax.set_ylim(85, 100) + ax.set_title('Distributed MAML Omniglot') + ax.legend(ncol=2, loc='lower right') + fig.tight_layout() + fname = 'maml-accs.png' + print(f'--- Plotting accuracy to {fname}') + fig.savefig(fname) + plt.close(fig) + + +if __name__ == '__main__': + main() diff --git a/examples/few-shot/README.md b/examples/few-shot/README.md index d25eafc4..df6578f3 100644 --- a/examples/few-shot/README.md +++ b/examples/few-shot/README.md @@ -14,5 +14,5 @@ python3 maml_omniglot.py The figure illustrate the experimental result.
- +
diff --git a/examples/few-shot/helpers/omniglot_loaders.py b/examples/few-shot/helpers/omniglot_loaders.py new file mode 100644 index 00000000..52fab28a --- /dev/null +++ b/examples/few-shot/helpers/omniglot_loaders.py @@ -0,0 +1,332 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# These Omniglot loaders are from Jackie Loong's PyTorch MAML implementation: +# https://github.com/dragen1860/MAML-Pytorch +# https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot.py +# https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglotNShot.py +# ============================================================================== + +import errno +import os + +import numpy as np +import torch +import torch.utils.data as data +import torchvision.transforms as transforms +from PIL import Image + + +class Omniglot(data.Dataset): + """ + The items are ``(filename, category)``. The index of all the categories can be found in + :attr:`idx_classes`. + + Args: + root: the directory where the dataset will be stored + transform: how to transform the input + target_transform: how to transform the target + download: need to download the dataset + """ + + urls = [ + 'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip', + 'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip', + ] + raw_folder = 'raw' + processed_folder = 'processed' + training_file = 'training.pt' + test_file = 'test.pt' + + def __init__(self, root, transform=None, target_transform=None, download=False): + self.root = root + self.transform = transform + self.target_transform = target_transform + + if not self._check_exists(): + if download: + self.download() + else: + raise RuntimeError('Dataset not found. You can use download=True to download it') + + self.all_items = find_classes(os.path.join(self.root, self.processed_folder)) + self.idx_classes = index_classes(self.all_items) + + def __getitem__(self, index): + filename = self.all_items[index][0] + img = str.join('/', [self.all_items[index][2], filename]) + + target = self.idx_classes[self.all_items[index][1]] + if self.transform is not None: + img = self.transform(img) + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self): + return len(self.all_items) + + def _check_exists(self): + return os.path.exists( + os.path.join(self.root, self.processed_folder, 'images_evaluation'), + ) and os.path.exists(os.path.join(self.root, self.processed_folder, 'images_background')) + + def download(self): + import zipfile + + from six.moves import urllib + + if self._check_exists(): + return + + # download files + try: + os.makedirs(os.path.join(self.root, self.raw_folder)) + os.makedirs(os.path.join(self.root, self.processed_folder)) + except OSError as e: + if e.errno == errno.EEXIST: + pass + else: + raise + + for url in self.urls: + print('== Downloading ' + url) + data = urllib.request.urlopen(url) + filename = url.rpartition('/')[2] + file_path = os.path.join(self.root, self.raw_folder, filename) + with open(file_path, 'wb') as f: + f.write(data.read()) + file_processed = os.path.join(self.root, self.processed_folder) + print('== Unzip from ' + file_path + ' to ' + file_processed) + zip_ref = zipfile.ZipFile(file_path, 'r') + zip_ref.extractall(file_processed) + zip_ref.close() + print('Download finished.') + + +def find_classes(root_dir): + retour = [] + for root, _, files in os.walk(root_dir): + for f in files: + if f.endswith('png'): + r = root.split('/') + lr = len(r) + retour.append((f, r[lr - 2] + '/' + r[lr - 1], root)) + print('== Found %d items ' % len(retour)) + return retour + + +def index_classes(items): + idx = {} + for i in items: + if i[1] not in idx: + idx[i[1]] = len(idx) + print('== Found %d classes' % len(idx)) + return idx + + +class OmniglotNShot: + def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, rng, device=None): + """ + Different from mnistNShot, the + :param root: + :param batchsz: task num + :param n_way: + :param k_shot: + :param k_qry: + :param imgsz: + """ + + self.resize = imgsz + self.rng = rng + self.device = device + if not os.path.isfile(os.path.join(root, 'omniglot.npy')): + # if root/data.npy does not exist, just download it + self.x = Omniglot( + root, + download=True, + transform=transforms.Compose( + [ + lambda x: Image.open(x).convert('L'), + lambda x: x.resize((imgsz, imgsz)), + lambda x: np.reshape(x, (imgsz, imgsz, 1)), + lambda x: np.transpose(x, [2, 0, 1]), + lambda x: x / 255.0, + ], + ), + ) + + # {label: [img1, img2..., img20], label2: [img1, img2, ...], ... 1623 labels in total} + temp = {} + for img, label in self.x: + if label in temp: + temp[label].append(img) + else: + temp[label] = [img] + + self.x = [] + for ( + label, + imgs, + ) in temp.items(): # labels info deserted , each label contains 20imgs + self.x.append(np.array(imgs)) + + # as different class may have different number of imgs + self.x = np.array(self.x).astype(np.float) # [[20 imgs],..., 1623 classes in total] + # each character contains 20 imgs + print('data shape:', self.x.shape) # [1623, 20, 84, 84, 1] + temp = [] # Free memory + # save all dataset into npy file. + np.save(os.path.join(root, 'omniglot.npy'), self.x) + print('write into omniglot.npy.') + else: + # if data.npy exists, just load it. + self.x = np.load(os.path.join(root, 'omniglot.npy')) + print('load from omniglot.npy.') + + # NOTE: do not shuffle here, we must keep training and test set distinct! + self.x_train, self.x_test = self.x[:1200], self.x[1200:] + + self.batchsz = batchsz + self.n_cls = self.x.shape[0] # 1623 + self.n_way = n_way # n way + self.k_shot = k_shot # k shot + self.k_query = k_query # k query + assert (k_shot + k_query) <= 20 + + # save pointer of current read batch in total cache + self.indexes = {'train': 0, 'test': 0} + self.datasets = { + 'train': self.x_train, + 'test': self.x_test, + } # original data cached + print('DB: train', self.x_train.shape, 'test', self.x_test.shape) + + self.datasets_cache = { + 'train': self.load_data_cache(self.datasets['train']), # current epoch data cached + 'test': self.load_data_cache(self.datasets['test']), + } + + def normalization(self): + """ + Normalizes our data, to have a mean of 0 and sdt of 1 + """ + self.mean = np.mean(self.x_train) + self.std = np.std(self.x_train) + self.max = np.max(self.x_train) + self.min = np.min(self.x_train) + self.x_train = (self.x_train - self.mean) / self.std + self.x_test = (self.x_test - self.mean) / self.std + + self.mean = np.mean(self.x_train) + self.std = np.std(self.x_train) + self.max = np.max(self.x_train) + self.min = np.min(self.x_train) + + def load_data_cache(self, data_pack): + """ + Collects several batches data for N-shot learning + :param data_pack: [cls_num, 20, 84, 84, 1] + :return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks + """ + + # take 5 way 1 shot as example: 5 * 1 + setsz = self.k_shot * self.n_way + querysz = self.k_query * self.n_way + data_cache = [] + + for _sample in range(10): # num of episodes + x_spts, y_spts, x_qrys, y_qrys = [], [], [], [] + for _ in range(self.batchsz): # one batch means one set + x_spt, y_spt, x_qry, y_qry = [], [], [], [] + selected_cls = self.rng.choice(data_pack.shape[0], self.n_way, False) + + for j, cur_class in enumerate(selected_cls): + selected_img = self.rng.choice(20, self.k_shot + self.k_query, False) + + # meta-training and meta-test + x_spt.append(data_pack[cur_class][selected_img[: self.k_shot]]) + x_qry.append(data_pack[cur_class][selected_img[self.k_shot :]]) + y_spt.append([j for _ in range(self.k_shot)]) + y_qry.append([j for _ in range(self.k_query)]) + + # shuffle inside a batch + perm = self.rng.permutation(self.n_way * self.k_shot) + x_spt = np.array(x_spt).reshape( + self.n_way * self.k_shot, + 1, + self.resize, + self.resize, + )[perm] + y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm] + perm = self.rng.permutation(self.n_way * self.k_query) + x_qry = np.array(x_qry).reshape( + self.n_way * self.k_query, + 1, + self.resize, + self.resize, + )[perm] + y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm] + + # append [sptsz, 1, 84, 84] => [b, setsz, 1, 84, 84] + x_spts.append(x_spt) + y_spts.append(y_spt) + x_qrys.append(x_qry) + y_qrys.append(y_qry) + + x_spts = np.array(x_spts, dtype=np.float32).reshape( + self.batchsz, + setsz, + 1, + self.resize, + self.resize, + ) # [b, setsz, 1, 84, 84] + y_spts = np.array(y_spts, dtype=np.int).reshape( + self.batchsz, + setsz, + ) # [b, qrysz, 1, 84, 84] + x_qrys = np.array(x_qrys, dtype=np.float32).reshape( + self.batchsz, + querysz, + 1, + self.resize, + self.resize, + ) + y_qrys = np.array(y_qrys, dtype=np.int).reshape(self.batchsz, querysz) + + x_spts, y_spts, x_qrys, y_qrys = ( + torch.from_numpy(z).to(self.device) for z in [x_spts, y_spts, x_qrys, y_qrys] + ) + + data_cache.append([x_spts, y_spts, x_qrys, y_qrys]) + + return data_cache + + def next(self, mode='train'): + """ + Gets next batch from the dataset with name. + :param mode: The name of the splitting (one of "train", "val", "test") + :return: + """ + + # update cache if indexes is larger cached num + if self.indexes[mode] >= len(self.datasets_cache[mode]): + self.indexes[mode] = 0 + self.datasets_cache[mode] = self.load_data_cache(self.datasets[mode]) + + next_batch = self.datasets_cache[mode][self.indexes[mode]] + self.indexes[mode] += 1 + + return next_batch diff --git a/examples/few-shot/maml-accs.png b/examples/few-shot/maml-accs.png index a3a0f4ce..df0b37db 100644 Binary files a/examples/few-shot/maml-accs.png and b/examples/few-shot/maml-accs.png differ diff --git a/examples/few-shot/maml_omniglot.py b/examples/few-shot/maml_omniglot.py index 30b10559..7f7f67fe 100644 --- a/examples/few-shot/maml_omniglot.py +++ b/examples/few-shot/maml_omniglot.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -52,9 +52,7 @@ import torch.optim as optim import torchopt - - -from support.omniglot_loaders import OmniglotNShot # isort: skip +from helpers.omniglot_loaders import OmniglotNShot mpl.use('Agg') @@ -67,7 +65,10 @@ def main(): argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=5) argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15) argparser.add_argument( - '--task_num', type=int, help='meta batch size, namely task num', default=32 + '--task_num', + type=int, + help='meta batch size, namely task num', + default=32, ) argparser.add_argument('--seed', type=int, help='random seed', default=1) args = argparser.parse_args() @@ -75,11 +76,13 @@ def main(): torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True np.random.seed(args.seed) rng = np.random.default_rng(args.seed) # Set up the Omniglot loader. - device = torch.device('cuda:0') + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') db = OmniglotNShot( '/tmp/omniglot-data', batchsz=args.task_num, @@ -114,9 +117,10 @@ def main(): meta_opt = optim.Adam(net.parameters(), lr=1e-3) log = [] + test(db, net, epoch=-1, log=log) for epoch in range(10): - train(db, net, meta_opt, epoch, log) - test(db, net, epoch, log) + train(db, net, meta_opt, epoch=epoch, log=log) + test(db, net, epoch=epoch, log=log) plot(log) @@ -130,8 +134,7 @@ def train(db, net, meta_opt, epoch, log): # Sample a batch of support and query images and labels. x_spt, y_spt, x_qry, y_qry = db.next() - task_num, setsz, c_, h, w = x_spt.size() - querysz = x_qry.size(1) + task_num = x_spt.size(0) # TODO: Maybe pull this out into a separate module so it # doesn't have to be duplicated between `train` and `test`? @@ -144,8 +147,8 @@ def train(db, net, meta_opt, epoch, log): qry_accs = [] meta_opt.zero_grad() - net_state_dict = torchopt.extract_state_dict(net) - optim_state_dict = torchopt.extract_state_dict(inner_opt) + net_state_dict = torchopt.extract_state_dict(net, by='reference', detach_buffers=True) + optim_state_dict = torchopt.extract_state_dict(inner_opt, by='reference') for i in range(task_num): # Optimize the likelihood of the support set by taking # gradient steps w.r.t. the model's parameters. @@ -162,28 +165,24 @@ def train(db, net, meta_opt, epoch, log): # These will be used to update the model's meta-parameters. qry_logits = net(x_qry[i]) qry_loss = F.cross_entropy(qry_logits, y_qry[i]) - qry_losses.append(qry_loss.detach()) - qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).sum().item() / querysz - qry_accs.append(qry_acc) - - # Update the model's meta-parameters to optimize the query - # losses across all of the tasks sampled in this batch. - # This unrolls through the gradient steps. - qry_loss.backward() + qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean() + qry_losses.append(qry_loss) + qry_accs.append(qry_acc.item()) torchopt.recover_state_dict(net, net_state_dict) torchopt.recover_state_dict(inner_opt, optim_state_dict) + qry_losses = torch.mean(torch.stack(qry_losses)) + qry_losses.backward() meta_opt.step() - qry_losses = sum(qry_losses) / task_num - qry_accs = 100.0 * sum(qry_accs) / task_num + qry_losses = qry_losses.item() + qry_accs = 100.0 * np.mean(qry_accs) i = epoch + float(batch_idx) / n_train_iter iter_time = time.time() - start_time print( - f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}' + f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}', ) - log.append( { 'epoch': i, @@ -191,7 +190,7 @@ def train(db, net, meta_opt, epoch, log): 'acc': qry_accs, 'mode': 'train', 'time': time.time(), - } + }, ) @@ -208,18 +207,17 @@ def test(db, net, epoch, log): qry_losses = [] qry_accs = [] - for batch_idx in range(n_test_iter): + for _ in range(n_test_iter): x_spt, y_spt, x_qry, y_qry = db.next('test') - task_num, setsz, c_, h, w = x_spt.size() - querysz = x_qry.size(1) + task_num = x_spt.size(0) # TODO: Maybe pull this out into a separate module so it # doesn't have to be duplicated between `train` and `test`? n_inner_iter = 5 - net_state_dict = torchopt.extract_state_dict(net) - optim_state_dict = torchopt.extract_state_dict(inner_opt) + net_state_dict = torchopt.extract_state_dict(net, by='reference', detach_buffers=True) + optim_state_dict = torchopt.extract_state_dict(inner_opt, by='reference') for i in range(task_num): # Optimize the likelihood of the support set by taking # gradient steps w.r.t. the model's parameters. @@ -231,15 +229,17 @@ def test(db, net, epoch, log): # The query loss and acc induced by these parameters. qry_logits = net(x_qry[i]).detach() - qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction='none') - qry_losses.append(qry_loss.detach()) - qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach()) + qry_loss = F.cross_entropy(qry_logits, y_qry[i]) + qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean() + qry_losses.append(qry_loss.item()) + qry_accs.append(qry_acc.item()) torchopt.recover_state_dict(net, net_state_dict) torchopt.recover_state_dict(inner_opt, optim_state_dict) - qry_losses = torch.cat(qry_losses).mean().item() - qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item() + qry_losses = np.mean(qry_losses) + qry_accs = 100.0 * np.mean(qry_accs) + print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}') log.append( { @@ -248,7 +248,7 @@ def test(db, net, epoch, log): 'acc': qry_accs, 'mode': 'test', 'time': time.time(), - } + }, ) @@ -257,15 +257,16 @@ def plot(log): # script but we are doing it here for brevity. df = pd.DataFrame(log) - fig, ax = plt.subplots(figsize=(6, 4)) + fig, ax = plt.subplots(figsize=(8, 4), dpi=250) train_df = df[df['mode'] == 'train'] test_df = df[df['mode'] == 'test'] ax.plot(train_df['epoch'], train_df['acc'], label='Train') ax.plot(test_df['epoch'], test_df['acc'], label='Test') ax.set_xlabel('Epoch') ax.set_ylabel('Accuracy') - ax.set_ylim(70, 100) - fig.legend(ncol=2, loc='lower right') + ax.set_ylim(85, 100) + ax.set_title('MAML Omniglot') + ax.legend(ncol=2, loc='lower right') fig.tight_layout() fname = 'maml-accs.png' print(f'--- Plotting accuracy to {fname}') diff --git a/examples/iMAML/README.md b/examples/iMAML/README.md new file mode 100644 index 00000000..6208bc81 --- /dev/null +++ b/examples/iMAML/README.md @@ -0,0 +1,23 @@ +# implicit MAML few-shot Omniglot classification-examples + +Code on implicit MAML few-shot Omniglot classification in paper [Meta-Learning with Implicit Gradients](https://arxiv.org/abs/1909.04630) using TorchOpt. We use `torchopt.sgd` as the inner-loop optimizer. + +## Usage + +```bash +### Run +python3 imaml_omniglot.py --inner_steps 5 # use OOP APIs +python3 imaml_omniglot_functional.py --inner_steps 5 # use functional APIs +``` + +## Results + +The figure illustrate the experimental result. + +
+ +
+ +
+ +
diff --git a/examples/iMAML/helpers/omniglot_loaders.py b/examples/iMAML/helpers/omniglot_loaders.py new file mode 100644 index 00000000..52fab28a --- /dev/null +++ b/examples/iMAML/helpers/omniglot_loaders.py @@ -0,0 +1,332 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# These Omniglot loaders are from Jackie Loong's PyTorch MAML implementation: +# https://github.com/dragen1860/MAML-Pytorch +# https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot.py +# https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglotNShot.py +# ============================================================================== + +import errno +import os + +import numpy as np +import torch +import torch.utils.data as data +import torchvision.transforms as transforms +from PIL import Image + + +class Omniglot(data.Dataset): + """ + The items are ``(filename, category)``. The index of all the categories can be found in + :attr:`idx_classes`. + + Args: + root: the directory where the dataset will be stored + transform: how to transform the input + target_transform: how to transform the target + download: need to download the dataset + """ + + urls = [ + 'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip', + 'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip', + ] + raw_folder = 'raw' + processed_folder = 'processed' + training_file = 'training.pt' + test_file = 'test.pt' + + def __init__(self, root, transform=None, target_transform=None, download=False): + self.root = root + self.transform = transform + self.target_transform = target_transform + + if not self._check_exists(): + if download: + self.download() + else: + raise RuntimeError('Dataset not found. You can use download=True to download it') + + self.all_items = find_classes(os.path.join(self.root, self.processed_folder)) + self.idx_classes = index_classes(self.all_items) + + def __getitem__(self, index): + filename = self.all_items[index][0] + img = str.join('/', [self.all_items[index][2], filename]) + + target = self.idx_classes[self.all_items[index][1]] + if self.transform is not None: + img = self.transform(img) + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self): + return len(self.all_items) + + def _check_exists(self): + return os.path.exists( + os.path.join(self.root, self.processed_folder, 'images_evaluation'), + ) and os.path.exists(os.path.join(self.root, self.processed_folder, 'images_background')) + + def download(self): + import zipfile + + from six.moves import urllib + + if self._check_exists(): + return + + # download files + try: + os.makedirs(os.path.join(self.root, self.raw_folder)) + os.makedirs(os.path.join(self.root, self.processed_folder)) + except OSError as e: + if e.errno == errno.EEXIST: + pass + else: + raise + + for url in self.urls: + print('== Downloading ' + url) + data = urllib.request.urlopen(url) + filename = url.rpartition('/')[2] + file_path = os.path.join(self.root, self.raw_folder, filename) + with open(file_path, 'wb') as f: + f.write(data.read()) + file_processed = os.path.join(self.root, self.processed_folder) + print('== Unzip from ' + file_path + ' to ' + file_processed) + zip_ref = zipfile.ZipFile(file_path, 'r') + zip_ref.extractall(file_processed) + zip_ref.close() + print('Download finished.') + + +def find_classes(root_dir): + retour = [] + for root, _, files in os.walk(root_dir): + for f in files: + if f.endswith('png'): + r = root.split('/') + lr = len(r) + retour.append((f, r[lr - 2] + '/' + r[lr - 1], root)) + print('== Found %d items ' % len(retour)) + return retour + + +def index_classes(items): + idx = {} + for i in items: + if i[1] not in idx: + idx[i[1]] = len(idx) + print('== Found %d classes' % len(idx)) + return idx + + +class OmniglotNShot: + def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, rng, device=None): + """ + Different from mnistNShot, the + :param root: + :param batchsz: task num + :param n_way: + :param k_shot: + :param k_qry: + :param imgsz: + """ + + self.resize = imgsz + self.rng = rng + self.device = device + if not os.path.isfile(os.path.join(root, 'omniglot.npy')): + # if root/data.npy does not exist, just download it + self.x = Omniglot( + root, + download=True, + transform=transforms.Compose( + [ + lambda x: Image.open(x).convert('L'), + lambda x: x.resize((imgsz, imgsz)), + lambda x: np.reshape(x, (imgsz, imgsz, 1)), + lambda x: np.transpose(x, [2, 0, 1]), + lambda x: x / 255.0, + ], + ), + ) + + # {label: [img1, img2..., img20], label2: [img1, img2, ...], ... 1623 labels in total} + temp = {} + for img, label in self.x: + if label in temp: + temp[label].append(img) + else: + temp[label] = [img] + + self.x = [] + for ( + label, + imgs, + ) in temp.items(): # labels info deserted , each label contains 20imgs + self.x.append(np.array(imgs)) + + # as different class may have different number of imgs + self.x = np.array(self.x).astype(np.float) # [[20 imgs],..., 1623 classes in total] + # each character contains 20 imgs + print('data shape:', self.x.shape) # [1623, 20, 84, 84, 1] + temp = [] # Free memory + # save all dataset into npy file. + np.save(os.path.join(root, 'omniglot.npy'), self.x) + print('write into omniglot.npy.') + else: + # if data.npy exists, just load it. + self.x = np.load(os.path.join(root, 'omniglot.npy')) + print('load from omniglot.npy.') + + # NOTE: do not shuffle here, we must keep training and test set distinct! + self.x_train, self.x_test = self.x[:1200], self.x[1200:] + + self.batchsz = batchsz + self.n_cls = self.x.shape[0] # 1623 + self.n_way = n_way # n way + self.k_shot = k_shot # k shot + self.k_query = k_query # k query + assert (k_shot + k_query) <= 20 + + # save pointer of current read batch in total cache + self.indexes = {'train': 0, 'test': 0} + self.datasets = { + 'train': self.x_train, + 'test': self.x_test, + } # original data cached + print('DB: train', self.x_train.shape, 'test', self.x_test.shape) + + self.datasets_cache = { + 'train': self.load_data_cache(self.datasets['train']), # current epoch data cached + 'test': self.load_data_cache(self.datasets['test']), + } + + def normalization(self): + """ + Normalizes our data, to have a mean of 0 and sdt of 1 + """ + self.mean = np.mean(self.x_train) + self.std = np.std(self.x_train) + self.max = np.max(self.x_train) + self.min = np.min(self.x_train) + self.x_train = (self.x_train - self.mean) / self.std + self.x_test = (self.x_test - self.mean) / self.std + + self.mean = np.mean(self.x_train) + self.std = np.std(self.x_train) + self.max = np.max(self.x_train) + self.min = np.min(self.x_train) + + def load_data_cache(self, data_pack): + """ + Collects several batches data for N-shot learning + :param data_pack: [cls_num, 20, 84, 84, 1] + :return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks + """ + + # take 5 way 1 shot as example: 5 * 1 + setsz = self.k_shot * self.n_way + querysz = self.k_query * self.n_way + data_cache = [] + + for _sample in range(10): # num of episodes + x_spts, y_spts, x_qrys, y_qrys = [], [], [], [] + for _ in range(self.batchsz): # one batch means one set + x_spt, y_spt, x_qry, y_qry = [], [], [], [] + selected_cls = self.rng.choice(data_pack.shape[0], self.n_way, False) + + for j, cur_class in enumerate(selected_cls): + selected_img = self.rng.choice(20, self.k_shot + self.k_query, False) + + # meta-training and meta-test + x_spt.append(data_pack[cur_class][selected_img[: self.k_shot]]) + x_qry.append(data_pack[cur_class][selected_img[self.k_shot :]]) + y_spt.append([j for _ in range(self.k_shot)]) + y_qry.append([j for _ in range(self.k_query)]) + + # shuffle inside a batch + perm = self.rng.permutation(self.n_way * self.k_shot) + x_spt = np.array(x_spt).reshape( + self.n_way * self.k_shot, + 1, + self.resize, + self.resize, + )[perm] + y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm] + perm = self.rng.permutation(self.n_way * self.k_query) + x_qry = np.array(x_qry).reshape( + self.n_way * self.k_query, + 1, + self.resize, + self.resize, + )[perm] + y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm] + + # append [sptsz, 1, 84, 84] => [b, setsz, 1, 84, 84] + x_spts.append(x_spt) + y_spts.append(y_spt) + x_qrys.append(x_qry) + y_qrys.append(y_qry) + + x_spts = np.array(x_spts, dtype=np.float32).reshape( + self.batchsz, + setsz, + 1, + self.resize, + self.resize, + ) # [b, setsz, 1, 84, 84] + y_spts = np.array(y_spts, dtype=np.int).reshape( + self.batchsz, + setsz, + ) # [b, qrysz, 1, 84, 84] + x_qrys = np.array(x_qrys, dtype=np.float32).reshape( + self.batchsz, + querysz, + 1, + self.resize, + self.resize, + ) + y_qrys = np.array(y_qrys, dtype=np.int).reshape(self.batchsz, querysz) + + x_spts, y_spts, x_qrys, y_qrys = ( + torch.from_numpy(z).to(self.device) for z in [x_spts, y_spts, x_qrys, y_qrys] + ) + + data_cache.append([x_spts, y_spts, x_qrys, y_qrys]) + + return data_cache + + def next(self, mode='train'): + """ + Gets next batch from the dataset with name. + :param mode: The name of the splitting (one of "train", "val", "test") + :return: + """ + + # update cache if indexes is larger cached num + if self.indexes[mode] >= len(self.datasets_cache[mode]): + self.indexes[mode] = 0 + self.datasets_cache[mode] = self.load_data_cache(self.datasets[mode]) + + next_batch = self.datasets_cache[mode][self.indexes[mode]] + self.indexes[mode] += 1 + + return next_batch diff --git a/examples/iMAML/imaml-accs-functional.png b/examples/iMAML/imaml-accs-functional.png new file mode 100644 index 00000000..34922bc0 Binary files /dev/null and b/examples/iMAML/imaml-accs-functional.png differ diff --git a/examples/iMAML/imaml-accs.png b/examples/iMAML/imaml-accs.png new file mode 100644 index 00000000..1a6a5636 Binary files /dev/null and b/examples/iMAML/imaml-accs.png differ diff --git a/examples/iMAML/imaml_omniglot.py b/examples/iMAML/imaml_omniglot.py new file mode 100644 index 00000000..1db08427 --- /dev/null +++ b/examples/iMAML/imaml_omniglot.py @@ -0,0 +1,290 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +This example shows how to use TorchOpt to do iMAML-GD (see [1] for more details) +for few-shot Omniglot classification. + +[1] Rajeswaran, A., Finn, C., Kakade, S. M., & Levine, S. (2019). + Meta-learning with implicit gradients. In Advances in Neural Information Processing Systems (pp. 113-124). + https://arxiv.org/abs/1909.04630 +""" + +import argparse +import time + +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F + +import torchopt +from helpers.omniglot_loaders import OmniglotNShot +from torchopt.diff.implicit import ImplicitMetaGradientModule + + +mpl.use('Agg') +plt.style.use('bmh') + + +class InnerNet( + ImplicitMetaGradientModule, + linear_solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0), +): + def __init__(self, meta_net, n_inner_iter, reg_param): + super().__init__() + self.meta_net = meta_net + self.net = torchopt.module_clone(meta_net, by='deepcopy', detach_buffers=True) + self.n_inner_iter = n_inner_iter + self.reg_param = reg_param + self.reset_parameters() + + def reset_parameters(self): + with torch.no_grad(): + for p1, p2 in zip(self.parameters(), self.meta_parameters()): + p1.data.copy_(p2.data) + p1.detach_().requires_grad_() + + def forward(self, x): + return self.net(x) + + def objective(self, x, y): + y_pred = self(x) + loss = F.cross_entropy(y_pred, y) + regularization_loss = 0 + for p1, p2 in zip(self.parameters(), self.meta_parameters()): + regularization_loss += 0.5 * self.reg_param * torch.sum(torch.square(p1 - p2)) + return loss + regularization_loss + + def solve(self, x, y): + params = tuple(self.parameters()) + inner_optim = torchopt.SGD(params, lr=1e-1) + with torch.enable_grad(): + # Temporarily enable gradient computation for conducting the optimization + for _ in range(self.n_inner_iter): + loss = self.objective(x, y) + inner_optim.zero_grad() + loss.backward(inputs=params) + inner_optim.step() + return self + + +def main(): + argparser = argparse.ArgumentParser() + argparser.add_argument('--n_way', type=int, help='n way', default=5) + argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=5) + argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=5) + argparser.add_argument('--inner_steps', type=int, help='number of inner steps', default=5) + argparser.add_argument( + '--reg_params', + type=float, + help='regularization parameters', + default=2.0, + ) + argparser.add_argument( + '--task_num', + type=int, + help='meta batch size, namely task num', + default=16, + ) + argparser.add_argument('--seed', type=int, help='random seed', default=1) + args = argparser.parse_args() + + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + np.random.seed(args.seed) + rng = np.random.default_rng(args.seed) + + # Set up the Omniglot loader. + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + db = OmniglotNShot( + '/tmp/omniglot-data', + batchsz=args.task_num, + n_way=args.n_way, + k_shot=args.k_spt, + k_query=args.k_qry, + imgsz=28, + rng=rng, + device=device, + ) + + # Create a vanilla PyTorch neural network. + net = nn.Sequential( + nn.Conv2d(1, 64, 3), + nn.BatchNorm2d(64, momentum=1.0, affine=True, track_running_stats=False), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, momentum=1.0, affine=True, track_running_stats=False), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, momentum=1.0, affine=True, track_running_stats=False), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), + nn.Flatten(), + nn.Linear(64, args.n_way), + ).to(device) + + # We will use Adam to (meta-)optimize the initial parameters + # to be adapted. + net.train() + meta_opt = torchopt.Adam(net.parameters(), lr=1e-3) + + log = [] + test(db, net, epoch=-1, log=log, args=args) + for epoch in range(10): + train(db, net, meta_opt, epoch, log, args) + test(db, net, epoch, log, args) + plot(log) + + +def train(db, net, meta_opt, epoch, log, args): + n_train_iter = db.x_train.shape[0] // db.batchsz + n_inner_iter = args.inner_steps + reg_param = args.reg_params + task_num = args.task_num + + inner_nets = [InnerNet(net, n_inner_iter, reg_param) for _ in range(task_num)] + for batch_idx in range(n_train_iter): + start_time = time.time() + # Sample a batch of support and query images and labels. + x_spt, y_spt, x_qry, y_qry = db.next() + + qry_losses = [] + qry_accs = [] + meta_opt.zero_grad() + + for i in range(task_num): + # Optimize the likelihood of the support set by taking + # gradient steps w.r.t. the model's parameters. + # This adapts the model's meta-parameters to the task. + + inner_net = inner_nets[i] + inner_net.reset_parameters() + optimal_inner_net = inner_net.solve(x_spt[i], y_spt[i]) + + # The final set of adapted parameters will induce some + # final loss and accuracy on the query dataset. + # These will be used to update the model's meta-parameters. + qry_logits = optimal_inner_net(x_qry[i]) + qry_loss = F.cross_entropy(qry_logits, y_qry[i]) + qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean() + qry_losses.append(qry_loss) + qry_accs.append(qry_acc.item()) + + qry_losses = torch.mean(torch.stack(qry_losses)) + qry_losses.backward() + meta_opt.step() + qry_losses = qry_losses.item() + qry_accs = 100.0 * np.mean(qry_accs) + i = epoch + float(batch_idx) / n_train_iter + iter_time = time.time() - start_time + + print( + f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}', + ) + log.append( + { + 'epoch': i, + 'loss': qry_losses, + 'acc': qry_accs, + 'mode': 'train', + 'time': time.time(), + }, + ) + + +def test(db, net, epoch, log, args): + # Crucially in our testing procedure here, we do *not* fine-tune + # the model during testing for simplicity. + # Most research papers using MAML for this task do an extra + # stage of fine-tuning here that should be added if you are + # adapting this code for research. + n_test_iter = db.x_test.shape[0] // db.batchsz + + qry_losses = [] + qry_accs = [] + + # TODO: Maybe pull this out into a separate module so it + # doesn't have to be duplicated between `train` and `test`? + n_inner_iter = args.inner_steps + reg_param = args.reg_params + + for _ in range(n_test_iter): + x_spt, y_spt, x_qry, y_qry = db.next('test') + + task_num = x_spt.size(0) + + for i in range(task_num): + # Optimize the likelihood of the support set by taking + # gradient steps w.r.t. the model's parameters. + # This adapts the model's meta-parameters to the task. + + inner_net = InnerNet(net, n_inner_iter, reg_param) + with torch.no_grad(): + optimal_inner_net = inner_net.solve(x_spt[i], y_spt[i]) + + # The query loss and acc induced by these parameters. + qry_logits = optimal_inner_net(x_qry[i]) + qry_loss = F.cross_entropy(qry_logits, y_qry[i]) + qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean() + qry_losses.append(qry_loss.item()) + qry_accs.append(qry_acc.item()) + + qry_losses = np.mean(qry_losses) + qry_accs = 100.0 * np.mean(qry_accs) + + print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}') + log.append( + { + 'epoch': epoch + 1, + 'loss': qry_losses, + 'acc': qry_accs, + 'mode': 'test', + 'time': time.time(), + }, + ) + + +def plot(log): + # Generally you should pull your plotting code out of your training + # script but we are doing it here for brevity. + df = pd.DataFrame(log) + + fig, ax = plt.subplots(figsize=(8, 4), dpi=250) + train_df = df[df['mode'] == 'train'] + test_df = df[df['mode'] == 'test'] + ax.plot(train_df['epoch'], train_df['acc'], label='Train') + ax.plot(test_df['epoch'], test_df['acc'], label='Test') + ax.set_xlabel('Epoch') + ax.set_ylabel('Accuracy') + ax.set_ylim(80, 100) + ax.set_title('iMAML Omniglot') + ax.legend(ncol=2, loc='lower right') + fig.tight_layout() + fname = 'imaml-accs.png' + print(f'--- Plotting accuracy to {fname}') + fig.savefig(fname) + plt.close(fig) + + +if __name__ == '__main__': + main() diff --git a/examples/iMAML/imaml_omniglot_functional.py b/examples/iMAML/imaml_omniglot_functional.py new file mode 100644 index 00000000..7bc1e9da --- /dev/null +++ b/examples/iMAML/imaml_omniglot_functional.py @@ -0,0 +1,340 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +This example shows how to use TorchOpt to do iMAML-GD (see [1] for more details) +for few-shot Omniglot classification. + +[1] Rajeswaran, A., Finn, C., Kakade, S. M., & Levine, S. (2019). + Meta-learning with implicit gradients. In Advances in Neural Information Processing Systems (pp. 113-124). + https://arxiv.org/abs/1909.04630 +""" + +import argparse +import time + +import functorch +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F + +import torchopt +from helpers.omniglot_loaders import OmniglotNShot +from torchopt import pytree + + +mpl.use('Agg') +plt.style.use('bmh') + + +def main(): + argparser = argparse.ArgumentParser() + argparser.add_argument('--n_way', type=int, help='n way', default=5) + argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=5) + argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=5) + argparser.add_argument('--inner_steps', type=int, help='number of inner steps', default=5) + argparser.add_argument( + '--reg_params', + type=float, + help='regularization parameters', + default=2.0, + ) + argparser.add_argument( + '--task_num', + type=int, + help='meta batch size, namely task num', + default=16, + ) + argparser.add_argument('--seed', type=int, help='random seed', default=1) + args = argparser.parse_args() + + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + np.random.seed(args.seed) + rng = np.random.default_rng(args.seed) + + # Set up the Omniglot loader. + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + db = OmniglotNShot( + '/tmp/omniglot-data', + batchsz=args.task_num, + n_way=args.n_way, + k_shot=args.k_spt, + k_query=args.k_qry, + imgsz=28, + rng=rng, + device=device, + ) + + # Create a vanilla PyTorch neural network. + net = nn.Sequential( + nn.Conv2d(1, 64, 3), + nn.BatchNorm2d(64, momentum=1.0, affine=True, track_running_stats=False), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, momentum=1.0, affine=True, track_running_stats=False), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, momentum=1.0, affine=True, track_running_stats=False), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), + nn.Flatten(), + nn.Linear(64, args.n_way), + ).to(device) + + # We will use Adam to (meta-)optimize the initial parameters + # to be adapted. + net.train() + fnet, meta_params = model = functorch.make_functional(net) + meta_opt = torchopt.adam(lr=1e-3) + meta_opt_state = meta_opt.init(meta_params) + + log = [] + test(db, model, epoch=-1, log=log, args=args) + for epoch in range(10): + meta_opt, meta_opt_state = train(db, model, (meta_opt, meta_opt_state), epoch, log, args) + test(db, model, epoch, log, args) + plot(log) + + +def train(db, model, meta_opt_and_state, epoch, log, args): + n_train_iter = db.x_train.shape[0] // db.batchsz + fnet, meta_params = model + meta_opt, meta_opt_state = meta_opt_and_state + # Given this module we've created, rip out the parameters and buffers + # and return a functional version of the module. `fnet` is stateless + # and can be called with `fnet(params, buffers, args, kwargs)` + # fnet, params, buffers = functorch.make_functional_with_buffers(net) + + for batch_idx in range(n_train_iter): + start_time = time.time() + # Sample a batch of support and query images and labels. + x_spt, y_spt, x_qry, y_qry = db.next() + + task_num = x_spt.size(0) + + n_inner_iter = args.inner_steps + reg_param = args.reg_params + + qry_losses = [] + qry_accs = [] + + for i in range(task_num): + # Optimize the likelihood of the support set by taking + # gradient steps w.r.t. the model's parameters. + # This adapts the model's meta-parameters to the task. + + init_params = pytree.tree_map( + lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), + meta_params, + ) + optimal_params = train_imaml_inner_solver( + init_params, + meta_params, + (x_spt[i], y_spt[i]), + (fnet, n_inner_iter, reg_param), + ) + # The final set of adapted parameters will induce some + # final loss and accuracy on the query dataset. + # These will be used to update the model's meta-parameters. + qry_logits = fnet(optimal_params, x_qry[i]) + qry_loss = F.cross_entropy(qry_logits, y_qry[i]) + qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean() + qry_losses.append(qry_loss) + qry_accs.append(qry_acc.item()) + + qry_losses = torch.mean(torch.stack(qry_losses)) + meta_grads = torch.autograd.grad(qry_losses, meta_params) + meta_updates, meta_opt_state = meta_opt.update(meta_grads, meta_opt_state) + meta_params = torchopt.apply_updates(meta_params, meta_updates) + qry_losses = qry_losses.item() + qry_accs = 100.0 * np.mean(qry_accs) + i = epoch + float(batch_idx) / n_train_iter + iter_time = time.time() - start_time + + print( + f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}', + ) + log.append( + { + 'epoch': i, + 'loss': qry_losses, + 'acc': qry_accs, + 'mode': 'train', + 'time': time.time(), + }, + ) + + return (meta_opt, meta_opt_state) + + +def test(db, model, epoch, log, args): + # Crucially in our testing procedure here, we do *not* fine-tune + # the model during testing for simplicity. + # Most research papers using MAML for this task do an extra + # stage of fine-tuning here that should be added if you are + # adapting this code for research. + fnet, meta_params = model + n_test_iter = db.x_test.shape[0] // db.batchsz + + n_inner_iter = args.inner_steps + reg_param = args.reg_params + qry_losses = [] + qry_accs = [] + + for _ in range(n_test_iter): + x_spt, y_spt, x_qry, y_qry = db.next('test') + + task_num = x_spt.size(0) + + for i in range(task_num): + # Optimize the likelihood of the support set by taking + # gradient steps w.r.t. the model's parameters. + # This adapts the model's meta-parameters to the task. + + init_params = pytree.tree_map( + lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), + meta_params, + ) + optimal_params = test_imaml_inner_solver( + init_params, + meta_params, + (x_spt[i], y_spt[i]), + (fnet, n_inner_iter, reg_param), + ) + + # The query loss and acc induced by these parameters. + qry_logits = fnet(optimal_params, x_qry[i]) + qry_loss = F.cross_entropy(qry_logits, y_qry[i]) + qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean() + qry_losses.append(qry_loss.item()) + qry_accs.append(qry_acc.item()) + + qry_losses = np.mean(qry_losses) + qry_accs = 100.0 * np.mean(qry_accs) + + print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}') + log.append( + { + 'epoch': epoch + 1, + 'loss': qry_losses, + 'acc': qry_accs, + 'mode': 'test', + 'time': time.time(), + }, + ) + + +def imaml_objective(params, meta_params, data, aux): + x_spt, y_spt = data + fnet, n_inner_iter, reg_param = aux + y_pred = fnet(params, x_spt) + regularization_loss = 0 + for p1, p2 in zip(params, meta_params): + regularization_loss += 0.5 * reg_param * torch.sum(torch.square(p1 - p2)) + loss = F.cross_entropy(y_pred, y_spt) + regularization_loss + return loss + + +@torchopt.diff.implicit.custom_root( + functorch.grad(imaml_objective, argnums=0), + argnums=1, + has_aux=False, + solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0), +) +def train_imaml_inner_solver(params, meta_params, data, aux): + x_spt, y_spt = data + fnet, n_inner_iter, reg_param = aux + # Initial functional optimizer based on TorchOpt + inner_opt = torchopt.sgd(lr=1e-1) + inner_opt_state = inner_opt.init(params) + with torch.enable_grad(): + # Temporarily enable gradient computation for conducting the optimization + for _ in range(n_inner_iter): + pred = fnet(params, x_spt) + loss = F.cross_entropy(pred, y_spt) # compute loss + # Compute regularization loss + regularization_loss = 0 + for p1, p2 in zip(params, meta_params): + regularization_loss += 0.5 * reg_param * torch.sum(torch.square(p1 - p2)) + final_loss = loss + regularization_loss + grads = torch.autograd.grad(final_loss, params) # compute gradients + updates, inner_opt_state = inner_opt.update( + grads, + inner_opt_state, + inplace=True, + ) # get updates + params = torchopt.apply_updates(params, updates, inplace=True) + return params + + +def test_imaml_inner_solver(params, meta_params, data, aux): + x_spt, y_spt = data + fnet, n_inner_iter, reg_param = aux + # Initial functional optimizer based on TorchOpt + inner_opt = torchopt.sgd(lr=1e-1) + inner_opt_state = inner_opt.init(params) + with torch.enable_grad(): + # Temporarily enable gradient computation for conducting the optimization + for _ in range(n_inner_iter): + pred = fnet(params, x_spt) + loss = F.cross_entropy(pred, y_spt) # compute loss + # Compute regularization loss + regularization_loss = 0 + for p1, p2 in zip(params, meta_params): + regularization_loss += 0.5 * reg_param * torch.sum(torch.square(p1 - p2)) + final_loss = loss + regularization_loss + grads = torch.autograd.grad(final_loss, params) # compute gradients + updates, inner_opt_state = inner_opt.update( + grads, + inner_opt_state, + inplace=True, + ) # get updates + params = torchopt.apply_updates(params, updates, inplace=True) + return params + + +def plot(log): + # Generally you should pull your plotting code out of your training + # script but we are doing it here for brevity. + df = pd.DataFrame(log) + + fig, ax = plt.subplots(figsize=(8, 4), dpi=250) + train_df = df[df['mode'] == 'train'] + test_df = df[df['mode'] == 'test'] + ax.plot(train_df['epoch'], train_df['acc'], label='Train') + ax.plot(test_df['epoch'], test_df['acc'], label='Test') + ax.set_xlabel('Epoch') + ax.set_ylabel('Accuracy') + ax.set_ylim(80, 100) + ax.set_title('iMAML Omniglot (Functional)') + ax.legend(ncol=2, loc='lower right') + fig.tight_layout() + fname = 'imaml-accs-functional.png' + print(f'--- Plotting accuracy to {fname}') + fig.savefig(fname) + plt.close(fig) + + +if __name__ == '__main__': + main() diff --git a/examples/requirements.txt b/examples/requirements.txt index 66636aad..48945c62 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -1,7 +1,6 @@ ---extra-index-url https://download.pytorch.org/whl/cu116 -torch >= 1.12 +--extra-index-url https://download.pytorch.org/whl/cu121 +torch >= 2.0 torchvision -functorch >= 0.2 --requirement ../requirements.txt @@ -12,3 +11,4 @@ seaborn torchviz torchrl pillow +setproctitle diff --git a/examples/visualize.py b/examples/visualize.py index 56de2bd5..067fa511 100644 --- a/examples/visualize.py +++ b/examples/visualize.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -66,7 +66,8 @@ def draw_torchopt(): loss = F.mse_loss(pred, torch.ones_like(pred)) # draw computation graph torchopt.visual.make_dot(loss, [net_state_0, net_state_1, {meta_param: 'meta_param'}]).render( - 'torchopt_graph', format='svg' + 'torchopt_graph', + format='svg', ) diff --git a/image/TorchOpt.png b/image/TorchOpt.png deleted file mode 100644 index 04a90032..00000000 Binary files a/image/TorchOpt.png and /dev/null differ diff --git a/image/diffmode.png b/image/diffmode.png new file mode 100644 index 00000000..e33df7a9 Binary files /dev/null and b/image/diffmode.png differ diff --git a/image/time.png b/image/time.png deleted file mode 100644 index 6d246d2c..00000000 Binary files a/image/time.png and /dev/null differ diff --git a/image/torchviz_torchopt.jpg b/image/torchviz-vs-torchopt.jpg similarity index 100% rename from image/torchviz_torchopt.jpg rename to image/torchviz-vs-torchopt.jpg diff --git a/include/adam_op/adam_op.h b/include/adam_op/adam_op.h index 8b7ae2bf..2d0abcd3 100644 --- a/include/adam_op/adam_op.h +++ b/include/adam_op/adam_op.h @@ -1,4 +1,4 @@ -// Copyright 2022 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2024 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ // ============================================================================= #pragma once + #include #include @@ -67,9 +68,10 @@ TensorArray<2> adamBackwardUpdates(const torch::Tensor &dupdates, const torch::Tensor &new_nu, const pyfloat_t b1, const pyfloat_t b2, + const pyfloat_t eps_root, const pyuint_t count); -void buildSubmodule(py::module &mod); // NOLINT +void buildSubmodule(py::module &mod); // NOLINT[runtime/references] } // namespace adam_op } // namespace torchopt diff --git a/include/adam_op/adam_op_impl_cpu.h b/include/adam_op/adam_op_impl_cpu.h index 3e8da376..4d54377e 100644 --- a/include/adam_op/adam_op_impl_cpu.h +++ b/include/adam_op/adam_op_impl_cpu.h @@ -1,4 +1,4 @@ -// Copyright 2022 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2024 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ // ============================================================================= #pragma once + #include #include @@ -63,6 +64,7 @@ TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor &dupdates, const torch::Tensor &new_nu, const pyfloat_t b1, const pyfloat_t b2, + const pyfloat_t eps_root, const pyuint_t count); } // namespace adam_op } // namespace torchopt diff --git a/include/adam_op/adam_op_impl_cuda.cuh b/include/adam_op/adam_op_impl_cuda.cuh index a7ddb937..17002b36 100644 --- a/include/adam_op/adam_op_impl_cuda.cuh +++ b/include/adam_op/adam_op_impl_cuda.cuh @@ -1,4 +1,4 @@ -// Copyright 2022 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2024 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ // ============================================================================= #pragma once + #include #include @@ -63,6 +64,7 @@ TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor &dupdates, const torch::Tensor &new_nu, const pyfloat_t b1, const pyfloat_t b2, + const pyfloat_t eps_root, const pyuint_t count); } // namespace adam_op } // namespace torchopt diff --git a/include/common.h b/include/common.h index 5353e48e..256b0ca1 100644 --- a/include/common.h +++ b/include/common.h @@ -1,4 +1,4 @@ -// Copyright 2022 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2024 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ // ============================================================================= #pragma once + #include #include diff --git a/include/utils.h b/include/utils.h index 714f98d4..cefabfac 100644 --- a/include/utils.h +++ b/include/utils.h @@ -1,4 +1,4 @@ -// Copyright 2022 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2024 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ // ============================================================================= #pragma once + #include #include diff --git a/pyproject.toml b/pyproject.toml index 47af443f..d343e04a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,14 +1,17 @@ # Package ###################################################################### [build-system] -requires = ["setuptools", "torch >= 1.12", "numpy", "pybind11"] +# Sync with project.dependencies +requires = ["setuptools", "torch >= 2.0", "numpy", "pybind11 >= 2.11.1"] build-backend = "setuptools.build_meta" [project] name = "torchopt" -description = "A Jax-style optimizer for PyTorch." +description = "An efficient library for differentiable optimization for PyTorch." readme = "README.md" -requires-python = ">= 3.7" +# Change this if wheels for `torch` is available +# Search "requires-python" and update all corresponding items +requires-python = ">= 3.8" authors = [ { name = "TorchOpt Contributors" }, { name = "Jie Ren", email = "jieren9806@gmail.com" }, @@ -19,7 +22,7 @@ authors = [ license = { text = "Apache License, Version 2.0" } keywords = [ "PyTorch", - "functorch", + "FuncTorch", "JAX", "Meta-Learning", "Optimizer", @@ -29,12 +32,17 @@ keywords = [ classifiers = [ "Development Status :: 4 - Beta", "License :: OSI Approved :: Apache Software License", + # Sync with requires-python "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", + "Operating System :: Microsoft :: Windows", "Operating System :: POSIX :: Linux", + "Operating System :: MacOS", "Environment :: GPU", "Environment :: GPU :: NVIDIA CUDA", "Intended Audience :: Developers", @@ -44,8 +52,9 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", ] dependencies = [ - "torch >= 1.12", - "optree", + # See also build-system.requires and project.requires-python + "torch >= 2.0", + "optree >= 0.4.1", "numpy", "graphviz", "typing-extensions", @@ -61,43 +70,58 @@ Documentation = "https://torchopt.readthedocs.io" [project.optional-dependencies] lint = [ "isort", - "black >= 22.6.0", - "pylint", + "black[jupyter]", + "pylint[spelling]", "mypy", "flake8", "flake8-bugbear", - "doc8 < 1.0.0a0", - "pydocstyle", + "flake8-comprehensions", + "flake8-docstrings", + "flake8-pyi", + "flake8-simplify", + "ruff", + "doc8", + "pydocstyle[toml]", "pyenchant", "cpplint", "pre-commit", ] test = [ - 'functorch >= 0.2', - 'pytest', - 'pytest-cov', - 'pytest-xdist', + "pytest", + "pytest-cov", + "pytest-xdist", + "jax[cpu] >= 0.4; platform_system != 'Windows'", + "jaxopt; platform_system != 'Windows'", + "optax; platform_system != 'Windows'", ] +[tool.setuptools] +include-package-data = true + [tool.setuptools.packages.find] include = ["torchopt", "torchopt.*"] +[tool.setuptools.package-data] +torchopt = ['*.so', '*.pyd'] + # Wheel builder ################################################################ # Reference: https://cibuildwheel.readthedocs.io [tool.cibuildwheel] -archs = ["x86_64"] +archs = ["auto64"] build = "*manylinux*" -skip = "pp*" +skip = "pp* *musllinux*" build-frontend = "pip" build-verbosity = 3 environment.USE_FP16 = "ON" environment.CUDACXX = "/usr/local/cuda/bin/nvcc" environment.TORCH_CUDA_ARCH_LIST = "Common" -environment.DEFAULT_CUDA_VERSION = "11.6" -environment.DEFAULT_TEST_TORCH_SPECS = "cpu cu113 cu116" +environment.DEFAULT_CUDA_VERSION = "12.1" +environment.DEFAULT_TEST_TORCH_SPECS = "cpu cu118" environment-pass = ["CUDA_VERSION", "TEST_TORCH_SPECS"] container-engine = "docker" +test-extras = ["test"] +[tool.cibuildwheel.linux] before-all = """ CUDA_VERSION="${CUDA_VERSION:-"${DEFAULT_CUDA_VERSION}"}" if [[ "${CUDA_VERSION}" == "None" || "${CUDA_VERSION}" == "none" ]]; then @@ -111,32 +135,8 @@ before-all = """ yum install -y nvidia-driver-latest-libs "cuda-minimal-build-${CUDA_PKG_SUFFIX}" fi echo "cat torchopt/version.py"; cat torchopt/version.py - """ -test-extras = ["test"] -test-command = """ - SITE_PACKAGES="$(python -c 'print(__import__("sysconfig").get_path("purelib"))')" - TORCH_LIB_PATH="${SITE_PACKAGES}/torch/lib" - echo "LD_LIBRARY_PATH='${LD_LIBRARY_PATH}'" - echo "ls ${TORCH_LIB_PATH}"; ls -lh "${TORCH_LIB_PATH}" - find "${SITE_PACKAGES}/torchopt" -name "*.so" -print0 | - xargs -0 -I '{}' bash -c "echo 'ldd {}'; ldd '{}'; echo 'patchelf --print-rpath {}'; patchelf --print-rpath '{}'" - make -C "{project}" test || exit 1 - TORCH_VERSION="$(python -c 'print(__import__("torch").__version__.partition("+")[0])')" - TEST_TORCH_SPECS="${TEST_TORCH_SPECS:-"${DEFAULT_TEST_TORCH_SPECS}"}" - for spec in ${TEST_TORCH_SPECS}; do - python -m pip uninstall -y torch - export PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/${spec}" - echo "PIP_EXTRA_INDEX_URL='${PIP_EXTRA_INDEX_URL}'" - python -m pip install "torch==${TORCH_VERSION}" - echo "ls ${TORCH_LIB_PATH}"; ls -lh "${TORCH_LIB_PATH}" - find "${SITE_PACKAGES}/torchopt" -name "*.so" -print0 | - xargs -0 -I '{}' bash -c "echo 'ldd {}'; ldd '{}'; echo 'patchelf --print-rpath {}'; patchelf --print-rpath '{}'" - make -C "{project}" test || exit 1 - done - rm -rf ~/.pip/cache ~/.cache/pip - """ - -[tool.cibuildwheel.linux] + touch .first-python +""" repair-wheel-command = """ python -m pip install -r requirements.txt SITE_PACKAGES="$(python -c 'print(__import__("sysconfig").get_path("purelib"))')" @@ -148,35 +148,64 @@ repair-wheel-command = """ python -m auditwheel lddtree "{wheel}" python -m auditwheel repair --no-copy-site-libs --wheel-dir="{dest_dir}" "{wheel}" ) - """ +""" +test-command = """ + SITE_PACKAGES="$(python -c 'print(__import__("sysconfig").get_path("purelib"))')" + TORCH_LIB_PATH="${SITE_PACKAGES}/torch/lib" + echo "LD_LIBRARY_PATH='${LD_LIBRARY_PATH}'" + echo "ls ${TORCH_LIB_PATH}"; ls -lh "${TORCH_LIB_PATH}" + find "${SITE_PACKAGES}/torchopt" -name "*.so" -print0 | + xargs -0 -I '{}' bash -c "echo 'ldd {}'; ldd '{}'; echo 'patchelf --print-rpath {}'; patchelf --print-rpath '{}'" + make -C "{project}" test || exit 1 + TORCH_VERSION="$(python -c 'print(__import__("torch").__version__.partition("+")[0])')" + if [[ -f .first-python ]]; then + TEST_TORCH_SPECS="${TEST_TORCH_SPECS:-"${DEFAULT_TEST_TORCH_SPECS}"}" + for spec in ${TEST_TORCH_SPECS}; do + python -m pip uninstall -y torch + export PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/${spec}" + echo "PIP_EXTRA_INDEX_URL='${PIP_EXTRA_INDEX_URL}'" + python -m pip install "torch==${TORCH_VERSION}" + echo "ls ${TORCH_LIB_PATH}"; ls -lh "${TORCH_LIB_PATH}" + find "${SITE_PACKAGES}/torchopt" -name "*.so" -print0 | + xargs -0 -I '{}' bash -c "echo 'ldd {}'; ldd '{}'; echo 'patchelf --print-rpath {}'; patchelf --print-rpath '{}'" + make -C "{project}" test || exit 1 + done + rm -f .first-python + fi + rm -rf ~/.pip/cache ~/.cache/pip +""" # Linter tools ################################################################# [tool.black] -safe = true line-length = 100 skip-string-normalization = true -target-version = ["py37", "py38", "py39", "py310"] +# Sync with requires-python +target-version = ["py38"] [tool.isort] +atomic = true profile = "black" src_paths = ["torchopt", "examples", "tests"] +extra_standard_library = ["typing_extensions"] indent = 4 line_length = 100 lines_after_imports = 2 multi_line_output = 3 [tool.mypy] -allow_redefinition = true -check_untyped_defs = true -disallow_incomplete_defs = false -disallow_untyped_defs = false -ignore_missing_imports = true -no_implicit_optional = true +# Sync with requires-python +python_version = "3.8" pretty = true show_error_codes = true show_error_context = true show_traceback = true +allow_redefinition = true +check_untyped_defs = true +disallow_incomplete_defs = true +disallow_untyped_defs = true +ignore_missing_imports = true +no_implicit_optional = true strict_equality = true strict_optional = true warn_no_return = true @@ -190,3 +219,116 @@ convention = "google" [tool.doc8] max-line-length = 500 + +[tool.codespell] +ignore-words = "docs/source/spelling_wordlist.txt" + +[tool.ruff] +# Sync with requires-python +target-version = "py38" +line-length = 100 +output-format = "full" +src = ["torchopt", "tests"] +extend-exclude = ["examples"] + +[tool.ruff.lint] +select = [ + "E", "W", # pycodestyle + "F", # pyflakes + "C90", # mccabe + "UP", # pyupgrade + "ANN", # flake8-annotations + "S", # flake8-bandit + "BLE", # flake8-blind-except + "B", # flake8-bugbear + "COM", # flake8-commas + "C4", # flake8-comprehensions + "EXE", # flake8-executable + "FA", # flake8-future-annotations + "LOG", # flake8-logging + "ISC", # flake8-implicit-str-concat + "INP", # flake8-no-pep420 + "PIE", # flake8-pie + "PYI", # flake8-pyi + "Q", # flake8-quotes + "RSE", # flake8-raise + "RET", # flake8-return + "SIM", # flake8-simplify + "TID", # flake8-tidy-imports + "TCH", # flake8-type-checking + "PERF", # perflint + "FURB", # refurb + "TRY", # tryceratops + "RUF", # ruff +] +ignore = [ + # E501: line too long + # W505: doc line too long + # too long docstring due to long example blocks + "E501", + "W505", + # ANN101: missing type annotation for `self` in method + # ANN102: missing type annotation for `cls` in classmethod + "ANN101", + "ANN102", + # ANN401: dynamically typed expressions (typing.Any) are disallowed + "ANN401", + # S101: use of `assert` detected + # internal use and may never raise at runtime + "S101", + # TRY003: avoid specifying long messages outside the exception class + # long messages are necessary for clarity + "TRY003", +] +typing-modules = ["torchopt.typing"] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = [ + "F401", # unused-import +] +"torchopt/pytree.py" = [ + "F401", # unused-import + "F403", # import-star + "F405", # import-star-usage +] +"setup.py" = [ + "ANN", # flake8-annotations +] +"tests/**/*.py" = [ + "ANN", # flake8-annotations + "S", # flake8-bandit + "BLE", # flake8-blind-except +] +"tests/test_import.py" = [ + "B018", # useless-expression + "F401", # unused-import + "F811", # redefined-while-unused +] +"docs/source/conf.py" = [ + "INP001", # flake8-no-pep420 +] + +[tool.ruff.lint.flake8-annotations] +allow-star-arg-any = true + +[tool.ruff.lint.flake8-quotes] +docstring-quotes = "double" +multiline-quotes = "double" +inline-quotes = "single" + +[tool.ruff.lint.flake8-tidy-imports] +ban-relative-imports = "all" + +[tool.ruff.lint.pylint] +allow-magic-value-types = ["int", "str", "float"] + +[tool.pytest.ini_options] +filterwarnings = [ + "error", + 'ignore:Explicitly requested dtype float64 requested in .* is not available, and will be truncated to dtype float32\.:UserWarning', + 'ignore:jax\.numpy\.DeviceArray is deprecated\. Use jax\.Array\.:DeprecationWarning', + 'ignore:(ast\.Str|ast\.NameConstant|Attribute s) is deprecated and will be removed in Python 3\.14:DeprecationWarning', + 'ignore:.*functorch.*deprecate.*:UserWarning', + 'ignore:.*Apple Paravirtual device.*:UserWarning', + 'ignore:.*NVML.*:UserWarning', +] diff --git a/requirements.txt b/requirements.txt index a2ced2f2..a5151c36 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ -torch >= 1.12 -optree +# Sync with project.dependencies +torch >= 2.0 +optree >= 0.4.1 numpy graphviz typing-extensions diff --git a/setup.py b/setup.py index e0df95db..c50ba5ed 100644 --- a/setup.py +++ b/setup.py @@ -1,28 +1,25 @@ +import contextlib import os import pathlib +import platform +import re import shutil import sys +import sysconfig +from importlib.util import module_from_spec, spec_from_file_location -from setuptools import setup +from setuptools import Extension, setup +from setuptools.command.build_ext import build_ext -try: - from pybind11.setup_helpers import Pybind11Extension as Extension - from pybind11.setup_helpers import build_ext -except ImportError: - from setuptools import Extension - from setuptools.command.build_ext import build_ext - HERE = pathlib.Path(__file__).absolute().parent -sys.path.insert(0, str(HERE / 'torchopt')) -import version # noqa - class CMakeExtension(Extension): - def __init__(self, name, source_dir='.', **kwargs): + def __init__(self, name, source_dir='.', target=None, **kwargs): super().__init__(name, sources=[], **kwargs) self.source_dir = os.path.abspath(source_dir) + self.target = target if target is not None else name.rpartition('.')[-1] class cmake_build_ext(build_ext): @@ -31,60 +28,124 @@ def build_extension(self, ext): super().build_extension(ext) return - import pybind11 from torch.utils import cpp_extension cmake = shutil.which('cmake') if cmake is None: raise RuntimeError('Cannot find CMake executable.') - build_temp = pathlib.Path(self.build_temp) + ext_path = pathlib.Path(self.get_ext_fullpath(ext.name)).absolute() + build_temp = pathlib.Path(self.build_temp).absolute() build_temp.mkdir(parents=True, exist_ok=True) config = 'Debug' if self.debug else 'Release' - - extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) - print(self.get_ext_fullpath(ext.name)) - - PYTHON_INCLUDE_DIR = ';'.join(self.include_dirs) - TORCH_INCLUDE_PATH = ';'.join(cpp_extension.include_paths()) - TORCH_LIBRARY_PATH = ';'.join(cpp_extension.library_paths()) - cmake_args = [ f'-DCMAKE_BUILD_TYPE={config}', - f'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{config.upper()}={extdir}', - f'-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY_{config.upper()}={self.build_temp}', + f'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{config.upper()}={ext_path.parent}', + f'-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY_{config.upper()}={build_temp}', f'-DPYTHON_EXECUTABLE={sys.executable}', - f'-DPYBIND11_CMAKE_DIR={pybind11.get_cmake_dir()}', - f'-DPYTHON_INCLUDE_DIR={PYTHON_INCLUDE_DIR}', - f'-DTORCH_INCLUDE_PATH={TORCH_INCLUDE_PATH}', - f'-DTORCH_LIBRARY_PATH={TORCH_LIBRARY_PATH}', + f'-DPYTHON_INCLUDE_DIR={sysconfig.get_path("platinclude")}', + f'-DTORCH_INCLUDE_PATH={";".join(cpp_extension.include_paths())}', + f'-DTORCH_LIBRARY_PATH={";".join(cpp_extension.library_paths())}', ] - build_args = ['--config', config] + if platform.system() == 'Darwin': + # Cross-compile support for macOS - respect ARCHFLAGS if set + archs = re.findall(r'-arch (\S+)', os.environ.get('ARCHFLAGS', '')) + if archs: + cmake_args.append(f'-DCMAKE_OSX_ARCHITECTURES={";".join(archs)}') + + try: + import pybind11 + + cmake_args.append(f'-DPYBIND11_CMAKE_DIR={pybind11.get_cmake_dir()}') + except ImportError: + pass + build_args = ['--config', config] if ( 'CMAKE_BUILD_PARALLEL_LEVEL' not in os.environ and hasattr(self, 'parallel') and self.parallel ): - build_args.append(f'--parallel={self.parallel}') + build_args.extend(['--parallel', str(self.parallel)]) else: build_args.append('--parallel') + build_args.extend(['--target', ext.target, '--']) + + cwd = os.getcwd() try: os.chdir(build_temp) - self.spawn(['cmake', ext.source_dir] + cmake_args) + self.spawn([cmake, ext.source_dir, *cmake_args]) if not self.dry_run: - self.spawn(['cmake', '--build', '.'] + build_args) + self.spawn([cmake, '--build', '.', *build_args]) finally: - os.chdir(HERE) - - -setup( - version=version.__version__, - package_data={'sharedlib': ['*.so', '*.pyd']}, - include_package_data=True, - cmdclass={'build_ext': cmake_build_ext}, - ext_modules=[CMakeExtension('torchopt._C', source_dir=HERE)], -) + os.chdir(cwd) + + +@contextlib.contextmanager +def vcs_version(name, path): + path = pathlib.Path(path).absolute() + assert path.is_file() + module_spec = spec_from_file_location(name=name, location=path) + assert module_spec is not None + assert module_spec.loader is not None + module = sys.modules.get(name) + if module is None: + module = module_from_spec(module_spec) + sys.modules[name] = module + module_spec.loader.exec_module(module) + + if module.__release__: + yield module + return + + content = None + try: + try: + content = path.read_text(encoding='utf-8') + path.write_text( + data=re.sub( + r"""__version__\s*=\s*('[^']+'|"[^"]+")""", + f'__version__ = {module.__version__!r}', + string=content, + ), + encoding='utf-8', + ) + except OSError: + content = None + + yield module + finally: + if content is not None: + with path.open(mode='wt', encoding='utf-8', newline='') as file: + file.write(content) + + +CIBUILDWHEEL = os.getenv('CIBUILDWHEEL', '0') == '1' +LINUX = platform.system() == 'Linux' +MACOS = platform.system() == 'Darwin' +WINDOWS = platform.system() == 'Windows' +ext_kwargs = { + 'cmdclass': {'build_ext': cmake_build_ext}, + 'ext_modules': [ + CMakeExtension( + 'torchopt._C', + source_dir=HERE, + optional=not (LINUX and CIBUILDWHEEL), + ), + ], +} + +TORCHOPT_NO_EXTENSIONS = bool(os.getenv('TORCHOPT_NO_EXTENSIONS', '')) or WINDOWS or MACOS +if TORCHOPT_NO_EXTENSIONS: + ext_kwargs.clear() + + +with vcs_version(name='torchopt.version', path=(HERE / 'torchopt' / 'version.py')) as version: + setup( + name='torchopt', + version=version.__version__, + **ext_kwargs, + ) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 6e3bebc9..30d18335 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,10 +23,10 @@ endif() list(APPEND torchopt_csrc "${adam_op_src}") -pybind11_add_module(_C THIN_LTO "${torchopt_csrc}") +pybind11_add_module(_C MODULE THIN_LTO "${torchopt_csrc}") target_link_libraries( _C PRIVATE - ${TORCH_LIBRARIES} + "${TORCH_LIBRARIES}" OpenMP::OpenMP_CXX ) diff --git a/src/adam_op/adam_op.cpp b/src/adam_op/adam_op.cpp index 01412126..47f5d7f1 100644 --- a/src/adam_op/adam_op.cpp +++ b/src/adam_op/adam_op.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2024 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -104,11 +104,11 @@ TensorArray<2> adamBackwardMu(const torch::Tensor &dmu, const pyfloat_t b1) { #if defined(__USE_CUDA__) if (dmu.device().is_cuda()) { - return adamBackwardMuCUDA(dmu, updates, mu, b1); + return adamBackwardMuCUDA(dmu.contiguous(), updates, mu, b1); } #endif if (dmu.device().is_cpu()) { - return adamBackwardMuCPU(dmu, updates, mu, b1); + return adamBackwardMuCPU(dmu.contiguous(), updates, mu, b1); } else { throw std::runtime_error("Not implemented"); } @@ -120,11 +120,11 @@ TensorArray<2> adamBackwardNu(const torch::Tensor &dnu, const pyfloat_t b2) { #if defined(__USE_CUDA__) if (dnu.device().is_cuda()) { - return adamBackwardNuCUDA(dnu, updates, nu, b2); + return adamBackwardNuCUDA(dnu.contiguous(), updates, nu, b2); } #endif if (dnu.device().is_cpu()) { - return adamBackwardNuCPU(dnu, updates, nu, b2); + return adamBackwardNuCPU(dnu.contiguous(), updates, nu, b2); } else { throw std::runtime_error("Not implemented"); } @@ -136,20 +136,23 @@ TensorArray<2> adamBackwardUpdates(const torch::Tensor &dupdates, const torch::Tensor &new_nu, const pyfloat_t b1, const pyfloat_t b2, + const pyfloat_t eps_root, const pyuint_t count) { #if defined(__USE_CUDA__) if (dupdates.device().is_cuda()) { - return adamBackwardUpdatesCUDA(dupdates, updates, new_mu, new_nu, b1, b2, count); + return adamBackwardUpdatesCUDA( + dupdates.contiguous(), updates, new_mu, new_nu, b1, b2, eps_root, count); } #endif if (dupdates.device().is_cpu()) { - return adamBackwardUpdatesCPU(dupdates, updates, new_mu, new_nu, b1, b2, count); + return adamBackwardUpdatesCPU( + dupdates.contiguous(), updates, new_mu, new_nu, b1, b2, eps_root, count); } else { throw std::runtime_error("Not implemented"); } } -void buildSubmodule(py::module &mod) { // NOLINT +void buildSubmodule(py::module &mod) { // NOLINT[runtime/references] py::module m = mod.def_submodule("adam_op", "Adam Ops"); m.def("forward_", &adamForwardInplace, @@ -162,19 +165,19 @@ void buildSubmodule(py::module &mod) { // NOLINT py::arg("eps"), py::arg("eps_root"), py::arg("count")); - m.def("forwardMu", + m.def("forward_mu", &adamForwardMu, "Adam forward mu", py::arg("updates"), py::arg("mu"), py::arg("b1")); - m.def("forwardNu", + m.def("forward_nu", &adamForwardNu, "Adam forward nu", py::arg("updates"), py::arg("nu"), py::arg("b2")); - m.def("forwardUpdates", + m.def("forward_updates", &adamForwardUpdates, "Adam forward updates", py::arg("new_mu"), @@ -184,21 +187,21 @@ void buildSubmodule(py::module &mod) { // NOLINT py::arg("eps"), py::arg("eps_root"), py::arg("count")); - m.def("backwardMu", + m.def("backward_mu", &adamBackwardMu, "Adam backward mu", py::arg("dmu"), py::arg("updates"), py::arg("mu"), py::arg("b1")); - m.def("backwardNu", + m.def("backward_nu", &adamBackwardNu, "Adam backward nu", py::arg("dnu"), py::arg("updates"), py::arg("nu"), py::arg("b1")); - m.def("backwardUpdates", + m.def("backward_updates", &adamBackwardUpdates, "Adam backward updates", py::arg("dupdates"), @@ -207,6 +210,7 @@ void buildSubmodule(py::module &mod) { // NOLINT py::arg("new_nu"), py::arg("b1"), py::arg("b2"), + py::arg("eps_root"), py::arg("count")); } diff --git a/src/adam_op/adam_op_impl_cpu.cpp b/src/adam_op/adam_op_impl_cpu.cpp index 82accd8c..9c460685 100644 --- a/src/adam_op/adam_op_impl_cpu.cpp +++ b/src/adam_op/adam_op_impl_cpu.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2024 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -27,6 +27,8 @@ using std::size_t; namespace adam_op { +constexpr size_t MIN_NUMEL_USE_OMP = 1000; + template void adamForwardInplaceCPUKernel(const other_t b1, const other_t inv_one_minus_pow_b1, @@ -38,7 +40,9 @@ void adamForwardInplaceCPUKernel(const other_t b1, scalar_t *__restrict__ updates_ptr, scalar_t *__restrict__ mu_ptr, scalar_t *__restrict__ nu_ptr) { -#pragma omp parallel for num_threads(omp_get_num_procs()) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) for (size_t tid = 0; tid < n; ++tid) { const scalar_t updates = updates_ptr[tid]; const scalar_t mu = mu_ptr[tid]; @@ -46,8 +50,10 @@ void adamForwardInplaceCPUKernel(const other_t b1, const scalar_t mu_out = b1 * mu + (1 - b1) * updates; const scalar_t nu_out = b2 * nu + (1 - b2) * updates * updates; - const scalar_t updates_out = - mu_out * inv_one_minus_pow_b1 / (sqrt(nu_out * inv_one_minus_pow_b2 + eps_root) + eps); + const scalar_t mu_hat = mu_out * inv_one_minus_pow_b1; + const scalar_t nu_hat = nu_out * inv_one_minus_pow_b2; + + const scalar_t updates_out = mu_hat / (sqrt(nu_hat + eps_root) + eps); mu_ptr[tid] = mu_out; nu_ptr[tid] = nu_out; @@ -90,7 +96,9 @@ void adamForwardMuCPUKernel(const scalar_t *__restrict__ updates_ptr, const other_t b1, const size_t n, scalar_t *__restrict__ mu_out_ptr) { -#pragma omp parallel for num_threads(omp_get_num_procs()) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) for (size_t tid = 0; tid < n; ++tid) { const scalar_t updates = updates_ptr[tid]; const scalar_t mu = mu_ptr[tid]; @@ -122,12 +130,14 @@ void adamForwardNuCPUKernel(const scalar_t *__restrict__ updates_ptr, const other_t b2, const size_t n, scalar_t *__restrict__ nu_out_ptr) { -#pragma omp parallel for num_threads(omp_get_num_procs()) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) for (size_t tid = 0; tid < n; ++tid) { const scalar_t updates = updates_ptr[tid]; const scalar_t nu = nu_ptr[tid]; - const scalar_t nu_out = b2 * nu + (1 - b2) * pow(updates, 2); + const scalar_t nu_out = b2 * nu + (1 - b2) * updates * updates; nu_out_ptr[tid] = nu_out; } } @@ -158,7 +168,9 @@ void adamForwardUpdatesCPUKernel(const scalar_t *__restrict__ new_mu_ptr, const other_t eps_root, const size_t n, scalar_t *__restrict__ updates_out_ptr) { -#pragma omp parallel for num_threads(omp_get_num_procs()) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) for (size_t tid = 0; tid < n; ++tid) { const scalar_t new_mu = new_mu_ptr[tid]; const scalar_t new_nu = new_nu_ptr[tid]; @@ -176,14 +188,11 @@ torch::Tensor adamForwardUpdatesCPU(const torch::Tensor &new_mu, const pyfloat_t eps_root, const pyuint_t count) { using other_t = pyfloat_t; + const other_t inv_one_minus_pow_b1 = 1 / (1 - std::pow(b1, count)); + const other_t inv_one_minus_pow_b2 = 1 / (1 - std::pow(b2, count)); auto updates_out = torch::empty_like(new_mu); - const other_t one_minus_pow_b1 = 1 - std::pow(b1, count); - const other_t inv_one_minus_pow_b1 = 1 / one_minus_pow_b1; - const other_t one_minus_pow_b2 = 1 - std::pow(b2, count); - const other_t inv_one_minus_pow_b2 = 1 / one_minus_pow_b2; - const size_t n = getTensorPlainSize(new_mu); AT_DISPATCH_SCALAR_TYPES(new_mu.scalar_type(), "adamForwardUpdatesCPU", ([&] { adamForwardUpdatesCPUKernel( @@ -205,7 +214,9 @@ void adamBackwardMuCPUKernel(const scalar_t *__restrict__ dmu_ptr, const size_t n, scalar_t *__restrict__ dupdates_out_ptr, scalar_t *__restrict__ dmu_out_ptr) { -#pragma omp parallel for num_threads(omp_get_num_procs()) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) for (size_t tid = 0; tid < n; ++tid) { const scalar_t dmu = dmu_ptr[tid]; @@ -240,7 +251,9 @@ void adamBackwardNuCPUKernel(const scalar_t *__restrict__ dnu_ptr, const size_t n, scalar_t *__restrict__ dupdates_out_ptr, scalar_t *__restrict__ dnu_out_ptr) { -#pragma omp parallel for num_threads(omp_get_num_procs()) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) for (size_t tid = 0; tid < n; ++tid) { const scalar_t dnu = dnu_ptr[tid]; const scalar_t updates = updates_ptr[tid]; @@ -279,7 +292,9 @@ void adamBackwardUpdatesCPUKernel(const scalar_t *__restrict__ dupdates_ptr, const size_t n, scalar_t *__restrict__ dnew_mu_out_ptr, scalar_t *__restrict__ dnew_nu_out_ptr) { -#pragma omp parallel for num_threads(omp_get_num_procs()) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) for (size_t tid = 0; tid < n; ++tid) { const scalar_t dupdates = dupdates_ptr[tid]; const scalar_t updates = updates_ptr[tid]; @@ -307,16 +322,15 @@ TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor &dupdates, const torch::Tensor &new_nu, const pyfloat_t b1, const pyfloat_t b2, + const pyfloat_t eps_root, const pyuint_t count) { using other_t = pyfloat_t; + const other_t one_minus_pow_b1 = 1 - std::pow(b1, count); + const other_t inv_one_minus_pow_b2 = 1 / (1 - std::pow(b2, count) + eps_root); auto dmu_out = torch::empty_like(new_mu); auto dnu_out = torch::empty_like(new_nu); - const other_t one_minus_pow_b1 = 1 - std::pow(b1, count); - const other_t one_minus_pow_b2 = 1 - std::pow(b2, count); - const other_t inv_one_minus_pow_b2 = 1 / one_minus_pow_b2; - const size_t n = getTensorPlainSize(dupdates); AT_DISPATCH_SCALAR_TYPES(dupdates.scalar_type(), "adamBackwardUpdatesCPU", ([&] { adamBackwardUpdatesCPUKernel( diff --git a/src/adam_op/adam_op_impl_cuda.cu b/src/adam_op/adam_op_impl_cuda.cu index c77d1790..a12eca4f 100644 --- a/src/adam_op/adam_op_impl_cuda.cu +++ b/src/adam_op/adam_op_impl_cuda.cu @@ -1,4 +1,4 @@ -// Copyright 2022 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2024 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,7 +24,10 @@ namespace torchopt { namespace adam_op { -template +constexpr int UNROLL_SIZE = 4; +constexpr int BLOCK_SIZE = 256; + +template __global__ void adamForwardInplaceCUDAKernel(const other_t b1, const other_t inv_one_minus_pow_b1, const other_t b2, @@ -35,22 +38,28 @@ __global__ void adamForwardInplaceCUDAKernel(const other_t b1, scalar_t *__restrict__ updates_ptr, scalar_t *__restrict__ mu_ptr, scalar_t *__restrict__ nu_ptr) { - unsigned tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= n) { - return; + const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; +#pragma unroll + for (int i = 0; i < unroll_size; ++i) { + size_t tid = toffset + i; + if (tid >= n) { + return; + } + const scalar_t updates = updates_ptr[tid]; + const scalar_t mu = mu_ptr[tid]; + const scalar_t nu = nu_ptr[tid]; + + const scalar_t mu_out = b1 * mu + (1 - b1) * updates; + const scalar_t nu_out = b2 * nu + (1 - b2) * updates * updates; + const scalar_t mu_hat = mu_out * inv_one_minus_pow_b1; + const scalar_t nu_hat = nu_out * inv_one_minus_pow_b2; + + const scalar_t updates_out = mu_hat / (sqrt(nu_hat + eps_root) + eps); + + mu_ptr[tid] = mu_out; + nu_ptr[tid] = nu_out; + updates_ptr[tid] = updates_out; } - const scalar_t updates = updates_ptr[tid]; - const scalar_t mu = mu_ptr[tid]; - const scalar_t nu = nu_ptr[tid]; - - const scalar_t mu_out = b1 * mu + (1 - b1) * updates; - const scalar_t nu_out = b2 * nu + (1 - b2) * updates * updates; - const scalar_t updates_out = - mu_out * inv_one_minus_pow_b1 / (sqrt(nu_out * inv_one_minus_pow_b2 + eps_root) + eps); - - mu_ptr[tid] = mu_out; - nu_ptr[tid] = nu_out; - updates_ptr[tid] = updates_out; } TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor &updates, @@ -66,39 +75,61 @@ TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor &updates, const other_t inv_one_minus_pow_b2 = 1 / (1 - std::pow(b2, count)); const size_t n = getTensorPlainSize(updates); - const dim3 block(std::min(n, size_t(256))); - const dim3 grid((n - 1) / block.x + 1); - AT_DISPATCH_SCALAR_TYPES_CUDA(updates.scalar_type(), "adamForwardInplaceCUDA", ([&] { - adamForwardInplaceCUDAKernel - <<>>(scalar_t(b1), - scalar_t(inv_one_minus_pow_b1), - scalar_t(b2), - scalar_t(inv_one_minus_pow_b2), - scalar_t(eps), - scalar_t(eps_root), - n, - updates.data_ptr(), - mu.data_ptr(), - nu.data_ptr()); - })); + if (n < BLOCK_SIZE * UNROLL_SIZE) { + const dim3 block(std::min(n, size_t(BLOCK_SIZE))); + const dim3 grid((n - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(updates.scalar_type(), "adamForwardInplaceCUDA", ([&] { + adamForwardInplaceCUDAKernel + <<>>(scalar_t(b1), + scalar_t(inv_one_minus_pow_b1), + scalar_t(b2), + scalar_t(inv_one_minus_pow_b2), + scalar_t(eps), + scalar_t(eps_root), + n, + updates.data_ptr(), + mu.data_ptr(), + nu.data_ptr()); + })); + } else { + const dim3 block(std::min(n / UNROLL_SIZE, size_t(BLOCK_SIZE))); + const dim3 grid((n / UNROLL_SIZE - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(updates.scalar_type(), "adamForwardInplaceCUDA", ([&] { + adamForwardInplaceCUDAKernel + <<>>(scalar_t(b1), + scalar_t(inv_one_minus_pow_b1), + scalar_t(b2), + scalar_t(inv_one_minus_pow_b2), + scalar_t(eps), + scalar_t(eps_root), + n, + updates.data_ptr(), + mu.data_ptr(), + nu.data_ptr()); + })); + } return TensorArray<3>{updates, mu, nu}; } -template +template __global__ void adamForwardMuCUDAKernel(const scalar_t *__restrict__ updates_ptr, const scalar_t *__restrict__ mu_ptr, const other_t b1, const size_t n, scalar_t *__restrict__ mu_out_ptr) { - size_t tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= n) { - return; + const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; +#pragma unroll + for (int i = 0; i < unroll_size; ++i) { + size_t tid = toffset + i; + if (tid >= n) { + return; + } + + const scalar_t updates = updates_ptr[tid]; + const scalar_t mu = mu_ptr[tid]; + const scalar_t mu_out = b1 * mu + (1 - b1) * updates; + mu_out_ptr[tid] = mu_out; } - - const scalar_t updates = updates_ptr[tid]; - const scalar_t mu = mu_ptr[tid]; - const scalar_t mu_out = b1 * mu + (1 - b1) * updates; - mu_out_ptr[tid] = mu_out; } torch::Tensor adamForwardMuCUDA(const torch::Tensor &updates, @@ -107,35 +138,52 @@ torch::Tensor adamForwardMuCUDA(const torch::Tensor &updates, auto mu_out = torch::empty_like(mu); const size_t n = getTensorPlainSize(updates); - const dim3 block(std::min(n, size_t(256))); - const dim3 grid((n - 1) / block.x + 1); - AT_DISPATCH_SCALAR_TYPES_CUDA(updates.scalar_type(), "adamForwardMuCUDA", ([&] { - adamForwardMuCUDAKernel - <<>>(updates.data_ptr(), - mu.data_ptr(), - scalar_t(b1), - n, - mu_out.data_ptr()); - })); + if (n < BLOCK_SIZE * UNROLL_SIZE) { + const dim3 block(std::min(n, size_t(BLOCK_SIZE))); + const dim3 grid((n - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(updates.scalar_type(), "adamForwardMuCUDA", ([&] { + adamForwardMuCUDAKernel + <<>>(updates.data_ptr(), + mu.data_ptr(), + scalar_t(b1), + n, + mu_out.data_ptr()); + })); + } else { + const dim3 block(std::min(n / UNROLL_SIZE, size_t(BLOCK_SIZE))); + const dim3 grid((n / UNROLL_SIZE - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(updates.scalar_type(), "adamForwardMuCUDA", ([&] { + adamForwardMuCUDAKernel + <<>>(updates.data_ptr(), + mu.data_ptr(), + scalar_t(b1), + n, + mu_out.data_ptr()); + })); + } return mu_out; } -template +template __global__ void adamForwardNuCUDAKernel(const scalar_t *__restrict__ updates_ptr, const scalar_t *__restrict__ nu_ptr, const other_t b2, const size_t n, scalar_t *__restrict__ nu_out_ptr) { - size_t tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= n) { - return; + const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; +#pragma unroll + for (int i = 0; i < unroll_size; ++i) { + size_t tid = toffset + i; + if (tid >= n) { + return; + } + + const scalar_t updates = updates_ptr[tid]; + const scalar_t nu = nu_ptr[tid]; + + const scalar_t nu_out = b2 * nu + (1 - b2) * updates * updates; + nu_out_ptr[tid] = nu_out; } - - const scalar_t updates = updates_ptr[tid]; - const scalar_t nu = nu_ptr[tid]; - - const scalar_t nu_out = b2 * nu + (1 - b2) * pow(updates, 2); - nu_out_ptr[tid] = nu_out; } torch::Tensor adamForwardNuCUDA(const torch::Tensor &updates, @@ -144,20 +192,33 @@ torch::Tensor adamForwardNuCUDA(const torch::Tensor &updates, auto nu_out = torch::empty_like(nu); const size_t n = getTensorPlainSize(updates); - const dim3 block(std::min(n, size_t(256))); - const dim3 grid((n - 1) / block.x + 1); - AT_DISPATCH_SCALAR_TYPES_CUDA(updates.scalar_type(), "adamForwardNuCUDA", ([&] { - adamForwardNuCUDAKernel - <<>>(updates.data_ptr(), - nu.data_ptr(), - scalar_t(b2), - n, - nu_out.data_ptr()); - })); + if (n < BLOCK_SIZE * UNROLL_SIZE) { + const dim3 block(std::min(n, size_t(BLOCK_SIZE))); + const dim3 grid((n - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(updates.scalar_type(), "adamForwardNuCUDA", ([&] { + adamForwardNuCUDAKernel + <<>>(updates.data_ptr(), + nu.data_ptr(), + scalar_t(b2), + n, + nu_out.data_ptr()); + })); + } else { + const dim3 block(std::min(n / UNROLL_SIZE, size_t(BLOCK_SIZE))); + const dim3 grid((n / UNROLL_SIZE - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(updates.scalar_type(), "adamForwardNuCUDA", ([&] { + adamForwardNuCUDAKernel + <<>>(updates.data_ptr(), + nu.data_ptr(), + scalar_t(b2), + n, + nu_out.data_ptr()); + })); + } return nu_out; } -template +template __global__ void adamForwardUpdatesCUDAKernel(const scalar_t *__restrict__ new_mu_ptr, const scalar_t *__restrict__ new_nu_ptr, const other_t inv_one_minus_pow_b1, @@ -166,16 +227,20 @@ __global__ void adamForwardUpdatesCUDAKernel(const scalar_t *__restrict__ new_mu const other_t eps_root, const size_t n, scalar_t *__restrict__ updates_out_ptr) { - size_t tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= n) { - return; + const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; +#pragma unroll + for (int i = 0; i < unroll_size; ++i) { + size_t tid = toffset + i; + if (tid >= n) { + return; + } + + const scalar_t new_mu = new_mu_ptr[tid]; + const scalar_t new_nu = new_nu_ptr[tid]; + const scalar_t mu_hat = new_mu * inv_one_minus_pow_b1; + const scalar_t nu_hat = new_nu * inv_one_minus_pow_b2; + updates_out_ptr[tid] = mu_hat / (sqrt(nu_hat + eps_root) + eps); } - - const scalar_t new_mu = new_mu_ptr[tid]; - const scalar_t new_nu = new_nu_ptr[tid]; - const scalar_t mu_hat = new_mu * inv_one_minus_pow_b1; - const scalar_t nu_hat = new_nu * inv_one_minus_pow_b2; - updates_out_ptr[tid] = mu_hat / (sqrt(nu_hat + eps_root) + eps); } torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor &new_mu, @@ -186,46 +251,64 @@ torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor &new_mu, const pyfloat_t eps_root, const pyuint_t count) { using other_t = pyfloat_t; + const other_t inv_one_minus_pow_b1 = 1 / (1 - std::pow(b1, count)); + const other_t inv_one_minus_pow_b2 = 1 / (1 - std::pow(b2, count)); auto updates_out = torch::empty_like(new_mu); - const other_t one_minus_pow_b1 = 1 - std::pow(b1, count); - const other_t inv_one_minus_pow_b1 = 1 / one_minus_pow_b1; - const other_t one_minus_pow_b2 = 1 - std::pow(b2, count); - const other_t inv_one_minus_pow_b2 = 1 / one_minus_pow_b2; - const size_t n = getTensorPlainSize(new_mu); - const dim3 block(std::min(n, size_t(256))); - const dim3 grid((n - 1) / block.x + 1); - AT_DISPATCH_SCALAR_TYPES_CUDA(new_mu.scalar_type(), "adamForwardUpdatesCUDA", ([&] { - adamForwardUpdatesCUDAKernel - <<>>(new_mu.data_ptr(), - new_nu.data_ptr(), - scalar_t(inv_one_minus_pow_b1), - scalar_t(inv_one_minus_pow_b2), - scalar_t(eps), - scalar_t(eps_root), - n, - updates_out.data_ptr()); - })); + if (n < BLOCK_SIZE * UNROLL_SIZE) { + const dim3 block(std::min(n, size_t(BLOCK_SIZE))); + const dim3 grid((n - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(new_mu.scalar_type(), "adamForwardUpdatesCUDA", ([&] { + adamForwardUpdatesCUDAKernel + <<>>(new_mu.data_ptr(), + new_nu.data_ptr(), + scalar_t(inv_one_minus_pow_b1), + scalar_t(inv_one_minus_pow_b2), + scalar_t(eps), + scalar_t(eps_root), + n, + updates_out.data_ptr()); + })); + } else { + const dim3 block(std::min(n / UNROLL_SIZE, size_t(BLOCK_SIZE))); + const dim3 grid((n / UNROLL_SIZE - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(new_mu.scalar_type(), "adamForwardUpdatesCUDA", ([&] { + adamForwardUpdatesCUDAKernel + <<>>(new_mu.data_ptr(), + new_nu.data_ptr(), + scalar_t(inv_one_minus_pow_b1), + scalar_t(inv_one_minus_pow_b2), + scalar_t(eps), + scalar_t(eps_root), + n, + updates_out.data_ptr()); + })); + } + return updates_out; } -template +template __global__ void adamBackwardMuCUDAKernel(const scalar_t *__restrict__ dmu_ptr, const other_t b1, const size_t n, scalar_t *__restrict__ dupdates_out_ptr, scalar_t *__restrict__ dmu_out_ptr) { - size_t tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= n) { - return; + const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; +#pragma unroll + for (int i = 0; i < unroll_size; ++i) { + size_t tid = toffset + i; + if (tid >= n) { + return; + } + + const scalar_t dmu = dmu_ptr[tid]; + + dupdates_out_ptr[tid] = (1 - b1) * dmu; + dmu_out_ptr[tid] = b1 * dmu; } - - const scalar_t dmu = dmu_ptr[tid]; - - dupdates_out_ptr[tid] = (1 - b1) * dmu; - dmu_out_ptr[tid] = b1 * dmu; } TensorArray<2> adamBackwardMuCUDA(const torch::Tensor &dmu, @@ -236,36 +319,53 @@ TensorArray<2> adamBackwardMuCUDA(const torch::Tensor &dmu, auto dmu_out = torch::empty_like(mu); const size_t n = getTensorPlainSize(dmu); - const dim3 block(std::min(n, size_t(256))); - const dim3 grid((n - 1) / block.x + 1); - AT_DISPATCH_SCALAR_TYPES_CUDA(dmu.scalar_type(), "adamBackwardMuCUDA", ([&] { - adamBackwardMuCUDAKernel - <<>>(dmu.data_ptr(), - scalar_t(b1), - n, - dupdates_out.data_ptr(), - dmu_out.data_ptr()); - })); + if (n < BLOCK_SIZE * UNROLL_SIZE) { + const dim3 block(std::min(n, size_t(BLOCK_SIZE))); + const dim3 grid((n - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(dmu.scalar_type(), "adamBackwardMuCUDA", ([&] { + adamBackwardMuCUDAKernel + <<>>(dmu.data_ptr(), + scalar_t(b1), + n, + dupdates_out.data_ptr(), + dmu_out.data_ptr()); + })); + } else { + const dim3 block(std::min(n / UNROLL_SIZE, size_t(BLOCK_SIZE))); + const dim3 grid((n / UNROLL_SIZE - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(dmu.scalar_type(), "adamBackwardMuCUDA", ([&] { + adamBackwardMuCUDAKernel + <<>>(dmu.data_ptr(), + scalar_t(b1), + n, + dupdates_out.data_ptr(), + dmu_out.data_ptr()); + })); + } return TensorArray<2>{std::move(dupdates_out), std::move(dmu_out)}; } -template +template __global__ void adamBackwardNuCUDAKernel(const scalar_t *__restrict__ dnu_ptr, const scalar_t *__restrict__ updates_ptr, const other_t b2, const size_t n, scalar_t *__restrict__ dupdates_out_ptr, scalar_t *__restrict__ dnu_out_ptr) { - size_t tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= n) { - return; + const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; +#pragma unroll + for (int i = 0; i < unroll_size; ++i) { + size_t tid = toffset + i; + if (tid >= n) { + return; + } + + const scalar_t dnu = dnu_ptr[tid]; + const scalar_t updates = updates_ptr[tid]; + + dupdates_out_ptr[tid] = 2 * (1 - b2) * updates * dnu; + dnu_out_ptr[tid] = b2 * dnu; } - - const scalar_t dnu = dnu_ptr[tid]; - const scalar_t updates = updates_ptr[tid]; - - dupdates_out_ptr[tid] = 2 * (1 - b2) * updates * dnu; - dnu_out_ptr[tid] = b2 * dnu; } TensorArray<2> adamBackwardNuCUDA(const torch::Tensor &dnu, @@ -276,21 +376,35 @@ TensorArray<2> adamBackwardNuCUDA(const torch::Tensor &dnu, auto dnu_out = torch::empty_like(nu); const size_t n = getTensorPlainSize(dnu); - const dim3 block(std::min(n, size_t(256))); - const dim3 grid((n - 1) / block.x + 1); - AT_DISPATCH_SCALAR_TYPES_CUDA(dnu.scalar_type(), "adamForwardNuCUDA", ([&] { - adamBackwardNuCUDAKernel - <<>>(dnu.data_ptr(), - updates.data_ptr(), - scalar_t(b2), - n, - dupdates_out.data_ptr(), - dnu_out.data_ptr()); - })); + if (n < BLOCK_SIZE * UNROLL_SIZE) { + const dim3 block(std::min(n, size_t(BLOCK_SIZE))); + const dim3 grid((n - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(dnu.scalar_type(), "adamForwardNuCUDA", ([&] { + adamBackwardNuCUDAKernel + <<>>(dnu.data_ptr(), + updates.data_ptr(), + scalar_t(b2), + n, + dupdates_out.data_ptr(), + dnu_out.data_ptr()); + })); + } else { + const dim3 block(std::min(n / UNROLL_SIZE, size_t(BLOCK_SIZE))); + const dim3 grid((n / UNROLL_SIZE - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(dnu.scalar_type(), "adamForwardNuCUDA", ([&] { + adamBackwardNuCUDAKernel + <<>>(dnu.data_ptr(), + updates.data_ptr(), + scalar_t(b2), + n, + dupdates_out.data_ptr(), + dnu_out.data_ptr()); + })); + } return TensorArray<2>{std::move(dupdates_out), std::move(dnu_out)}; } -template +template __global__ void adamBackwardUpdatesCUDAKernel(const scalar_t *__restrict__ dupdates_ptr, const scalar_t *__restrict__ updates_ptr, const scalar_t *__restrict__ new_mu_ptr, @@ -299,28 +413,32 @@ __global__ void adamBackwardUpdatesCUDAKernel(const scalar_t *__restrict__ dupda const size_t n, scalar_t *__restrict__ dnew_mu_out_ptr, scalar_t *__restrict__ dnew_nu_out_ptr) { - size_t tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= n) { - return; - } - - const scalar_t dupdates = dupdates_ptr[tid]; - const scalar_t updates = updates_ptr[tid]; - const scalar_t new_mu = new_mu_ptr[tid]; - - if (new_mu == 0) { - dnew_mu_out_ptr[tid] = 0; - dnew_nu_out_ptr[tid] = 0; - return; + const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; +#pragma unroll + for (int i = 0; i < unroll_size; ++i) { + size_t tid = toffset + i; + if (tid >= n) { + return; + } + + const scalar_t dupdates = dupdates_ptr[tid]; + const scalar_t updates = updates_ptr[tid]; + const scalar_t new_mu = new_mu_ptr[tid]; + + if (new_mu == 0) { + dnew_mu_out_ptr[tid] = 0; + dnew_nu_out_ptr[tid] = 0; + return; + } + + const scalar_t updates_div_new_mu = updates / new_mu; + + const scalar_t denominator = updates_div_new_mu * one_minus_pow_b1; + + dnew_mu_out_ptr[tid] = dupdates * updates_div_new_mu; + dnew_nu_out_ptr[tid] = + -dupdates * updates * denominator * 0.5 * inv_one_minus_pow_b2 * denominator; } - - const scalar_t updates_div_new_mu = updates / new_mu; - - const scalar_t denominator = updates_div_new_mu * one_minus_pow_b1; - - dnew_mu_out_ptr[tid] = dupdates * updates_div_new_mu; - dnew_nu_out_ptr[tid] = - -dupdates * updates * denominator * 0.5 * inv_one_minus_pow_b2 * denominator; } TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor &dupdates, @@ -329,30 +447,45 @@ TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor &dupdates, const torch::Tensor &new_nu, const pyfloat_t b1, const pyfloat_t b2, + const pyfloat_t eps_root, const pyuint_t count) { using other_t = pyfloat_t; + const other_t one_minus_pow_b1 = 1 - std::pow(b1, count); + const other_t inv_one_minus_pow_b2 = 1 / (1 - std::pow(b2, count) + eps_root); auto dmu_out = torch::empty_like(new_mu); auto dnu_out = torch::empty_like(new_nu); - const other_t one_minus_pow_b1 = 1 - std::pow(b1, count); - const other_t one_minus_pow_b2 = 1 - std::pow(b2, count); - const other_t inv_one_minus_pow_b2 = 1 / one_minus_pow_b2; - const size_t n = getTensorPlainSize(dupdates); - const dim3 block(std::min(n, size_t(256))); - const dim3 grid((n - 1) / block.x + 1); - AT_DISPATCH_SCALAR_TYPES_CUDA(dupdates.scalar_type(), "adamBackwardUpdatesCUDA", ([&] { - adamBackwardUpdatesCUDAKernel - <<>>(dupdates.data_ptr(), - updates.data_ptr(), - new_mu.data_ptr(), - scalar_t(one_minus_pow_b1), - scalar_t(inv_one_minus_pow_b2), - n, - dmu_out.data_ptr(), - dnu_out.data_ptr()); - })); + if (n < BLOCK_SIZE * UNROLL_SIZE) { + const dim3 block(std::min(n, size_t(BLOCK_SIZE))); + const dim3 grid((n - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(dupdates.scalar_type(), "adamBackwardUpdatesCUDA", ([&] { + adamBackwardUpdatesCUDAKernel + <<>>(dupdates.data_ptr(), + updates.data_ptr(), + new_mu.data_ptr(), + scalar_t(one_minus_pow_b1), + scalar_t(inv_one_minus_pow_b2), + n, + dmu_out.data_ptr(), + dnu_out.data_ptr()); + })); + } else { + const dim3 block(std::min(n / UNROLL_SIZE, size_t(BLOCK_SIZE))); + const dim3 grid((n / UNROLL_SIZE - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(dupdates.scalar_type(), "adamBackwardUpdatesCUDA", ([&] { + adamBackwardUpdatesCUDAKernel + <<>>(dupdates.data_ptr(), + updates.data_ptr(), + new_mu.data_ptr(), + scalar_t(one_minus_pow_b1), + scalar_t(inv_one_minus_pow_b2), + n, + dmu_out.data_ptr(), + dnu_out.data_ptr()); + })); + } return TensorArray<2>{std::move(dmu_out), std::move(dnu_out)}; } diff --git a/src/extension.cpp b/src/extension.cpp index 45880bf6..3d3d52b3 100644 --- a/src/extension.cpp +++ b/src/extension.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2024 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/tests/.coveragerc b/tests/.coveragerc new file mode 100644 index 00000000..4238e71d --- /dev/null +++ b/tests/.coveragerc @@ -0,0 +1,17 @@ +[run] +omit = + ../torchopt/distributed/* + ../torchopt/visual.py + ../torchopt/version.py + ../docs/* + ../examples/* + ../tutorials/* + +[report] +exclude_lines = + pragma: no cover + raise NotImplementedError + class .*\bProtocol\): + @(abc\.)?abstractmethod + if __name__ == ('__main__'|"__main__"): + if TYPE_CHECKING: diff --git a/tests/conftest.py b/tests/conftest.py index 41b7db0b..bb2b1cf2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/helpers.py b/tests/helpers.py index d34ad41e..ca5aa443 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,18 +13,28 @@ # limitations under the License. # ============================================================================== +from __future__ import annotations + +import contextlib import copy import itertools import os import random -from typing import Iterable, Optional, Tuple, Union +from typing import TYPE_CHECKING, Iterable import numpy as np import pytest import torch import torch.nn as nn +import torch.types from torch.utils import data +from torchopt import pytree + + +if TYPE_CHECKING: + from torchopt.typing import TensorTree + BATCH_SIZE = 64 NUM_UPDATES = 5 @@ -34,6 +44,14 @@ MODEL_HIDDEN_SIZE = 64 +def dtype_numpy2torch(dtype: np.dtype) -> torch.dtype: + return torch.tensor(np.zeros(1, dtype=dtype)).dtype + + +def dtype_torch2numpy(dtype: torch.dtype) -> np.dtype: + return torch.zeros(1, dtype=dtype).numpy().dtype + + def parametrize(**argvalues) -> pytest.mark.parametrize: arguments = list(argvalues) @@ -46,9 +64,11 @@ def parametrize(**argvalues) -> pytest.mark.parametrize: argvalues = list(itertools.product(*tuple(map(argvalues.get, arguments)))) first_product = argvalues[0] argvalues.extend((dtype,) + first_product[1:] for dtype in dtypes[1:]) + else: + argvalues = list(itertools.product(*tuple(map(argvalues.get, arguments)))) ids = tuple( - '-'.join(f'{arg}({val})' for arg, val in zip(arguments, values)) for values in argvalues + '-'.join(f'{arg}({val!r})' for arg, val in zip(arguments, values)) for values in argvalues ) return pytest.mark.parametrize(arguments, argvalues, ids=ids) @@ -63,51 +83,74 @@ def seed_everything(seed: int) -> None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) - try: + with contextlib.suppress(AttributeError): torch.use_deterministic_algorithms(True) - except AttributeError: - pass -@torch.no_grad() -def get_models( - device: Optional[Union[str, torch.device]] = None, dtype: torch.dtype = torch.float32 -) -> Tuple[nn.Module, nn.Module, nn.Module, data.DataLoader]: - seed_everything(seed=42) +class MyLinear(nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super().__init__() + self.linear = nn.Linear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + dtype=dtype, + ) + self.unused_module = nn.Linear(1, 1, bias=False) + self.unused_parameter = nn.Parameter(torch.zeros(1, 1), requires_grad=True) - model_base = nn.Sequential( - nn.Linear( + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + +@torch.no_grad() +def get_model(): + return nn.Sequential( + MyLinear( in_features=MODEL_NUM_INPUTS, out_features=MODEL_HIDDEN_SIZE, bias=True, - dtype=dtype, ), nn.BatchNorm1d( num_features=MODEL_HIDDEN_SIZE, track_running_stats=True, - dtype=dtype, ), nn.ReLU(), nn.Linear( in_features=MODEL_HIDDEN_SIZE, out_features=MODEL_HIDDEN_SIZE, bias=True, - dtype=dtype, ), nn.BatchNorm1d( num_features=MODEL_HIDDEN_SIZE, track_running_stats=True, - dtype=dtype, ), nn.ReLU(), nn.Linear( in_features=MODEL_HIDDEN_SIZE, out_features=MODEL_NUM_CLASSES, - bias=True, - dtype=dtype, + bias=False, ), nn.Softmax(dim=-1), ) + + +@torch.no_grad() +def get_models( + device: torch.types.Device | None = None, + dtype: torch.dtype = torch.float32, +) -> tuple[nn.Module, nn.Module, nn.Module, data.DataLoader]: + seed_everything(seed=42) + + model_base = get_model().to(dtype=dtype) for name, param in model_base.named_parameters(recurse=True): if name.endswith('weight') and param.ndim >= 2: nn.init.orthogonal_(param) @@ -123,6 +166,7 @@ def get_models( dataset = data.TensorDataset( torch.randint(0, 1, (BATCH_SIZE * NUM_UPDATES, MODEL_NUM_INPUTS)), + # torch.empty((BATCH_SIZE * NUM_UPDATES, MODEL_NUM_INPUTS), dtype=dtype).uniform_(-1.0, +1.0), torch.randint(0, MODEL_NUM_CLASSES, (BATCH_SIZE * NUM_UPDATES,)), ) loader = data.DataLoader(dataset, BATCH_SIZE, shuffle=False) @@ -132,15 +176,14 @@ def get_models( @torch.no_grad() def assert_model_all_close( - model: Union[nn.Module, Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]], + model: nn.Module | tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]], model_ref: nn.Module, model_base: nn.Module, dtype: torch.dtype = torch.float32, - rtol: Optional[float] = None, - atol: Optional[float] = None, + rtol: float | None = None, + atol: float | None = None, equal_nan: bool = False, -): - +) -> None: if isinstance(model, tuple): params, buffers = model elif isinstance(model, nn.Module): @@ -160,12 +203,11 @@ def assert_model_all_close( def assert_all_close( actual: torch.Tensor, expected: torch.Tensor, - base: torch.Tensor = None, - rtol: Optional[float] = None, - atol: Optional[float] = None, + base: torch.Tensor | None = None, + rtol: float | None = None, + atol: float | None = None, equal_nan: bool = False, ) -> None: - if base is not None: actual = actual - base expected = expected - base @@ -174,8 +216,8 @@ def assert_all_close( from torch.testing._comparison import get_tolerances rtol, atol = get_tolerances(actual, expected, rtol=rtol, atol=atol) - rtol *= 4 * NUM_UPDATES - atol *= 4 * NUM_UPDATES + rtol *= 5 * NUM_UPDATES + atol *= 5 * NUM_UPDATES torch.testing.assert_close( actual, @@ -185,3 +227,32 @@ def assert_all_close( equal_nan=equal_nan, check_dtype=True, ) + + +@torch.no_grad() +def assert_pytree_all_close( + actual: TensorTree, + expected: TensorTree, + base: TensorTree | None = None, + rtol: float | None = None, + atol: float | None = None, + equal_nan: bool = False, +) -> None: + actual_leaves, actual_treespec = pytree.tree_flatten(actual) + expected_leaves, expected_treespec = pytree.tree_flatten(expected) + assert actual_treespec == expected_treespec + if base is not None: + base_leaves, base_treespec = pytree.tree_flatten(base) + assert base_treespec == expected_treespec + else: + base_leaves = [None] * len(actual_leaves) + + for actual_leaf, expected_leaf, base_leaf in zip(actual_leaves, expected_leaves, base_leaves): + assert_all_close( + actual_leaf, + expected_leaf, + base=base_leaf, + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + ) diff --git a/tests/requirements.txt b/tests/requirements.txt index d02db980..ee54732b 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,20 +1,29 @@ ---extra-index-url https://download.pytorch.org/whl/cu116 -torch >= 1.12 -functorch >= 0.2 +--extra-index-url https://download.pytorch.org/whl/cu121 +torch >= 2.0 --requirement ../requirements.txt +jax[cpu] >= 0.4; platform_system != 'Windows' +jaxopt; platform_system != 'Windows' +optax < 0.1.8a0; platform_system != 'Windows' and python_version < '3.9' +optax >= 0.1.8; platform_system != 'Windows' and python_version >= '3.9' + pytest pytest-cov pytest-xdist isort -black >= 22.6.0 -pylint +black[jupyter] +pylint[spelling] mypy flake8 flake8-bugbear -doc8 < 1.0.0a0 -pydocstyle +flake8-comprehensions +flake8-docstrings +flake8-pyi +flake8-simplify +ruff +doc8 +pydocstyle[toml] pyenchant cpplint pre-commit diff --git a/tests/test_accelerated_op.py b/tests/test_accelerated_op.py new file mode 100644 index 00000000..668c9b9a --- /dev/null +++ b/tests/test_accelerated_op.py @@ -0,0 +1,208 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import functorch +import torch +import torch.nn.functional as F + +import helpers +import torchopt + + +try: + import torchopt._C.adam_op +except ImportError: + CXX_ACCELERATED_OP_AVAILABLE = False +else: + CXX_ACCELERATED_OP_AVAILABLE = True + + +def test_accelerated_op_is_available() -> None: + assert torchopt.accelerated_op_available('cpu') + assert torchopt.accelerated_op_available(torch.device('cpu')) + + if CXX_ACCELERATED_OP_AVAILABLE: + assert not torchopt.accelerated_op_available('meta') + assert not torchopt.accelerated_op_available(torch.device('meta')) + assert not torchopt.accelerated_op_available(['cpu', 'meta']) + assert not torchopt.accelerated_op_available([torch.device('cpu'), torch.device('meta')]) + else: + assert torchopt.accelerated_op_available('meta') + assert torchopt.accelerated_op_available(torch.device('meta')) + assert torchopt.accelerated_op_available(['cpu', 'meta']) + assert torchopt.accelerated_op_available([torch.device('cpu'), torch.device('meta')]) + + if torch.cuda.is_available(): + assert torchopt.accelerated_op_available() + assert torchopt.accelerated_op_available('cuda') + assert torchopt.accelerated_op_available('cuda:0') + assert torchopt.accelerated_op_available(0) + assert torchopt.accelerated_op_available(['cpu', 'cuda']) + assert torchopt.accelerated_op_available(['cpu', 'cuda:0']) + assert torchopt.accelerated_op_available(['cpu', 0]) + else: + assert not torchopt.accelerated_op_available() + assert not torchopt.accelerated_op_available('cuda') + assert not torchopt.accelerated_op_available('cuda:0') + assert not torchopt.accelerated_op_available(0) + assert not torchopt.accelerated_op_available(['cpu', 'cuda']) + assert not torchopt.accelerated_op_available(['cpu', 'cuda:0']) + assert not torchopt.accelerated_op_available(['cpu', 0]) + + +@helpers.parametrize( + dtype=[torch.float64, torch.float32], + lr=[1e-2, 1e-3, 1e-4], + inplace=[True, False], +) +def test_accelerated_op( + dtype: torch.dtype, + lr: float, + inplace: bool, +) -> None: + if dtype is torch.float32 and inplace: + return + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + fmodel, params, buffers = functorch.make_functional_with_buffers(model) + optim = torchopt.adam( + lr, + use_accelerated_op=True, + ) + optim_state = optim.init(params) + + fmodel_ref, params_ref, buffers_ref = functorch.make_functional_with_buffers(model_ref) + optim_ref = torchopt.adam( + lr, + use_accelerated_op=False, + ) + optim_state_ref = optim_ref.init(params_ref) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = fmodel(params, buffers, xs) + pred_ref = fmodel_ref(params_ref, buffers_ref, xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + grads = torch.autograd.grad(loss, params, allow_unused=True) + updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) + params = torchopt.apply_updates(params, updates, inplace=inplace) + + grads = torch.autograd.grad(loss_ref, params_ref, allow_unused=True) + updates, optim_state_ref = optim_ref.update( + grads, + optim_state_ref, + params=params, + inplace=inplace, + ) + params_ref = torchopt.apply_updates(params_ref, updates, inplace=inplace) + + helpers.assert_pytree_all_close(params, params_ref) + + +@helpers.parametrize( + dtype=[torch.float64, torch.float32], + outer_lr=[1e-2, 1e-3, 1e-4], + inner_lr=[1e-2, 1e-3, 1e-4], + inner_update=[2, 3, 5], + inplace=[True, False], +) +def test_maml_accelerated_op( + dtype: torch.dtype, + outer_lr: float, + inner_lr: float, + inner_update: int, + inplace: bool, +) -> None: + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + fmodel, params, buffers = functorch.make_functional_with_buffers(model) + outer_optim = torchopt.adam( + outer_lr, + use_accelerated_op=True, + ) + outer_optim_state = outer_optim.init(params) + + fmodel_ref, params_ref, buffers_ref = functorch.make_functional_with_buffers(model_ref) + outer_optim_ref = torchopt.adam( + outer_lr, + use_accelerated_op=False, + ) + outer_optim_state_ref = outer_optim_ref.init(params_ref) + + def maml_inner_solver(params, data, use_accelerated_op): + # Initial functional optimizer based on TorchOpt + x, y, f, b = data + inner_optimizer = torchopt.adam( + inner_lr, + use_accelerated_op=use_accelerated_op, + ) + inner_opt_state = inner_optimizer.init(params) + with torch.enable_grad(): + # Temporarily enable gradient computation for conducting the optimization + for _ in range(inner_update): + pred = f(params, b, x) + inner_loss = F.cross_entropy(pred, y) # compute loss + grads = torch.autograd.grad( + inner_loss, + params, + allow_unused=True, + ) # compute gradients + updates, inner_opt_state = inner_optimizer.update( + grads, + inner_opt_state, + inplace=False, + ) # get updates + params = torchopt.apply_updates(params, updates, inplace=False) + return (params, b) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + data = (xs, ys, fmodel, buffers) + data_ref = (xs, ys, fmodel_ref, buffers_ref) + + params_prime, buffers_prime = maml_inner_solver(params, data, use_accelerated_op=True) + params_prime_ref, buffers_prime_ref = maml_inner_solver( + params_ref, + data_ref, + use_accelerated_op=False, + ) + + pred = fmodel(params_prime, buffers_prime, xs) + pred_ref = fmodel_ref(params_prime_ref, buffers_prime_ref, xs) + outer_loss = F.cross_entropy(pred, ys) + outer_loss_ref = F.cross_entropy(pred_ref, ys) + + grads = torch.autograd.grad(outer_loss, params, allow_unused=True) + updates, outer_optim_state = outer_optim.update( + grads, + outer_optim_state, + params=params, + inplace=inplace, + ) + params = torchopt.apply_updates(params, updates, inplace=inplace) + + grads = torch.autograd.grad(outer_loss_ref, params_ref, allow_unused=True) + updates, outer_optim_state_ref = outer_optim_ref.update( + grads, + outer_optim_state_ref, + params=params, + inplace=inplace, + ) + params_ref = torchopt.apply_updates(params_ref, updates, inplace=inplace) + + torchopt.stop_gradient(model) + torchopt.stop_gradient(model_ref) diff --git a/tests/test_alias.py b/tests/test_alias.py index 6f37e939..3c42d7c8 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,9 @@ # limitations under the License. # ============================================================================== -from typing import Tuple +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable import functorch import pytest @@ -22,6 +24,59 @@ import helpers import torchopt +from torchopt import pytree +from torchopt.alias.utils import _set_use_chain_flat + + +if TYPE_CHECKING: + from torchopt.typing import TensorTree + + +@helpers.parametrize( + optimizer=[ + torchopt.sgd, + torchopt.adam, + torchopt.adamw, + torchopt.rmsprop, + ], + tensortree=[ + {}, + (), + [], + (None,), + {'a': (), 'b': {'c': []}, 'd': None}, + ], + maximize=[False, True], + inplace=[True, False], + use_chain_flat=[True, False], +) +def test_empty( + optimizer: Callable, + tensortree: TensorTree, + maximize: bool, + inplace: bool, + use_chain_flat: bool, +) -> None: + _set_use_chain_flat(use_chain_flat) + + params = pytree.tree_map(lambda x: x, tensortree) + grads = pytree.tree_map(lambda x: x, tensortree) + + optim = optimizer(1e-3, maximize=maximize) + optim_state = optim.init(params) + updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) + _ = torchopt.apply_updates(params, updates) + + try: + optim = optimizer(1e-3, maximize=maximize, use_accelerated_op=True) + except TypeError: + pass + else: + optim_state = optim.init(params) + updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) + _ = torchopt.apply_updates(params, updates) + + _set_use_chain_flat(True) @helpers.parametrize( @@ -32,7 +87,8 @@ nesterov=[False, True], inplace=[True, False], weight_decay=[0.0, 1e-2], - maximize=[False], # TODO: test maximize after PyTorch 1.13 + maximize=[False, True], + use_chain_flat=[True, False], ) def test_sgd( dtype: torch.dtype, @@ -43,10 +99,13 @@ def test_sgd( inplace: bool, weight_decay: float, maximize: bool, + use_chain_flat: bool, ) -> None: if nesterov and (momentum <= 0.0 or dampening != 0.0): pytest.skip('Nesterov momentum requires a momentum and zero dampening.') + _set_use_chain_flat(use_chain_flat) + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) fmodel, params, buffers = functorch.make_functional_with_buffers(model) @@ -76,7 +135,7 @@ def test_sgd( loss = F.cross_entropy(pred, ys) loss_ref = F.cross_entropy(pred_ref, ys) - grads = torch.autograd.grad(loss, params) + grads = torch.autograd.grad(loss, params, allow_unused=True) updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) params = torchopt.apply_updates(params, updates, inplace=inplace) @@ -85,6 +144,64 @@ def test_sgd( optim_ref.step() helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) + _set_use_chain_flat(True) + + +@helpers.parametrize( + dtype=[torch.float64], + lr=[1e-2, 1e-3, 1e-4], + rho=[0.9, 0.95], + eps=[1e-8], + inplace=[True, False], + weight_decay=[0.0, 1e-2], + use_chain_flat=[True, False], +) +def test_adadelta( + dtype: torch.dtype, + lr: float, + rho: float, + eps: float, + inplace: bool, + weight_decay: float, + use_chain_flat: bool, +) -> None: + _set_use_chain_flat(use_chain_flat) + + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + fmodel, params, buffers = functorch.make_functional_with_buffers(model) + optim = torchopt.adadelta( + lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + ) + optim_state = optim.init(params) + optim_ref = torch.optim.Adadelta( + model_ref.parameters(), + lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + ) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = fmodel(params, buffers, xs) + pred_ref = model_ref(xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + grads = torch.autograd.grad(loss, params, allow_unused=True) + updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) + params = torchopt.apply_updates(params, updates, inplace=inplace) + + optim_ref.zero_grad() + loss_ref.backward() + optim_ref.step() + + helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) + _set_use_chain_flat(True) @helpers.parametrize( @@ -95,16 +212,22 @@ def test_sgd( inplace=[True, False], weight_decay=[0.0, 1e-2], maximize=[False, True], + use_accelerated_op=[False, True], + use_chain_flat=[True, False], ) def test_adam( dtype: torch.dtype, lr: float, - betas: Tuple[float, float], + betas: tuple[float, float], eps: float, inplace: bool, weight_decay: float, maximize: bool, + use_accelerated_op: bool, + use_chain_flat: bool, ) -> None: + _set_use_chain_flat(use_chain_flat) + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) fmodel, params, buffers = functorch.make_functional_with_buffers(model) @@ -115,6 +238,7 @@ def test_adam( eps_root=0.0, weight_decay=weight_decay, maximize=maximize, + use_accelerated_op=use_accelerated_op, ) optim_state = optim.init(params) optim_ref = torch.optim.Adam( @@ -134,7 +258,7 @@ def test_adam( loss = F.cross_entropy(pred, ys) loss_ref = F.cross_entropy(pred_ref, ys) - grads = torch.autograd.grad(loss, params) + grads = torch.autograd.grad(loss, params, allow_unused=True) updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) params = torchopt.apply_updates(params, updates, inplace=inplace) @@ -143,6 +267,7 @@ def test_adam( optim_ref.step() helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) + _set_use_chain_flat(True) @helpers.parametrize( @@ -152,37 +277,35 @@ def test_adam( eps=[1e-8], inplace=[True, False], weight_decay=[0.0, 1e-2], - maximize=[False, True], + use_chain_flat=[True, False], ) -def test_adamw( +def test_radam( dtype: torch.dtype, lr: float, - betas: Tuple[float, float], + betas: tuple[float, float], eps: float, inplace: bool, weight_decay: float, - maximize: bool, + use_chain_flat: bool, ) -> None: + _set_use_chain_flat(use_chain_flat) + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) fmodel, params, buffers = functorch.make_functional_with_buffers(model) - optim = torchopt.adamw( + optim = torchopt.radam( lr, betas=betas, eps=eps, - eps_root=0.0, weight_decay=weight_decay, - maximize=maximize, ) optim_state = optim.init(params) - optim_ref = torch.optim.AdamW( + optim_ref = torch.optim.RAdam( model_ref.parameters(), lr, betas=betas, eps=eps, - amsgrad=False, weight_decay=weight_decay, - maximize=maximize, ) for xs, ys in loader: @@ -192,7 +315,64 @@ def test_adamw( loss = F.cross_entropy(pred, ys) loss_ref = F.cross_entropy(pred_ref, ys) - grads = torch.autograd.grad(loss, params) + grads = torch.autograd.grad(loss, params, allow_unused=True) + updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) + params = torchopt.apply_updates(params, updates, inplace=inplace) + + optim_ref.zero_grad() + loss_ref.backward() + optim_ref.step() + + helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) + _set_use_chain_flat(True) + + +@helpers.parametrize( + dtype=[torch.float64], + lr=[1e-2, 1e-3, 1e-4], + betas=[(0.9, 0.999), (0.95, 0.9995)], + eps=[1e-8], + inplace=[True, False], + weight_decay=[0.0, 1e-2], + use_chain_flat=[True, False], +) +def test_adamax( + dtype: torch.dtype, + lr: float, + betas: tuple[float, float], + eps: float, + inplace: bool, + weight_decay: float, + use_chain_flat: bool, +) -> None: + _set_use_chain_flat(use_chain_flat) + + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + fmodel, params, buffers = functorch.make_functional_with_buffers(model) + optim = torchopt.adamax( + lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + ) + optim_state = optim.init(params) + optim_ref = torch.optim.Adamax( + model_ref.parameters(), + lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + ) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = fmodel(params, buffers, xs) + pred_ref = model_ref(xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + grads = torch.autograd.grad(loss, params, allow_unused=True) updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) params = torchopt.apply_updates(params, updates, inplace=inplace) @@ -201,6 +381,107 @@ def test_adamw( optim_ref.step() helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) + _set_use_chain_flat(True) + + +@helpers.parametrize( + dtype=[torch.float64], + outer_lr=[1e-2, 1e-3, 1e-4], + inner_lr=[1e-2, 1e-3, 1e-4], + inner_update=[2, 3, 5], + betas=[(0.9, 0.999), (0.95, 0.9995)], + eps=[1e-8], + inplace=[True, False], + weight_decay=[0.0, 1e-2], + maximize=[False, True], + use_accelerated_op=[False, True], + use_chain_flat=[True, False], +) +def test_maml_adam( + dtype: torch.dtype, + outer_lr: float, + inner_lr: float, + inner_update: int, + betas: tuple[float, float], + eps: float, + inplace: bool, + weight_decay: float, + maximize: bool, + use_accelerated_op: bool, + use_chain_flat: bool, +) -> None: + _set_use_chain_flat(use_chain_flat) + + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + fmodel, params, buffers = functorch.make_functional_with_buffers(model) + outer_optim = torchopt.adam( + outer_lr, + betas=betas, + eps=eps, + eps_root=0.0, + weight_decay=weight_decay, + maximize=maximize, + use_accelerated_op=use_accelerated_op, + ) + outer_optim_state = outer_optim.init(params) + + def maml_inner_solver_torchopt(params, data, use_accelerated_op): + # Initial functional optimizer based on TorchOpt + x, y, f, b = data + inner_optimizer = torchopt.adam( + inner_lr, + betas=betas, + eps=eps, + eps_root=0.0, + weight_decay=weight_decay, + maximize=maximize, + use_accelerated_op=use_accelerated_op, + ) + inner_opt_state = inner_optimizer.init(params) + with torch.enable_grad(): + # Temporarily enable gradient computation for conducting the optimization + for _ in range(inner_update): + pred = f(params, b, x) + inner_loss = F.cross_entropy(pred, y) # compute loss + grads = torch.autograd.grad( + inner_loss, + params, + allow_unused=True, + ) # compute gradients + updates, inner_opt_state = inner_optimizer.update( + grads, + inner_opt_state, + params=params, + inplace=False, + ) # get updates + params = torchopt.apply_updates(params, updates, inplace=False) + return (params, b) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + data = (xs, ys, fmodel, buffers) + + params_prime, buffers_prime = maml_inner_solver_torchopt( + params, + data, + use_accelerated_op=True, + ) + pred = fmodel(params_prime, buffers_prime, xs) + outer_loss = F.cross_entropy(pred, ys) + + grads = torch.autograd.grad(outer_loss, params, allow_unused=True) + updates, outer_optim_state = outer_optim.update( + grads, + outer_optim_state, + params=params, + inplace=inplace, + ) + params = torchopt.apply_updates(params, updates, inplace=inplace) + + torchopt.stop_gradient(model) + + _set_use_chain_flat(True) @helpers.parametrize( @@ -209,32 +490,38 @@ def test_adamw( betas=[(0.9, 0.999), (0.95, 0.9995)], eps=[1e-8], inplace=[True, False], - weight_decay=[1e-2, 1e-1], + weight_decay=[0.0, 1e-2], maximize=[False, True], + use_accelerated_op=[False, True], + use_chain_flat=[True, False], ) -def test_adam_accelerated_cpu( +def test_adamw( dtype: torch.dtype, lr: float, - betas: Tuple[float, float], + betas: tuple[float, float], eps: float, inplace: bool, weight_decay: float, maximize: bool, + use_accelerated_op: bool, + use_chain_flat: bool, ) -> None: + _set_use_chain_flat(use_chain_flat) + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) fmodel, params, buffers = functorch.make_functional_with_buffers(model) - optim = torchopt.adam( + optim = torchopt.adamw( lr, betas=betas, eps=eps, eps_root=0.0, weight_decay=weight_decay, maximize=maximize, - use_accelerated_op=True, + use_accelerated_op=use_accelerated_op, ) optim_state = optim.init(params) - optim_ref = torch.optim.Adam( + optim_ref = torch.optim.AdamW( model_ref.parameters(), lr, betas=betas, @@ -251,7 +538,7 @@ def test_adam_accelerated_cpu( loss = F.cross_entropy(pred, ys) loss_ref = F.cross_entropy(pred_ref, ys) - grads = torch.autograd.grad(loss, params) + grads = torch.autograd.grad(loss, params, allow_unused=True) updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) params = torchopt.apply_updates(params, updates, inplace=inplace) @@ -260,32 +547,44 @@ def test_adam_accelerated_cpu( optim_ref.step() helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) + _set_use_chain_flat(True) @pytest.mark.skipif(not torch.cuda.is_available(), reason='No CUDA device available.') @helpers.parametrize( dtype=[torch.float64], lr=[1e-2, 1e-3, 1e-4], + optimizers=[ + (torchopt.adam, torch.optim.Adam), + (torchopt.adamw, torch.optim.AdamW), + ], betas=[(0.9, 0.999), (0.95, 0.9995)], eps=[1e-8], inplace=[True, False], weight_decay=[0.0, 1e-2], maximize=[False, True], + use_chain_flat=[True, False], ) def test_adam_accelerated_cuda( dtype: torch.dtype, lr: float, - betas: Tuple[float, float], + optimizers: tuple[Callable, torch.optim.Optimizer], + betas: tuple[float, float], eps: float, inplace: bool, weight_decay: float, maximize: bool, + use_chain_flat: bool, ) -> None: + _set_use_chain_flat(use_chain_flat) + device = 'cuda' model, model_ref, model_base, loader = helpers.get_models(device=device, dtype=dtype) + torchopt_optimizer, torch_optimizer = optimizers + fmodel, params, buffers = functorch.make_functional_with_buffers(model) - optim = torchopt.adam( + optim = torchopt_optimizer( lr, betas=betas, eps=eps, @@ -295,7 +594,7 @@ def test_adam_accelerated_cuda( use_accelerated_op=True, ) optim_state = optim.init(params) - optim_ref = torch.optim.Adam( + optim_ref = torch_optimizer( model_ref.parameters(), lr, betas=betas, @@ -313,7 +612,7 @@ def test_adam_accelerated_cuda( loss = F.cross_entropy(pred, ys) loss_ref = F.cross_entropy(pred_ref, ys) - grads = torch.autograd.grad(loss, params) + grads = torch.autograd.grad(loss, params, allow_unused=True) updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) params = torchopt.apply_updates(params, updates, inplace=inplace) @@ -322,6 +621,71 @@ def test_adam_accelerated_cuda( optim_ref.step() helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) + _set_use_chain_flat(True) + + +@helpers.parametrize( + dtype=[torch.float64], + lr=[1e-2, 1e-3, 1e-4], + lr_decay=[0.0, 1e-2], + initial_accumulator_value=[0.0, 1e-1], + eps=[1e-8], + inplace=[True, False], + weight_decay=[0.0, 1e-2], + maximize=[False, True], + use_chain_flat=[True, False], +) +def test_adagrad( + dtype: torch.dtype, + lr: float, + lr_decay: float, + initial_accumulator_value: float, + eps: float, + inplace: bool, + weight_decay: float, + maximize: bool, + use_chain_flat: bool, +) -> None: + _set_use_chain_flat(use_chain_flat) + + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + fmodel, params, buffers = functorch.make_functional_with_buffers(model) + optim = torchopt.adagrad( + lr=lr, + lr_decay=lr_decay, + weight_decay=weight_decay, + initial_accumulator_value=initial_accumulator_value, + eps=eps, + maximize=maximize, + ) + optim_state = optim.init(params) + optim_ref = torch.optim.Adagrad( + model_ref.parameters(), + lr=lr, + lr_decay=lr_decay, + weight_decay=weight_decay, + initial_accumulator_value=initial_accumulator_value, + eps=eps, + maximize=maximize, + ) + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = fmodel(params, buffers, xs) + pred_ref = model_ref(xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + grads = torch.autograd.grad(loss, params, allow_unused=True) + updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) + params = torchopt.apply_updates(params, updates, inplace=inplace) + + optim_ref.zero_grad() + loss_ref.backward() + optim_ref.step() + + helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) + _set_use_chain_flat(True) @helpers.parametrize( @@ -333,6 +697,7 @@ def test_adam_accelerated_cuda( centered=[False, True], weight_decay=[0.0, 1e-2], inplace=[True, False], + use_chain_flat=[True, False], ) def test_rmsprop( dtype: torch.dtype, @@ -343,7 +708,10 @@ def test_rmsprop( centered: bool, weight_decay: float, inplace: bool, + use_chain_flat: bool, ) -> None: + _set_use_chain_flat(use_chain_flat) + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) fmodel, params, buffers = functorch.make_functional_with_buffers(model) @@ -374,7 +742,7 @@ def test_rmsprop( loss = F.cross_entropy(pred, ys) loss_ref = F.cross_entropy(pred_ref, ys) - grads = torch.autograd.grad(loss, params) + grads = torch.autograd.grad(loss, params, allow_unused=True) updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) params = torchopt.apply_updates(params, updates, inplace=inplace) @@ -383,3 +751,4 @@ def test_rmsprop( optim_ref.step() helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) + _set_use_chain_flat(True) diff --git a/tests/test_clip.py b/tests/test_clip.py index 420cfdaa..2614781e 100644 --- a/tests/test_clip.py +++ b/tests/test_clip.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ import helpers import torchopt +from torchopt.alias.utils import _set_use_chain_flat @helpers.parametrize( @@ -30,7 +31,8 @@ dampening=[0.0, 0.5], nesterov=[False, True], weight_decay=[0.0, 1e-2], - maximize=[False], # TODO: test maximize after PyTorch 1.13 + maximize=[False, True], + use_chain_flat=[True, False], ) def test_sgd( dtype: torch.dtype, @@ -41,10 +43,13 @@ def test_sgd( nesterov: bool, weight_decay: float, maximize: bool, + use_chain_flat: bool, ) -> None: if nesterov and (momentum <= 0.0 or dampening != 0.0): pytest.skip('Nesterov momentum requires a momentum and zero dampening.') + _set_use_chain_flat(use_chain_flat) + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) chain = torchopt.chain( @@ -86,3 +91,4 @@ def test_sgd( optim_ref.step() helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype) + _set_use_chain_flat(True) diff --git a/tests/test_combine.py b/tests/test_combine.py new file mode 100644 index 00000000..1a026b9e --- /dev/null +++ b/tests/test_combine.py @@ -0,0 +1,53 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import torchopt +from torchopt.alias.utils import _set_use_chain_flat + + +def test_chain() -> None: + assert torchopt.chain() == torchopt.base.identity() + assert torchopt.chain(torchopt.base.identity()) == torchopt.base.identity() + assert ( + torchopt.chain(torchopt.base.identity(), torchopt.base.identity()) + == torchopt.base.identity() + ) + assert torchopt.base.identity().chain(torchopt.base.identity()) == torchopt.base.identity() + assert isinstance(torchopt.base.identity(), torchopt.base.IdentityGradientTransformation) + assert isinstance( + torchopt.base.identity().chain(torchopt.base.identity()), + torchopt.base.ChainedGradientTransformation, + ) + + _set_use_chain_flat(False) + adam = torchopt.adam() + assert isinstance(adam, torchopt.base.ChainedGradientTransformation) + assert isinstance( + adam.chain(torchopt.base.identity()), + torchopt.base.ChainedGradientTransformation, + ) + assert adam.chain(torchopt.base.identity()) == adam + assert torchopt.base.identity().chain(adam) == adam + assert torchopt.chain(torchopt.base.identity(), adam, torchopt.base.identity()) == adam + _set_use_chain_flat(True) + + assert isinstance(adam, torchopt.base.GradientTransformation) + assert isinstance( + adam.chain(torchopt.base.identity()), + torchopt.base.ChainedGradientTransformation, + ) + assert adam.chain(torchopt.base.identity()) == adam + assert torchopt.base.identity().chain(adam) == adam + assert torchopt.chain(torchopt.base.identity(), adam, torchopt.base.identity()) == adam diff --git a/tests/test_hook.py b/tests/test_hook.py new file mode 100644 index 00000000..e89bb178 --- /dev/null +++ b/tests/test_hook.py @@ -0,0 +1,38 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import torch + +import torchopt +from torchopt import pytree + + +def test_nan_to_num_hook() -> None: + nan = torch.tensor(torch.nan) + inf = torch.tensor(torch.inf) + ninf = torch.tensor(-torch.inf) + hook = torchopt.hook.nan_to_num_hook(0.0, 1.0, -1.0) + result = pytree.tree_map(hook, [nan, inf, ninf]) + assert torch.equal(result[0], torch.tensor(0.0)) + assert torch.equal(result[1], torch.tensor(1.0)) + assert torch.equal(result[2], torch.tensor(-1.0)) + + +def test_zero_nan_hook() -> None: + tensor = torch.tensor(1.0, requires_grad=True) + hook = torchopt.hook.zero_nan_hook + fn = torchopt.register_hook(hook) + fn.update(tensor, None) + assert tensor._backward_hooks[0] is hook diff --git a/tests/test_implicit.py b/tests/test_implicit.py new file mode 100644 index 00000000..6cccb716 --- /dev/null +++ b/tests/test_implicit.py @@ -0,0 +1,873 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import annotations + +import copy +import re +from collections import OrderedDict +from typing import TYPE_CHECKING + +import functorch +import numpy as np +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.types +from torch.utils import data + +import helpers +import torchopt +from torchopt import pytree +from torchopt.diff.implicit import ImplicitMetaGradientModule + + +try: + import jax + import jax.numpy as jnp + import jaxopt + import optax + + HAS_JAX = True +except ImportError: + jax = jnp = jaxopt = optax = None + HAS_JAX = False + + +if TYPE_CHECKING: + from types import FunctionType + + +BATCH_SIZE = 8 +NUM_UPDATES = 3 + +MODEL_NUM_INPUTS = 10 +MODEL_NUM_CLASSES = 10 + + +class FcNet(nn.Module): + def __init__(self, dim, out): + super().__init__() + self.fc = nn.Linear(in_features=dim, out_features=out, bias=True) + nn.init.ones_(self.fc.weight) + nn.init.zeros_(self.fc.bias) + + def forward(self, x): + return self.fc(x) + + +def get_model_jax(dtype: np.dtype = np.float32) -> tuple[FunctionType, OrderedDict]: + helpers.seed_everything(seed=42) + + def func(params, x): + return x @ params['weight'] + params['bias'] + + params = OrderedDict( + [ + ('weight', jnp.ones((MODEL_NUM_INPUTS, MODEL_NUM_CLASSES), dtype=dtype)), + ('bias', jnp.zeros((MODEL_NUM_CLASSES,), dtype=dtype)), + ], + ) + return func, params + + +@torch.no_grad() +def get_model_torch( + device: torch.types.Device | None = None, + dtype: torch.dtype = torch.float32, +) -> tuple[nn.Module, data.DataLoader]: + helpers.seed_everything(seed=42) + + model = FcNet(MODEL_NUM_INPUTS, MODEL_NUM_CLASSES).to(dtype=dtype) + + if device is not None: + model = model.to(device=torch.device(device)) + + dataset = data.TensorDataset( + torch.randint(0, 1, (BATCH_SIZE * NUM_UPDATES, MODEL_NUM_INPUTS)), + torch.randint(0, MODEL_NUM_CLASSES, (BATCH_SIZE * NUM_UPDATES,)), + ) + loader = data.DataLoader(dataset, BATCH_SIZE, shuffle=False) + + return model, loader + + +@torch.no_grad() +def get_rr_dataset_torch() -> data.DataLoader: + helpers.seed_everything(seed=42) + + BATCH_SIZE = 1024 + NUM_UPDATES = 4 + dataset = data.TensorDataset( + torch.randn((BATCH_SIZE * NUM_UPDATES, MODEL_NUM_INPUTS)), + torch.randn((BATCH_SIZE * NUM_UPDATES,)), + torch.randn((BATCH_SIZE * NUM_UPDATES, MODEL_NUM_INPUTS)), + torch.randn((BATCH_SIZE * NUM_UPDATES,)), + ) + return data.DataLoader(dataset, BATCH_SIZE, shuffle=False) + + +@pytest.mark.skipif(not HAS_JAX, reason='JAX is not installed') +@helpers.parametrize( + dtype=[torch.float64, torch.float32], + lr=[1e-3, 1e-4], + inner_lr=[2e-2, 2e-3], + inner_update=[20, 50, 100], +) +def test_imaml_solve_normal_cg( # noqa: C901 + dtype: torch.dtype, + lr: float, + inner_lr: float, + inner_update: int, +) -> None: + np_dtype = helpers.dtype_torch2numpy(dtype) + + jax_model, jax_params = get_model_jax(dtype=np_dtype) + model, loader = get_model_torch(device='cpu', dtype=dtype) + + fmodel, params = functorch.make_functional(model) + optim = torchopt.sgd(lr) + optim_state = optim.init(params) + + optim_jax = optax.sgd(lr) + optim_state_jax = optim_jax.init(jax_params) + + def imaml_objective_torchopt(params, meta_params, data): + x, y, f = data + y_pred = f(params, x) + regularization_loss = 0 + for p1, p2 in zip(params, meta_params): + regularization_loss += 0.5 * torch.sum(torch.square(p1 - p2)) + return F.cross_entropy(y_pred, y) + regularization_loss + + @torchopt.diff.implicit.custom_root( + functorch.grad(imaml_objective_torchopt, argnums=0), + argnums=1, + has_aux=True, + solve=torchopt.linear_solve.solve_normal_cg(), + ) + def inner_solver_torchopt(params, meta_params, data): + # Initial functional optimizer based on TorchOpt + x, y, f = data + optimizer = torchopt.sgd(lr=inner_lr) + opt_state = optimizer.init(params) + with torch.enable_grad(): + # Temporarily enable gradient computation for conducting the optimization + for _ in range(inner_update): + pred = f(params, x) + loss = F.cross_entropy(pred, y) # compute loss + # Compute regularization loss + regularization_loss = 0 + for p1, p2 in zip(params, meta_params): + regularization_loss += 0.5 * torch.sum(torch.square(p1 - p2)) + final_loss = loss + regularization_loss + grads = torch.autograd.grad(final_loss, params) # compute gradients + updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get updates + params = torchopt.apply_updates(params, updates, inplace=True) + return params, (0, {'a': 1, 'b': 2}) + + def imaml_objective_jax(params, meta_params, x, y): + y_pred = jax_model(params, x) + loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(y_pred, y)) + regularization_loss = 0 + for p1, p2 in zip(params.values(), meta_params.values()): + regularization_loss += 0.5 * jnp.sum(jnp.square(p1 - p2)) + return loss + regularization_loss + + @jaxopt.implicit_diff.custom_root( + jax.grad(imaml_objective_jax, argnums=0), + has_aux=True, + solve=jaxopt.linear_solve.solve_normal_cg, + ) + def inner_solver_jax(params, meta_params, x, y): + """Solve ridge regression by conjugate gradient.""" + # Initial functional optimizer based on torchopt + optimizer = optax.sgd(inner_lr) + opt_state = optimizer.init(params) + + def compute_loss(params, meta_params, x, y): + pred = jax_model(params, x) + loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(pred, y)) + # Compute regularization loss + regularization_loss = 0 + for p1, p2 in zip(params.values(), meta_params.values()): + regularization_loss += 0.5 * jnp.sum(jnp.square(p1 - p2)) + return loss + regularization_loss + + for _ in range(inner_update): + grads = jax.grad(compute_loss)(params, meta_params, x, y) # compute gradients + updates, opt_state = optimizer.update(grads, opt_state) # get updates + params = optax.apply_updates(params, updates) + return params, (0, {'a': 1, 'b': 2}) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + data = (xs, ys, fmodel) + inner_params = pytree.tree_map( + lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), + params, + ) + optimal_inner_params, aux = inner_solver_torchopt(inner_params, params, data) + assert aux == (0, {'a': 1, 'b': 2}) + outer_loss = fmodel(optimal_inner_params, xs).mean() + + grads = torch.autograd.grad(outer_loss, params) + updates, optim_state = optim.update(grads, optim_state) + params = torchopt.apply_updates(params, updates) + + xs = xs.numpy() + ys = ys.numpy() + + def outer_level(p, xs, ys): + optimal_params, aux = inner_solver_jax(copy.deepcopy(p), p, xs, ys) + assert aux == (0, {'a': 1, 'b': 2}) + return jax_model(optimal_params, xs).mean() + + grads_jax = jax.grad(outer_level, argnums=0)(jax_params, xs, ys) + updates_jax, optim_state_jax = optim_jax.update(grads_jax, optim_state_jax) # get updates + jax_params = optax.apply_updates(jax_params, updates_jax) + + jax_params_as_tensor = tuple( + nn.Parameter(torch.tensor(np.asarray(jax_params[j]), dtype=dtype)) for j in jax_params + ) + + helpers.assert_pytree_all_close(params, jax_params_as_tensor) + + +@pytest.mark.skipif(not HAS_JAX, reason='JAX is not installed') +@helpers.parametrize( + dtype=[torch.float64, torch.float32], + lr=[1e-3, 1e-4], + inner_lr=[2e-2, 2e-3], + inner_update=[20, 50, 100], + ns=[False, True], +) +def test_imaml_solve_inv( # noqa: C901 + dtype: torch.dtype, + lr: float, + inner_lr: float, + inner_update: int, + ns: bool, +) -> None: + np_dtype = helpers.dtype_torch2numpy(dtype) + + jax_model, jax_params = get_model_jax(dtype=np_dtype) + model, loader = get_model_torch(device='cpu', dtype=dtype) + + fmodel, params = functorch.make_functional(model) + optim = torchopt.sgd(lr) + optim_state = optim.init(params) + + optim_jax = optax.sgd(lr) + optim_state_jax = optim_jax.init(jax_params) + + def imaml_objective_torchopt(params, meta_params, data): + x, y, f = data + y_pred = f(params, x) + regularization_loss = 0 + for p1, p2 in zip(params, meta_params): + regularization_loss += 0.5 * torch.sum(torch.square(p1 - p2)) + return F.cross_entropy(y_pred, y) + regularization_loss + + @torchopt.diff.implicit.custom_root( + functorch.grad(imaml_objective_torchopt, argnums=0), + argnums=1, + solve=torchopt.linear_solve.solve_inv(ns=ns), + ) + def inner_solver_torchopt(params, meta_params, data): + # Initial functional optimizer based on TorchOpt + x, y, f = data + optimizer = torchopt.sgd(lr=inner_lr) + opt_state = optimizer.init(params) + with torch.enable_grad(): + # Temporarily enable gradient computation for conducting the optimization + for _ in range(inner_update): + pred = f(params, x) + loss = F.cross_entropy(pred, y) # compute loss + # Compute regularization loss + regularization_loss = 0 + for p1, p2 in zip(params, meta_params): + regularization_loss += 0.5 * torch.sum(torch.square(p1 - p2)) + final_loss = loss + regularization_loss + grads = torch.autograd.grad(final_loss, params) # compute gradients + updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get updates + params = torchopt.apply_updates(params, updates, inplace=True) + return params + + def imaml_objective_jax(params, meta_params, x, y): + y_pred = jax_model(params, x) + loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(y_pred, y)) + regularization_loss = 0 + for p1, p2 in zip(params.values(), meta_params.values()): + regularization_loss += 0.5 * jnp.sum(jnp.square(p1 - p2)) + return loss + regularization_loss + + @jaxopt.implicit_diff.custom_root( + jax.grad(imaml_objective_jax, argnums=0), + solve=jaxopt.linear_solve.solve_normal_cg, + ) + def inner_solver_jax(params, meta_params, x, y): + """Solve ridge regression by conjugate gradient.""" + # Initial functional optimizer based on torchopt + optimizer = optax.sgd(inner_lr) + opt_state = optimizer.init(params) + + def compute_loss(params, meta_params, x, y): + pred = jax_model(params, x) + loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(pred, y)) + # Compute regularization loss + regularization_loss = 0 + for p1, p2 in zip(params.values(), meta_params.values()): + regularization_loss += 0.5 * jnp.sum(jnp.square(p1 - p2)) + return loss + regularization_loss + + for _ in range(inner_update): + grads = jax.grad(compute_loss)(params, meta_params, x, y) # compute gradients + updates, opt_state = optimizer.update(grads, opt_state) # get updates + params = optax.apply_updates(params, updates) + return params + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + data = (xs, ys, fmodel) + inner_params = pytree.tree_map( + lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), + params, + ) + optimal_inner_params = inner_solver_torchopt(inner_params, params, data) + outer_loss = fmodel(optimal_inner_params, xs).mean() + + grads = torch.autograd.grad(outer_loss, params) + updates, optim_state = optim.update(grads, optim_state) + params = torchopt.apply_updates(params, updates) + + xs = xs.numpy() + ys = ys.numpy() + + def outer_level(p, xs, ys): + optimal_params = inner_solver_jax(copy.deepcopy(p), p, xs, ys) + return jax_model(optimal_params, xs).mean() + + grads_jax = jax.grad(outer_level, argnums=0)(jax_params, xs, ys) + updates_jax, optim_state_jax = optim_jax.update(grads_jax, optim_state_jax) # get updates + jax_params = optax.apply_updates(jax_params, updates_jax) + + jax_params_as_tensor = tuple( + nn.Parameter(torch.tensor(np.asarray(jax_params[j]), dtype=dtype)) for j in jax_params + ) + + helpers.assert_pytree_all_close(params, jax_params_as_tensor) + + +@pytest.mark.skipif(not HAS_JAX, reason='JAX is not installed') +@helpers.parametrize( + dtype=[torch.float64, torch.float32], + lr=[1e-3, 1e-4], + inner_lr=[2e-2, 2e-3], + inner_update=[20, 50, 100], +) +def test_imaml_module( # noqa: C901 + dtype: torch.dtype, + lr: float, + inner_lr: float, + inner_update: int, +) -> None: + np_dtype = helpers.dtype_torch2numpy(dtype) + + jax_model, jax_params = get_model_jax(dtype=np_dtype) + model, loader = get_model_torch(device='cpu', dtype=dtype) + + class InnerNet(ImplicitMetaGradientModule): + def __init__(self, meta_model): + super().__init__() + self.meta_model = meta_model + self.model = torchopt.module_clone(meta_model, by='deepcopy', detach_buffers=True) + + def forward(self, x): + return self.model(x) + + def objective(self, x, y): + y_pred = self.model(x) + loss = F.cross_entropy(y_pred, y) + regularization_loss = 0 + for p1, p2 in zip(self.parameters(), self.meta_parameters()): + regularization_loss += 0.5 * torch.sum(torch.square(p1 - p2)) + return loss + regularization_loss + + def solve(self, x, y): + params = tuple(self.parameters()) + optim_inner = torchopt.SGD(params, lr=inner_lr) + with torch.enable_grad(): + # Temporarily enable gradient computation for conducting the optimization + for _ in range(inner_update): + loss = self.objective(x, y) + optim_inner.zero_grad() + loss.backward(inputs=params) + optim_inner.step() + + return self, (0, {'a': 1, 'b': 2}) + + outer_optim = torchopt.SGD(model.parameters(), lr) + + optim_jax = optax.sgd(lr) + optim_state_jax = optim_jax.init(jax_params) + + def imaml_objective_jax(params, meta_params, x, y): + y_pred = jax_model(params, x) + loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(y_pred, y)) + regularization_loss = 0 + for p1, p2 in zip(params.values(), meta_params.values()): + regularization_loss += 0.5 * jnp.sum(jnp.square(p1 - p2)) + return loss + regularization_loss + + @jaxopt.implicit_diff.custom_root(jax.grad(imaml_objective_jax, argnums=0), has_aux=True) + def inner_solver_jax(params, meta_params, x, y): + """Solve ridge regression by conjugate gradient.""" + # Initial functional optimizer based on torchopt + optimizer = optax.sgd(inner_lr) + opt_state = optimizer.init(params) + + def compute_loss(params, meta_params, x, y): + pred = jax_model(params, x) + loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(pred, y)) + # Compute regularization loss + regularization_loss = 0 + for p1, p2 in zip(params.values(), meta_params.values()): + regularization_loss += 0.5 * jnp.sum(jnp.square(p1 - p2)) + return loss + regularization_loss + + for _ in range(inner_update): + grads = jax.grad(compute_loss)(params, meta_params, x, y) # compute gradients + updates, opt_state = optimizer.update(grads, opt_state) # get updates + params = optax.apply_updates(params, updates) + return params, (0, {'a': 1, 'b': 2}) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + inner_model = InnerNet(model) + optimal_inner_model, aux = inner_model.solve(xs, ys) + assert aux == (0, {'a': 1, 'b': 2}) + outer_loss = optimal_inner_model(xs).mean() + + outer_optim.zero_grad() + outer_loss.backward() + outer_optim.step() + + xs = xs.numpy() + ys = ys.numpy() + + def outer_level(p, xs, ys): + optimal_params, aux = inner_solver_jax(copy.deepcopy(p), p, xs, ys) + assert aux == (0, {'a': 1, 'b': 2}) + return jax_model(optimal_params, xs).mean() + + grads_jax = jax.grad(outer_level, argnums=0)(jax_params, xs, ys) + updates_jax, optim_state_jax = optim_jax.update(grads_jax, optim_state_jax) # get updates + jax_params = optax.apply_updates(jax_params, updates_jax) + + jax_params_as_tensor = tuple( + nn.Parameter(torch.tensor(np.asarray(jax_params[j]), dtype=dtype)) for j in jax_params + ) + + helpers.assert_pytree_all_close(tuple(model.parameters()), jax_params_as_tensor) + + +@pytest.mark.skipif(not HAS_JAX, reason='JAX is not installed') +@helpers.parametrize( + dtype=[torch.float64, torch.float32], + lr=[1e-3, 1e-4], +) +def test_rr_solve_cg( + dtype: torch.dtype, + lr: float, +) -> None: + helpers.seed_everything(42) + np_dtype = helpers.dtype_torch2numpy(dtype) + input_size = 10 + + init_params_torch = torch.randn(input_size, dtype=dtype) + l2reg_torch = torch.rand(1, dtype=dtype).squeeze_().requires_grad_(True) + + init_params_jax = jnp.array(init_params_torch.detach().numpy(), dtype=np_dtype) + l2reg_jax = jnp.array(l2reg_torch.detach().numpy(), dtype=np_dtype) + + loader = get_rr_dataset_torch() + + optim = torchopt.sgd(lr) + optim_state = optim.init(l2reg_torch) + + optim_jax = optax.sgd(lr) + optim_state_jax = optim_jax.init(l2reg_jax) + + def ridge_objective_torch(params, l2reg, data): + """Ridge objective function.""" + X_tr, y_tr = data + residuals = X_tr @ params - y_tr + regularization_loss = 0.5 * l2reg * torch.sum(torch.square(params)) + return 0.5 * torch.mean(torch.square(residuals)) + regularization_loss + + @torchopt.diff.implicit.custom_root(functorch.grad(ridge_objective_torch, argnums=0), argnums=1) + def ridge_solver_torch_cg(params, l2reg, data): + """Solve ridge regression by conjugate gradient.""" + X_tr, y_tr = data + + def matvec(u): + return X_tr.T @ (X_tr @ u) + + solve = torchopt.linear_solve.solve_cg( + ridge=len(y_tr) * l2reg.item(), + init=params, + maxiter=20, + ) + + return solve(matvec=matvec, b=X_tr.T @ y_tr) + + def ridge_objective_jax(params, l2reg, X_tr, y_tr): + """Ridge objective function.""" + residuals = X_tr @ params - y_tr + regularization_loss = 0.5 * l2reg * jnp.sum(jnp.square(params)) + return 0.5 * jnp.mean(jnp.square(residuals)) + regularization_loss + + @jaxopt.implicit_diff.custom_root(jax.grad(ridge_objective_jax, argnums=0)) + def ridge_solver_jax_cg(params, l2reg, X_tr, y_tr): + """Solve ridge regression by conjugate gradient.""" + + def matvec(u): + return X_tr.T @ (X_tr @ u) + + return jaxopt.linear_solve.solve_cg( + matvec=matvec, + b=X_tr.T @ y_tr, + ridge=len(y_tr) * l2reg.item(), + init=params, + maxiter=20, + ) + + for xs, ys, xq, yq in loader: + xs = xs.to(dtype=dtype) + ys = ys.to(dtype=dtype) + xq = xq.to(dtype=dtype) + yq = yq.to(dtype=dtype) + + w_fit = ridge_solver_torch_cg(init_params_torch, l2reg_torch, (xs, ys)) + outer_loss = F.mse_loss(xq @ w_fit, yq) + + grads, *_ = torch.autograd.grad(outer_loss, l2reg_torch) + updates, optim_state = optim.update(grads, optim_state) + l2reg_torch = torchopt.apply_updates(l2reg_torch, updates) + + xs = jnp.array(xs.numpy(), dtype=np_dtype) + ys = jnp.array(ys.numpy(), dtype=np_dtype) + xq = jnp.array(xq.numpy(), dtype=np_dtype) + yq = jnp.array(yq.numpy(), dtype=np_dtype) + + def outer_level(params_jax, l2reg_jax, xs, ys, xq, yq): + w_fit = ridge_solver_jax_cg(params_jax, l2reg_jax, xs, ys) + y_pred = xq @ w_fit + return jnp.mean(jnp.square(y_pred - yq)) + + grads_jax = jax.grad(outer_level, argnums=1)(init_params_jax, l2reg_jax, xs, ys, xq, yq) + updates_jax, optim_state_jax = optim_jax.update(grads_jax, optim_state_jax) # get updates + l2reg_jax = optax.apply_updates(l2reg_jax, updates_jax) + + l2reg_jax_as_tensor = torch.tensor(np.asarray(l2reg_jax), dtype=dtype) + helpers.assert_all_close(l2reg_torch, l2reg_jax_as_tensor) + + +@pytest.mark.skipif(not HAS_JAX, reason='JAX is not installed') +@helpers.parametrize( + dtype=[torch.float64, torch.float32], + lr=[1e-3, 1e-4], + ns=[True, False], +) +def test_rr_solve_inv( + dtype: torch.dtype, + lr: float, + ns: bool, +) -> None: + if dtype == torch.float64 and ns: + pytest.skip('Neumann Series test skips torch.float64 due to numerical stability.') + helpers.seed_everything(42) + np_dtype = helpers.dtype_torch2numpy(dtype) + input_size = 10 + + init_params_torch = torch.randn(input_size, dtype=dtype) + l2reg_torch = torch.rand(1, dtype=dtype).squeeze_().requires_grad_(True) + + init_params_jax = jnp.array(init_params_torch.detach().numpy(), dtype=np_dtype) + l2reg_jax = jnp.array(l2reg_torch.detach().numpy(), dtype=np_dtype) + + loader = get_rr_dataset_torch() + + optim = torchopt.sgd(lr) + optim_state = optim.init(l2reg_torch) + + optim_jax = optax.sgd(lr) + optim_state_jax = optim_jax.init(l2reg_jax) + + def ridge_objective_torch(params, l2reg, data): + """Ridge objective function.""" + X_tr, y_tr = data + residuals = X_tr @ params - y_tr + regularization_loss = 0.5 * l2reg * torch.sum(torch.square(params)) + return 0.5 * torch.mean(torch.square(residuals)) + regularization_loss + + @torchopt.diff.implicit.custom_root(functorch.grad(ridge_objective_torch, argnums=0), argnums=1) + def ridge_solver_torch_inv(params, l2reg, data): + """Solve ridge regression by conjugate gradient.""" + X_tr, y_tr = data + + def matvec(u): + return X_tr.T @ (X_tr @ u) + + solve = torchopt.linear_solve.solve_inv( + matvec=matvec, + b=X_tr.T @ y_tr, + ridge=len(y_tr) * l2reg.item(), + ns=ns, + ) + + return solve(matvec=matvec, b=X_tr.T @ y_tr) + + def ridge_objective_jax(params, l2reg, X_tr, y_tr): + """Ridge objective function.""" + residuals = X_tr @ params - y_tr + regularization_loss = 0.5 * l2reg * jnp.sum(jnp.square(params)) + return 0.5 * jnp.mean(jnp.square(residuals)) + regularization_loss + + @jaxopt.implicit_diff.custom_root(jax.grad(ridge_objective_jax, argnums=0)) + def ridge_solver_jax_inv(params, l2reg, X_tr, y_tr): + """Solve ridge regression by conjugate gradient.""" + + def matvec(u): + return X_tr.T @ (X_tr @ u) + + return jaxopt.linear_solve.solve_inv( + matvec=matvec, + b=X_tr.T @ y_tr, + ridge=len(y_tr) * l2reg.item(), + ) + + for xs, ys, xq, yq in loader: + xs = xs.to(dtype=dtype) + ys = ys.to(dtype=dtype) + xq = xq.to(dtype=dtype) + yq = yq.to(dtype=dtype) + + w_fit = ridge_solver_torch_inv(init_params_torch, l2reg_torch, (xs, ys)) + outer_loss = F.mse_loss(xq @ w_fit, yq) + + grads, *_ = torch.autograd.grad(outer_loss, l2reg_torch) + updates, optim_state = optim.update(grads, optim_state) + l2reg_torch = torchopt.apply_updates(l2reg_torch, updates) + + xs = jnp.array(xs.numpy(), dtype=np_dtype) + ys = jnp.array(ys.numpy(), dtype=np_dtype) + xq = jnp.array(xq.numpy(), dtype=np_dtype) + yq = jnp.array(yq.numpy(), dtype=np_dtype) + + def outer_level(params_jax, l2reg_jax, xs, ys, xq, yq): + w_fit = ridge_solver_jax_inv(params_jax, l2reg_jax, xs, ys) + y_pred = xq @ w_fit + return jnp.mean(jnp.square(y_pred - yq)) + + grads_jax = jax.grad(outer_level, argnums=1)(init_params_jax, l2reg_jax, xs, ys, xq, yq) + updates_jax, optim_state_jax = optim_jax.update(grads_jax, optim_state_jax) # get updates + l2reg_jax = optax.apply_updates(l2reg_jax, updates_jax) + + l2reg_jax_as_tensor = torch.tensor(np.asarray(l2reg_jax), dtype=dtype) + helpers.assert_all_close(l2reg_torch, l2reg_jax_as_tensor) + + +def test_module_empty_parameters() -> None: + class EmptyParameters(ImplicitMetaGradientModule): + def __init__(self, x): + super().__init__() + self.x = x + + def objective(self): + return self.x.mean() + + def solve(self): + pass + + model = EmptyParameters(torch.zeros(8)) + with pytest.raises(RuntimeError, match='The module has no parameters.'): + model.solve() + + model = EmptyParameters(torch.zeros(8)) + model.register_parameter('y', torch.zeros(8, requires_grad=True)) + with pytest.raises(RuntimeError, match='The module has no meta-parameters.'): + model.solve() + + model = EmptyParameters(torch.zeros(8, requires_grad=True)) + with pytest.raises(RuntimeError, match='The module has no parameters.'): + model.solve() + + model = EmptyParameters(torch.zeros(8, requires_grad=True)) + with pytest.raises(RuntimeError, match='The module has no parameters.'): + model.optimality() + + model = EmptyParameters(torch.zeros(8)) + model.register_parameter('y', torch.zeros(8, requires_grad=True)) + with pytest.raises(RuntimeError, match='The module has no meta-parameters.'): + model.optimality() + + model = EmptyParameters(torch.zeros(8, requires_grad=True)) + model.register_parameter('y', torch.zeros(8, requires_grad=True)) + model.solve() + + model = EmptyParameters(nn.Linear(8, 8).eval()) + with pytest.raises(RuntimeError, match='The module has no meta-parameters.'): + model.solve() + + model = EmptyParameters(nn.Linear(8, 8)) + model.register_parameter('y', torch.zeros(8, requires_grad=True)) + model.solve() + + +def test_module_enable_implicit_gradients_twice() -> None: + class MyModule1(torchopt.nn.ImplicitMetaGradientModule): + def objective(self): + return torch.tensor(0.0) + + def solve(self): + pass + + from torchopt.diff.implicit.nn.module import ( + enable_implicit_gradients, + make_optimality_from_objective, + ) + + with pytest.raises( + TypeError, + match='Implicit gradients are already enabled for the `solve` method.', + ): + enable_implicit_gradients(MyModule1) + + class MyModule2(torchopt.nn.ImplicitMetaGradientModule): + def optimality(self): + return torch.tensor(0.0) + + def solve(self): + pass + + with pytest.raises( + TypeError, + match='The objective function is not defined.', + ): + make_optimality_from_objective(MyModule2) + + +def test_module_abstract_methods() -> None: # noqa: C901 + class MyModule1(torchopt.nn.ImplicitMetaGradientModule): + def objective(self): + return torch.tensor(0.0) + + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + MyModule1() + + with pytest.raises( + TypeError, + match=re.escape( + 'ImplicitMetaGradientModule requires either an optimality() method or an objective() method', + ), + ): + + class MyModule2(torchopt.nn.ImplicitMetaGradientModule): + def solve(self): + pass + + class MyModule3(torchopt.nn.ImplicitMetaGradientModule): + def optimality(self): + return () + + def solve(self): + pass + + with pytest.raises( + TypeError, + match=re.escape('method optimality() must not be a staticmethod.'), + ): + + class MyModule4(torchopt.nn.ImplicitMetaGradientModule): + @staticmethod + def optimality(): + return () + + def solve(self): + pass + + with pytest.raises( + TypeError, + match=re.escape('method optimality() must not be a classmethod.'), + ): + + class MyModule5(torchopt.nn.ImplicitMetaGradientModule): + @classmethod + def optimality(cls): + return () + + def solve(self): + pass + + with pytest.raises( + TypeError, + match=re.escape('method optimality() must be callable.'), + ): + + class MyModule6(torchopt.nn.ImplicitMetaGradientModule): + optimality = 0 + + def solve(self): + pass + + with pytest.raises( + TypeError, + match=re.escape('method objective() must not be a staticmethod.'), + ): + + class MyModule7(torchopt.nn.ImplicitMetaGradientModule): + @staticmethod + def objective(): + return () + + def solve(self): + pass + + with pytest.raises( + TypeError, + match=re.escape('method objective() must not be a classmethod.'), + ): + + class MyModule8(torchopt.nn.ImplicitMetaGradientModule): + @classmethod + def objective(cls): + return () + + def solve(self): + pass + + with pytest.raises( + TypeError, + match=re.escape('method objective() must be callable.'), + ): + + class MyModule9(torchopt.nn.ImplicitMetaGradientModule): + objective = 0 + + def solve(self): + pass diff --git a/tests/test_import.py b/tests/test_import.py new file mode 100644 index 00000000..04d0ebbb --- /dev/null +++ b/tests/test_import.py @@ -0,0 +1,411 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import torchopt + + +def test_accelerated_op_import() -> None: + torchopt.accelerated_op.adam_op.AdamOp + torchopt.accelerated_op.is_available + torchopt.accelerated_op_available + from torchopt.accelerated_op import is_available + from torchopt.accelerated_op.adam_op import AdamOp + + +def test_alias_import() -> None: + torchopt.adadelta + torchopt.adagrad + torchopt.adam + torchopt.adamw + torchopt.adamax + torchopt.radam + torchopt.rmsprop + torchopt.sgd + torchopt.alias.adadelta + torchopt.alias.adagrad + torchopt.alias.adam + torchopt.alias.adamw + torchopt.alias.adamax + torchopt.alias.radam + torchopt.alias.rmsprop + torchopt.alias.sgd + from torchopt import adadelta, adagrad, adam, adamax, adamw, radam, rmsprop, sgd + from torchopt.alias import adadelta, adagrad, adam, adamax, adamw, radam, rmsprop, sgd + + +def test_diff_import() -> None: + torchopt.diff.implicit + torchopt.diff.implicit.custom_root + torchopt.diff.implicit.ImplicitMetaGradientModule + torchopt.diff.implicit.nn.ImplicitMetaGradientModule + torchopt.diff.zero_order + torchopt.diff.zero_order.zero_order + torchopt.diff.zero_order.ZeroOrderGradientModule + torchopt.diff.zero_order.nn.ZeroOrderGradientModule + from torchopt.diff import implicit, zero_order + from torchopt.diff.implicit import ImplicitMetaGradientModule, custom_root + from torchopt.diff.zero_order import ZeroOrderGradientModule, zero_order + + +def test_distributed_import() -> None: + torchopt.distributed.api + torchopt.distributed.autograd + torchopt.distributed.world + torchopt.distributed.is_available + torchopt.distributed.TensorDimensionPartitioner + torchopt.distributed.dim_partitioner + torchopt.distributed.batch_partitioner + torchopt.distributed.mean_reducer + torchopt.distributed.sum_reducer + torchopt.distributed.remote_async_call + torchopt.distributed.remote_sync_call + torchopt.distributed.parallelize + torchopt.distributed.parallelize_async + torchopt.distributed.parallelize_sync + torchopt.distributed.get_world_info + torchopt.distributed.get_world_rank + torchopt.distributed.get_rank + torchopt.distributed.get_world_size + torchopt.distributed.get_local_rank + torchopt.distributed.get_local_world_size + torchopt.distributed.get_worker_id + torchopt.distributed.barrier + torchopt.distributed.auto_init_rpc + torchopt.distributed.on_rank + torchopt.distributed.not_on_rank + torchopt.distributed.rank_zero_only + torchopt.distributed.rank_non_zero_only + torchopt.distributed.autograd.is_available + torchopt.distributed.autograd.context + from torchopt.distributed import api, autograd, world + + +def test_linalg_import() -> None: + torchopt.linalg.cg + torchopt.linalg.ns + torchopt.linalg.ns_inv + from torchopt.linalg import cg, ns, ns_inv + + +def test_linear_solve_import() -> None: + torchopt.linear_solve.solve_cg + torchopt.linear_solve.solve_inv + torchopt.linear_solve.solve_normal_cg + from torchopt.linear_solve import solve_cg, solve_inv, solve_normal_cg + + +def test_nn_import() -> None: + torchopt.nn.MetaGradientModule + torchopt.nn.ImplicitMetaGradientModule + torchopt.nn.ZeroOrderGradientModule + from torchopt.nn import ImplicitMetaGradientModule, MetaGradientModule, ZeroOrderGradientModule + + +def test_optim_import() -> None: + torchopt.FuncOptimizer + torchopt.MetaAdaDelta + torchopt.MetaAdadelta + torchopt.MetaAdaGrad + torchopt.MetaAdagrad + torchopt.MetaAdam + torchopt.MetaAdamW + torchopt.MetaAdaMax + torchopt.MetaAdamax + torchopt.MetaRAdam + torchopt.MetaRMSProp + torchopt.MetaRMSprop + torchopt.MetaSGD + torchopt.AdaDelta + torchopt.Adadelta + torchopt.AdaGrad + torchopt.Adagrad + torchopt.Adam + torchopt.AdamW + torchopt.AdaMax + torchopt.Adamax + torchopt.Optimizer + torchopt.RMSProp + torchopt.RMSprop + torchopt.SGD + torchopt.optim.meta.MetaAdaDelta + torchopt.optim.meta.MetaAdadelta + torchopt.optim.meta.MetaAdaGrad + torchopt.optim.meta.MetaAdagrad + torchopt.optim.meta.MetaAdam + torchopt.optim.meta.MetaAdamW + torchopt.optim.meta.MetaAdaMax + torchopt.optim.meta.MetaAdamax + torchopt.optim.meta.MetaRMSProp + torchopt.optim.meta.MetaRMSprop + torchopt.optim.meta.MetaSGD + torchopt.optim.Adam + torchopt.optim.AdamW + torchopt.optim.Optimizer + torchopt.optim.RMSProp + torchopt.optim.RMSprop + torchopt.optim.SGD + torchopt.optim.func.FuncOptimizer + from torchopt import ( + SGD, + AdaDelta, + Adadelta, + AdaGrad, + Adagrad, + Adam, + AdaMax, + Adamax, + AdamW, + FuncOptimizer, + MetaAdaDelta, + MetaAdadelta, + MetaAdaGrad, + MetaAdagrad, + MetaAdam, + MetaAdaMax, + MetaAdamax, + MetaAdamW, + MetaOptimizer, + MetaRMSprop, + MetaRMSProp, + MetaSGD, + Optimizer, + RMSProp, + ) + from torchopt.optim import SGD, Adam, AdamW, FuncOptimizer, Optimizer, RMSProp + from torchopt.optim.func import FuncOptimizer + from torchopt.optim.meta import ( + MetaAdaDelta, + MetaAdadelta, + MetaAdaGrad, + MetaAdagrad, + MetaAdam, + MetaAdaMax, + MetaAdamax, + MetaAdamW, + MetaOptimizer, + MetaRAdam, + MetaRMSProp, + MetaRMSprop, + MetaSGD, + ) + + +def test_schedule_import() -> None: + torchopt.schedule.linear_schedule + torchopt.schedule.polynomial_schedule + from torchopt.schedule import linear_schedule, polynomial_schedule + + +def test_transform_import() -> None: + torchopt.transform.add_decayed_weights + torchopt.transform.scale + torchopt.transform.scale_by_accelerated_adam + torchopt.transform.scale_by_adam + torchopt.transform.scale_by_rms + torchopt.transform.scale_by_schedule + torchopt.transform.scale_by_stddev + torchopt.transform.trace + torchopt.transform.nan_to_num + torchopt.nan_to_num + from torchopt import nan_to_num + from torchopt.transform import ( + add_decayed_weights, + nan_to_num, + scale, + scale_by_accelerated_adam, + scale_by_adam, + scale_by_rms, + scale_by_schedule, + scale_by_stddev, + trace, + ) + + +def test_base_import() -> None: + torchopt.base.EmptyState + torchopt.base.GradientTransformation + torchopt.base.ChainedGradientTransformation + torchopt.base.identity + from torchopt.base import ( + ChainedGradientTransformation, + EmptyState, + GradientTransformation, + identity, + ) + + +def test_clip_import() -> None: + torchopt.clip_grad_norm + torchopt.clip.clip_grad_norm + from torchopt import clip_grad_norm + from torchopt.clip import clip_grad_norm + + +def test_combine_import() -> None: + torchopt.chain + torchopt.chain.flat + torchopt.combine.chain + torchopt.combine.chain.flat + torchopt.combine.chain_flat + from torchopt import chain + from torchopt.combine import chain, chain_flat + + +def test_hook_import() -> None: + torchopt.register_hook + torchopt.hook.register_hook + torchopt.hook.zero_nan_hook + torchopt.hook.nan_to_num_hook + from torchopt import register_hook + from torchopt.hook import nan_to_num_hook, register_hook, zero_nan_hook + + +def test_pytree_import() -> None: + torchopt.pytree.tree_flatten_as_tuple + torchopt.pytree.tree_pos + torchopt.pytree.tree_neg + torchopt.pytree.tree_add + torchopt.pytree.tree_add_scalar_mul + torchopt.pytree.tree_sub + torchopt.pytree.tree_sub_scalar_mul + torchopt.pytree.tree_mul + torchopt.pytree.tree_matmul + torchopt.pytree.tree_scalar_mul + torchopt.pytree.tree_truediv + torchopt.pytree.tree_vdot_real + torchopt.pytree.tree_wait + from torchopt.pytree import ( + tree_add, + tree_add_scalar_mul, + tree_flatten_as_tuple, + tree_matmul, + tree_mul, + tree_neg, + tree_pos, + tree_scalar_mul, + tree_sub, + tree_sub_scalar_mul, + tree_truediv, + tree_vdot_real, + tree_wait, + ) + + +def test_typing_import() -> None: + torchopt.typing.GradientTransformation + torchopt.typing.ChainedGradientTransformation + torchopt.typing.EmptyState + torchopt.typing.UninitializedState + torchopt.typing.Params + torchopt.typing.Updates + torchopt.typing.OptState + torchopt.typing.Scalar + torchopt.typing.Numeric + torchopt.typing.Schedule + torchopt.typing.ScalarOrSchedule + torchopt.typing.PyTree + torchopt.typing.Tensor + torchopt.typing.OptionalTensor + torchopt.typing.ListOfTensors + torchopt.typing.TupleOfTensors + torchopt.typing.SequenceOfTensors + torchopt.typing.TensorOrTensors + torchopt.typing.TensorTree + torchopt.typing.ListOfOptionalTensors + torchopt.typing.TupleOfOptionalTensors + torchopt.typing.SequenceOfOptionalTensors + torchopt.typing.OptionalTensorOrOptionalTensors + torchopt.typing.OptionalTensorTree + torchopt.typing.TensorContainer + torchopt.typing.ModuleTensorContainers + torchopt.typing.Future + torchopt.typing.LinearSolver + torchopt.typing.Device + torchopt.typing.Size + torchopt.typing.Distribution + torchopt.typing.SampleFunc + torchopt.typing.Samplable + from torchopt.typing import ( + ChainedGradientTransformation, + Device, + Distribution, + EmptyState, + Future, + GradientTransformation, + LinearSolver, + ListOfOptionalTensors, + ListOfTensors, + ModuleTensorContainers, + Numeric, + OptionalTensor, + OptionalTensorOrOptionalTensors, + OptionalTensorTree, + OptState, + Params, + PyTree, + Samplable, + SampleFunc, + Scalar, + ScalarOrSchedule, + Schedule, + SequenceOfOptionalTensors, + SequenceOfTensors, + Size, + Tensor, + TensorContainer, + TensorOrTensors, + TensorTree, + TupleOfOptionalTensors, + TupleOfTensors, + UninitializedState, + Updates, + ) + + +def test_update_import() -> None: + torchopt.apply_updates + torchopt.update.apply_updates + from torchopt import apply_updates + from torchopt.update import apply_updates + + +def test_utils_import() -> None: + torchopt.utils.ModuleState + torchopt.utils.stop_gradient + torchopt.utils.extract_state_dict + torchopt.utils.recover_state_dict + torchopt.utils.module_clone + torchopt.utils.module_detach_ + from torchopt.utils import ( + ModuleState, + extract_state_dict, + module_clone, + module_detach_, + recover_state_dict, + stop_gradient, + ) + + +def test_version_import() -> None: + torchopt.__version__ + torchopt.version.__version__ + from torchopt import __version__ + from torchopt.version import __version__ + + +def test_visual_import() -> None: + torchopt.visual.make_dot + torchopt.visual.resize_graph + from torchopt.visual import make_dot, resize_graph diff --git a/torchopt/_src/typing.py b/tests/test_linalg.py similarity index 61% rename from torchopt/_src/typing.py rename to tests/test_linalg.py index b2104682..c5b07618 100644 --- a/torchopt/_src/typing.py +++ b/tests/test_linalg.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,16 +13,15 @@ # limitations under the License. # ============================================================================== -from typing import Any, Callable, Iterable, Mapping, TypeVar, Union +import torch -from torch import Tensor +import torchopt -Scalar = TypeVar('Scalar', float, int) -Numeric = Union[Tensor, Scalar] - -Schedule = Callable[[Numeric], Numeric] -ScalarOrSchedule = Union[float, Schedule] - -# mypy: ignore-errors -TensorTree = Union[Tensor, Iterable['TensorTree'], Mapping[Any, 'TensorTree']] +def test_normalize_matvec() -> None: + A = [torch.rand(10, 10) for _ in range(10)] + x = [torch.rand(10, 1) for _ in range(10)] + AxFn = torchopt.linalg.utils.normalize_matvec(A) + Ax = AxFn(x) + for Ax_item, A_item, x_item in zip(Ax, A, x): + assert torch.equal(Ax_item, A_item @ x_item) diff --git a/tests/test_meta_optim.py b/tests/test_meta_optim.py new file mode 100644 index 00000000..55712bdf --- /dev/null +++ b/tests/test_meta_optim.py @@ -0,0 +1,90 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import annotations + +import torch +import torch.nn.functional as F + +import helpers +import torchopt + + +@helpers.parametrize( + dtype=[torch.float64], + outer_lr=[1e-2, 1e-3, 1e-4], + inner_lr=[1e-2, 1e-3, 1e-4], + inner_update=[2, 3, 5], + betas=[(0.9, 0.999), (0.95, 0.9995)], + eps=[1e-8], + eps_root=[0.0, 1e-8], + weight_decay=[0.0, 1e-2], + maximize=[False, True], + use_accelerated_op=[False, True], + moment_requires_grad=[True, False], +) +def test_maml_meta_adam( + dtype: torch.dtype, + outer_lr: float, + inner_lr: float, + inner_update: int, + betas: tuple[float, float], + eps: float, + eps_root: float, + weight_decay: float, + maximize: bool, + use_accelerated_op: bool, + moment_requires_grad: bool, +) -> None: + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + outer_optim = torchopt.Adam( + model.parameters(), + outer_lr, + betas=betas, + eps=eps, + eps_root=0.0, + weight_decay=weight_decay, + maximize=maximize, + use_accelerated_op=use_accelerated_op, + ) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + + inner_optim = torchopt.MetaAdam( + module=model, + lr=inner_lr, + betas=betas, + eps=eps, + eps_root=eps_root, + moment_requires_grad=moment_requires_grad, + weight_decay=weight_decay, + maximize=maximize, + use_accelerated_op=use_accelerated_op, + ) + + for _ in range(inner_update): + pred = model(xs) + inner_loss = F.cross_entropy(pred, ys) # compute loss + inner_optim.step(inner_loss) + + pred = model(xs) + outer_loss = F.cross_entropy(pred, ys) + outer_optim.zero_grad() + outer_loss.backward() + outer_optim.step() + + torchopt.stop_gradient(model) diff --git a/tests/test_nn.py b/tests/test_nn.py new file mode 100644 index 00000000..f77c20ec --- /dev/null +++ b/tests/test_nn.py @@ -0,0 +1,252 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import re + +import pytest +import torch +import torch.nn as nn + +import helpers +import torchopt + + +def test_property() -> None: + m = torchopt.nn.MetaGradientModule() + x = helpers.get_model() + m.add_module('x', x) + assert m.x is x + delattr(m, 'x') + assert not hasattr(m, 'x') + m.add_meta_module('x', x) + assert m.x is x + delattr(m, 'x') + assert not hasattr(m, 'x') + x = torch.tensor(1.0, requires_grad=True) + m.register_parameter('x', x) + assert m.x is x + delattr(m, 'x') + assert not hasattr(m, 'x') + x = torch.tensor(1.0, requires_grad=True) + m.register_meta_parameter('x', x) + assert m.x is x + delattr(m, 'x') + assert not hasattr(m, 'x') + m.register_buffer('x', x) + assert len(m._buffers) == 1 + assert m.x is x + delattr(m, 'x') + assert len(m._buffers) == 0 + assert not hasattr(m, 'x') + + +def test_register_tensors() -> None: + x = torch.tensor(1.0, requires_grad=True) + y = torch.tensor(1.0, requires_grad=True) + z = torch.tensor(1.0, requires_grad=False) + b = torch.tensor(1.0, requires_grad=False) + + m = torchopt.nn.MetaGradientModule() + m.register_meta_parameter('x', x) + assert m.x is x + + m = torchopt.nn.MetaGradientModule(x) + m.x = x + m.y = y + m.z = z + + assert m._meta_parameters['x'] is x + assert m._parameters['y'] is y + assert ( + hasattr(m, 'z') + and m.z is z + and 'z' not in m._meta_parameters + and 'z' not in m._parameters + and 'z' not in m._buffers + ) + + del m.x + object.__setattr__(m, 'x', x) + assert hasattr(m, 'x') and m.x is x and 'x' not in m._meta_parameters + m.x = x + assert m._meta_parameters['x'] is x + + m.register_buffer('b', None) + assert m.b is None + m.b = b + assert m.b is b and 'b' in m._buffers + + m = torchopt.nn.MetaGradientModule(x, b) + + with pytest.raises( + TypeError, + match=re.escape('parameter name should be a string. Got bytes'), + ): + m.register_meta_parameter(b'x', x) + + with pytest.raises( + KeyError, + match=re.escape("parameter name can't contain '.'"), + ): + m.register_meta_parameter('x.x', x) + + with pytest.raises( + KeyError, + match=re.escape("parameter name can't be empty string ''"), + ): + m.register_meta_parameter('', x) + + m.register_buffer('z', None) + with pytest.raises( + KeyError, + match=re.escape("attribute 'z' already exists"), + ): + m.register_meta_parameter('z', x) + + with pytest.raises( + ValueError, + match=re.escape( + "cannot assign Tensor that is a meta-parameter to parameter 'x'. " + 'Use self.register_meta_parameter() instead.', + ), + ): + m.register_parameter('x', x) + + m.x = x + with pytest.raises( + KeyError, + match=re.escape("attribute 'x' already exists"), + ): + m.register_parameter('x', x) + + with pytest.raises( + TypeError, + match=re.escape('parameter name should be a string. Got bytes'), + ): + m.register_parameter(b'y', y) + + with pytest.raises( + KeyError, + match=re.escape("parameter name can't contain '.'"), + ): + m.register_parameter('y.x', y) + + with pytest.raises( + KeyError, + match=re.escape("parameter name can't be empty string ''"), + ): + m.register_parameter('', y) + + +def test_no_super_init() -> None: + class NoSuper1(torchopt.nn.MetaGradientModule): + def __init__(self, x) -> None: + self.x = x + + with pytest.raises( + AttributeError, + match=re.escape('cannot assign parameters before Module.__init__() call'), + ): + NoSuper1(torch.tensor(1.0, requires_grad=True)) + + class NoSuper2(torchopt.nn.MetaGradientModule): + def __init__(self) -> None: + self.x = torch.tensor(1.0, requires_grad=True) + + with pytest.raises( + AttributeError, + match=re.escape('cannot assign parameters before Module.__init__() call'), + ): + NoSuper2() + + class NoSuper3(torchopt.nn.MetaGradientModule): + def __init__(self) -> None: + self.register_buffer('x', torch.tensor(1.0)) + + with pytest.raises( + AttributeError, + match=re.escape('cannot assign buffer before Module.__init__() call'), + ): + NoSuper3() + + class NoSuper4(torchopt.nn.MetaGradientModule): + def __init__(self) -> None: + self.x = torch.tensor(1.0, requires_grad=False) + + NoSuper4() # no error + + class NoSuper5(torchopt.nn.MetaGradientModule): + def __init__(self, x) -> None: + self.x = x + + with pytest.raises( + AttributeError, + match=re.escape('cannot assign module before Module.__init__() call'), + ): + NoSuper5(nn.Linear(1, 1)) + + class NoSuper6(torchopt.nn.MetaGradientModule): + def __init__(self) -> None: + self.x = nn.Linear(1, 1) + + with pytest.raises( + AttributeError, + match=re.escape('cannot assign module before Module.__init__() call'), + ): + NoSuper6() + + +def test_add_meta_module() -> None: + meta_module = helpers.get_model() + fc = nn.Linear(1, 1) + + m = torchopt.nn.MetaGradientModule(meta_module) + m.fc = fc + assert m.fc is fc + assert m._modules['fc'] is fc + + m.meta = meta_module + assert m.meta is meta_module + assert m._meta_modules['meta'] is meta_module + + assert all(p1 is p2 for p1, p2 in zip(m.parameters(), fc.parameters())) + assert all(p1 is p2 for p1, p2 in zip(m.meta_parameters(), meta_module.parameters())) + + m = torchopt.nn.MetaGradientModule(meta_module) + m.add_meta_module('fc', fc) + assert m.fc is fc + assert all(p1 is p2 for p1, p2 in zip(m.meta_parameters(), fc.parameters())) + + +def test_meta_module() -> None: + m = torchopt.nn.MetaGradientModule() + meta_module = torch.nn.Linear(1, 1) + m.add_meta_module('m', meta_module) + assert next(m.named_meta_modules())[1] is meta_module + assert next(m.named_meta_children())[1] is meta_module + assert next(m.meta_children()) is meta_module + assert next(m.meta_modules()) is meta_module + + +def test_add_meta_parameters() -> None: + m = torchopt.nn.MetaGradientModule() + x = torch.tensor(1.0, requires_grad=True) + m.register_meta_parameter('x', x) + assert next(m.named_meta_parameters())[1] is x + + +def test_named_modules() -> None: + m = torchopt.nn.MetaGradientModule() + assert next(m.named_modules())[1] is m diff --git a/tests/test_optimizer.py b/tests/test_optim.py similarity index 58% rename from tests/test_optimizer.py rename to tests/test_optim.py index c0db3e34..1257054f 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optim.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,11 @@ # limitations under the License. # ============================================================================== -from typing import Tuple +from __future__ import annotations +from typing import Callable + +import functorch import pytest import torch import torch.nn.functional as F @@ -30,7 +33,7 @@ dampening=[0.0, 0.5], nesterov=[False, True], weight_decay=[0.0, 1e-2], - maximize=[False], # TODO: test maximize after PyTorch 1.13 + maximize=[False, True], ) def test_SGD( dtype: torch.dtype, @@ -90,14 +93,16 @@ def test_SGD( eps=[1e-8], weight_decay=[0.0, 1e-2], maximize=[False, True], + use_accelerated_op=[False, True], ) def test_Adam( dtype: torch.dtype, lr: float, - betas: Tuple[float, float], + betas: tuple[float, float], eps: float, weight_decay: float, maximize: bool, + use_accelerated_op: bool, ) -> None: model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) @@ -109,6 +114,7 @@ def test_Adam( eps_root=0.0, weight_decay=weight_decay, maximize=maximize, + use_accelerated_op=use_accelerated_op, ) optim_ref = torch.optim.Adam( model_ref.parameters(), @@ -138,41 +144,84 @@ def test_Adam( helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype) +@helpers.parametrize( + dtype=[torch.float64], + lr=[1e-2, 1e-3, 1e-4], + rho=[0.9, 0.95], + eps=[1e-8], + weight_decay=[0.0, 1e-2], +) +def test_Adadelta( + dtype: torch.dtype, + lr: float, + rho: float, + eps: float, + weight_decay: float, +) -> None: + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + optim = torchopt.Adadelta( + model.parameters(), + lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + ) + optim_ref = torch.optim.Adadelta( + model_ref.parameters(), + lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + ) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = model(xs) + pred_ref = model_ref(xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + optim.zero_grad() + loss.backward() + optim.step() + + optim_ref.zero_grad() + loss_ref.backward() + optim_ref.step() + + helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype) + + @helpers.parametrize( dtype=[torch.float64], lr=[1e-2, 1e-3, 1e-4], betas=[(0.9, 0.999), (0.95, 0.9995)], eps=[1e-8], - weight_decay=[1e-2, 1e-1], - maximize=[False, True], + weight_decay=[0.0, 1e-2], ) -def test_AdamW( +def test_RAdam( dtype: torch.dtype, lr: float, - betas: Tuple[float, float], + betas: tuple[float, float], eps: float, weight_decay: float, - maximize: bool, ) -> None: model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) - optim = torchopt.AdamW( + optim = torchopt.RAdam( model.parameters(), lr, betas=betas, eps=eps, - eps_root=0.0, weight_decay=weight_decay, - maximize=maximize, ) - optim_ref = torch.optim.AdamW( + optim_ref = torch.optim.RAdam( model_ref.parameters(), lr, betas=betas, eps=eps, - amsgrad=False, weight_decay=weight_decay, - maximize=maximize, ) for xs, ys in loader: @@ -199,19 +248,70 @@ def test_AdamW( betas=[(0.9, 0.999), (0.95, 0.9995)], eps=[1e-8], weight_decay=[0.0, 1e-2], +) +def test_Adamax( + dtype: torch.dtype, + lr: float, + betas: tuple[float, float], + eps: float, + weight_decay: float, +) -> None: + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + optim = torchopt.Adamax( + model.parameters(), + lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + ) + optim_ref = torch.optim.Adamax( + model_ref.parameters(), + lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + ) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = model(xs) + pred_ref = model_ref(xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + optim.zero_grad() + loss.backward() + optim.step() + + optim_ref.zero_grad() + loss_ref.backward() + optim_ref.step() + + helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype) + + +@helpers.parametrize( + dtype=[torch.float64], + lr=[1e-2, 1e-3, 1e-4], + betas=[(0.9, 0.999), (0.95, 0.9995)], + eps=[1e-8], + weight_decay=[1e-2, 1e-1], maximize=[False, True], + use_accelerated_op=[False, True], ) -def test_Adam_accelerated_cpu( +def test_AdamW( dtype: torch.dtype, lr: float, - betas: Tuple[float, float], + betas: tuple[float, float], eps: float, weight_decay: float, maximize: bool, + use_accelerated_op: bool, ) -> None: model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) - optim = torchopt.Adam( + optim = torchopt.AdamW( model.parameters(), lr, betas=betas, @@ -219,9 +319,9 @@ def test_Adam_accelerated_cpu( eps_root=0.0, weight_decay=weight_decay, maximize=maximize, - use_accelerated_op=True, + use_accelerated_op=use_accelerated_op, ) - optim_ref = torch.optim.Adam( + optim_ref = torch.optim.AdamW( model_ref.parameters(), lr, betas=betas, @@ -253,6 +353,10 @@ def test_Adam_accelerated_cpu( @helpers.parametrize( dtype=[torch.float64], lr=[1e-2, 1e-3, 1e-4], + optimizers=[ + (torchopt.Adam, torch.optim.Adam), + (torchopt.AdamW, torch.optim.AdamW), + ], betas=[(0.9, 0.999), (0.95, 0.9995)], eps=[1e-8], weight_decay=[0.0, 1e-2], @@ -261,7 +365,8 @@ def test_Adam_accelerated_cpu( def test_Adam_accelerated_cuda( dtype: torch.dtype, lr: float, - betas: Tuple[float, float], + optimizers: tuple[torchopt.Optimizer, torch.optim.Optimizer], + betas: tuple[float, float], eps: float, weight_decay: float, maximize: bool, @@ -269,7 +374,9 @@ def test_Adam_accelerated_cuda( device = 'cuda' model, model_ref, model_base, loader = helpers.get_models(device=device, dtype=dtype) - optim = torchopt.Adam( + torchopt_optimizer, torch_optimizer = optimizers + + optim = torchopt_optimizer( model.parameters(), lr, betas=betas, @@ -279,7 +386,7 @@ def test_Adam_accelerated_cuda( maximize=maximize, use_accelerated_op=True, ) - optim_ref = torch.optim.Adam( + optim_ref = torch_optimizer( model_ref.parameters(), lr, betas=betas, @@ -308,6 +415,63 @@ def test_Adam_accelerated_cuda( helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype) +@helpers.parametrize( + dtype=[torch.float64], + lr=[1e-2, 1e-3, 1e-4], + lr_decay=[0.0, 1e-2], + initial_accumulator_value=[0.0, 1e-1], + eps=[1e-8], + weight_decay=[0.0, 1e-2], + maximize=[False, True], +) +def test_AdaGrad( + dtype: torch.dtype, + lr: float, + lr_decay: float, + initial_accumulator_value: float, + eps: float, + weight_decay: float, + maximize: bool, +) -> None: + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + optim = torchopt.AdaGrad( + model.parameters(), + lr=lr, + lr_decay=lr_decay, + weight_decay=weight_decay, + initial_accumulator_value=initial_accumulator_value, + eps=eps, + maximize=maximize, + ) + optim_ref = torch.optim.Adagrad( + model_ref.parameters(), + lr=lr, + lr_decay=lr_decay, + weight_decay=weight_decay, + initial_accumulator_value=initial_accumulator_value, + eps=eps, + maximize=maximize, + ) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = model(xs) + pred_ref = model_ref(xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + optim.zero_grad() + loss.backward() + optim.step() + + optim_ref.zero_grad() + loss_ref.backward() + optim_ref.step() + + helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype) + + @helpers.parametrize( dtype=[torch.float64], lr=[1e-2, 1e-3, 1e-4], @@ -364,3 +528,59 @@ def test_RMSProp( optim_ref.step() helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype) + + +@helpers.parametrize( + dtype=[torch.float64, torch.float32], + lr=[1e-2, 1e-3], + optimizers=[ + (torchopt.sgd, torch.optim.SGD, {}), + (torchopt.adam, torch.optim.Adam, {}), + (torchopt.adamw, torch.optim.AdamW, {}), + (torchopt.adagrad, torch.optim.Adagrad, {'eps': 1e-8}), + (torchopt.rmsprop, torch.optim.RMSprop, {}), + ], + inplace=[True, False], + weight_decay=[0.0, 1e-2], +) +def test_FuncOptimizer( + dtype: torch.dtype, + lr: float, + optimizers: tuple[Callable, torch.optim.Optimizer], + inplace: bool, + weight_decay: float, +) -> None: + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + torchopt_optimizer, torch_optimizer, optimizer_kwargs = optimizers + + fmodel, params, buffers = functorch.make_functional_with_buffers(model) + optim = torchopt.FuncOptimizer( + torchopt_optimizer( + lr=lr, + weight_decay=weight_decay, + **optimizer_kwargs, + ), + inplace=inplace, + ) + optim_ref = torch_optimizer( + model_ref.parameters(), + lr, + weight_decay=weight_decay, + **optimizer_kwargs, + ) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = fmodel(params, buffers, xs) + pred_ref = model_ref(xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + params = optim.step(loss, params) + + optim_ref.zero_grad() + loss_ref.backward() + optim_ref.step() + + helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) diff --git a/tests/test_pytree.py b/tests/test_pytree.py new file mode 100644 index 00000000..6ee2939b --- /dev/null +++ b/tests/test_pytree.py @@ -0,0 +1,217 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import torch + +import helpers +from torchopt import pytree + + +tree_a = (torch.randn(20, 10), torch.randn(20)) +tree_b = (torch.randn(20, 10), torch.randn(20)) + +tree_a_dict = ( + torch.tensor(1.0), + {'k1': torch.tensor(1.0), 'k2': (torch.tensor(1.0), torch.tensor(1.0))}, + torch.tensor(1.0), +) +tree_b_dict = ( + torch.tensor(1.0), + {'k1': torch.tensor(2.0), 'k2': (torch.tensor(3.0), torch.tensor(4.0))}, + torch.tensor(5.0), +) + +tensor_a = torch.randn(20) +tensor_b = torch.randn(20) + + +def test_tree_flatten_as_tuple() -> None: + expected_leaves, expected_treespec = (tensor_a,), pytree.tree_structure(tensor_a) + actual_leaves, actual_treespec = pytree.tree_flatten_as_tuple(tensor_a) + assert actual_leaves == expected_leaves + assert actual_treespec == expected_treespec + + leaves_a, treespec_a = pytree.tree_flatten(tree_a) + expected_leaves, expected_treespec = tuple(leaves_a), treespec_a + actual_leaves, actual_treespec = pytree.tree_flatten_as_tuple(tree_a) + assert actual_leaves == expected_leaves + assert actual_treespec == expected_treespec + + +def test_tree_pos() -> None: + expected = +tensor_a + actual = pytree.tree_pos(tensor_a) + helpers.assert_pytree_all_close(actual, expected) + + expected = (+tree_a[0], +tree_a[1]) + actual = pytree.tree_pos(tree_a) + helpers.assert_pytree_all_close(actual, expected) + + +def test_tree_neg() -> None: + expected = -tensor_a + actual = pytree.tree_neg(tensor_a) + helpers.assert_pytree_all_close(actual, expected) + + expected = (-tree_a[0], -tree_a[1]) + actual = pytree.tree_neg(tree_a) + helpers.assert_pytree_all_close(actual, expected) + + +def test_tree_add() -> None: + expected = tensor_a + tensor_b + actual = pytree.tree_add(tensor_a, tensor_b) + helpers.assert_pytree_all_close(actual, expected) + + expected = (tree_a[0] + tree_b[0], tree_a[1] + tree_b[1]) + actual = pytree.tree_add(tree_a, tree_b) + helpers.assert_pytree_all_close(actual, expected) + + +def test_tree_add_scalar_mul() -> None: + expected = (tree_a[0] + tree_b[0], tree_a[1] + tree_b[1]) + actual = pytree.tree_add_scalar_mul(tree_a, tree_b) + helpers.assert_pytree_all_close(actual, expected) + + expected = (tree_a[0] + 0.5 * tree_b[0], tree_a[1] + 0.5 * tree_b[1]) + actual = pytree.tree_add_scalar_mul(tree_a, tree_b, 0.5) + helpers.assert_pytree_all_close(actual, expected) + + +def test_tree_sub() -> None: + expected = tensor_a - tensor_b + actual = pytree.tree_sub(tensor_a, tensor_b) + helpers.assert_pytree_all_close(actual, expected) + + expected = (tree_a[0] - tree_b[0], tree_a[1] - tree_b[1]) + actual = pytree.tree_sub(tree_a, tree_b) + helpers.assert_pytree_all_close(actual, expected) + + +def test_tree_sub_scalar_mul() -> None: + expected = (tree_a[0] - tree_b[0], tree_a[1] - tree_b[1]) + actual = pytree.tree_sub_scalar_mul(tree_a, tree_b) + helpers.assert_pytree_all_close(actual, expected) + + expected = (tree_a[0] - 0.5 * tree_b[0], tree_a[1] - 0.5 * tree_b[1]) + actual = pytree.tree_sub_scalar_mul(tree_a, tree_b, 0.5) + helpers.assert_pytree_all_close(actual, expected) + + +def test_tree_mul() -> None: + expected = tensor_a * tensor_b + actual = pytree.tree_mul(tensor_a, tensor_b) + helpers.assert_pytree_all_close(actual, expected) + + expected = (tree_a[0] * tree_b[0], tree_a[1] * tree_b[1]) + actual = pytree.tree_mul(tree_a, tree_b) + helpers.assert_pytree_all_close(actual, expected) + + +def test_tree_matmul() -> None: + tree_a = (torch.randn(20, 10), torch.randn(20, 1)) + tree_b = (torch.randn(10, 20), torch.randn(1, 20)) + tensor_a = torch.randn(10, 20) + tensor_b = torch.randn(20) + expected = tensor_a @ tensor_b + actual = pytree.tree_matmul(tensor_a, tensor_b) + helpers.assert_pytree_all_close(actual, expected) + + expected = (tree_a[0] @ tree_b[0], tree_a[1] @ tree_b[1]) + actual = pytree.tree_matmul(tree_a, tree_b) + helpers.assert_pytree_all_close(actual, expected) + + +def test_tree_scalar_mul() -> None: + expected = 0.5 * tensor_a + actual = pytree.tree_scalar_mul(0.5, tensor_a) + helpers.assert_pytree_all_close(actual, expected) + + expected = (0.5 * tree_a[0], 0.5 * tree_a[1]) + actual = pytree.tree_scalar_mul(0.5, tree_a) + helpers.assert_pytree_all_close(actual, expected) + + +def test_tree_truediv() -> None: + expected = (tree_a[0] / tree_b[0], tree_a[1] / tree_b[1]) + actual = pytree.tree_truediv(tree_a, tree_b) + helpers.assert_pytree_all_close(actual, expected) + + actual = pytree.tree_truediv(tree_a_dict, tree_b_dict) + expected = ( + torch.tensor(1.0), + {'k1': torch.tensor(0.5), 'k2': (torch.tensor(1.0 / 3.0), torch.tensor(0.25))}, + torch.tensor(0.2), + ) + helpers.assert_pytree_all_close(actual, expected) + + +def test_tree_vdot_real() -> None: + expected = torch.vdot(tensor_a, tensor_b).real + actual = torch.tensor(pytree.tree_vdot_real(tensor_a, tensor_b)) + helpers.assert_pytree_all_close(actual, expected) + + expected = ( + torch.vdot(tree_a[0].contiguous().view(-1), tree_b[0].contiguous().view(-1)) + + torch.vdot(tree_a[1].contiguous().view(-1), tree_b[1].contiguous().view(-1)) + ).real + actual = torch.tensor(pytree.tree_vdot_real(tree_a, tree_b)) + helpers.assert_all_close(actual, expected) + + tensor_a_complex = torch.randn(20, dtype=torch.cfloat) + tensor_b_complex = torch.randn(20, dtype=torch.cfloat) + expected = torch.vdot(tensor_a_complex, tensor_b_complex).real + actual = torch.tensor(pytree.tree_vdot_real(tensor_a_complex, tensor_b_complex)) + helpers.assert_pytree_all_close(actual, expected) + + tree_a_complex, tree_b_complex = pytree.tree_map( + lambda x: torch.randn(x.size(), dtype=torch.cfloat), + (tree_a, tree_b), + ) + expected = ( + torch.vdot(tree_a_complex[0].contiguous().view(-1), tree_b_complex[0].contiguous().view(-1)) + + torch.vdot( + tree_a_complex[1].contiguous().view(-1), + tree_b_complex[1].contiguous().view(-1), + ) + ).real + actual = torch.tensor(pytree.tree_vdot_real(tree_a_complex, tree_b_complex)) + helpers.assert_all_close(actual, expected) + + +@helpers.parametrize( + tree_name=[ + 'tree_a', + 'tree_b', + 'tree_a_dict', + 'tree_b_dict', + 'tensor_a', + 'tensor_b', + ], +) +def test_tree_wait(tree_name: str) -> None: + tree = globals()[tree_name] + + future_tree = pytree.tree_map(lambda x: torch.futures.Future(), tree) + new_future_tree = pytree.tree_map( + lambda fut: fut.then(lambda f: torch.square(f.wait()) + 1.0), + future_tree, + ) + pytree.tree_map_(lambda fut, x: fut.set_result(x), future_tree, tree) + + expected = pytree.tree_map(lambda x: torch.square(x) + 1.0, tree) + actual = pytree.tree_wait(new_future_tree) + assert all(fut.done() for fut in pytree.tree_leaves(new_future_tree)) + helpers.assert_pytree_all_close(actual, expected) diff --git a/tests/test_schedule.py b/tests/test_schedule.py index 971c0de4..e4c0ac0a 100644 --- a/tests/test_schedule.py +++ b/tests/test_schedule.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,9 +13,55 @@ # limitations under the License. # ============================================================================== +from __future__ import annotations + +from typing import Any, Callable + +import functorch import numpy as np +import torch +import torch.nn.functional as F +import helpers import torchopt +from torchopt.alias.utils import _set_use_chain_flat + + +@helpers.parametrize( + init_value=[1.0, 1e-1], + decay_rate=[1e-2, 1e-3], + transition_begin=[1, 5], + transition_steps=[10, 100], + staircase=[False, True], + end_value=[0.0, None, 8e-1], +) +def test_exponential_decay( + init_value: float, + decay_rate: float, + transition_begin: int, + transition_steps: int | None, + staircase: bool, + end_value: float | None, +) -> None: + schedule = torchopt.schedule.exponential_decay( + init_value=init_value, + decay_rate=decay_rate, + transition_steps=transition_steps, + transition_begin=transition_begin, + staircase=staircase, + end_value=end_value, + ) + if end_value is not None: + clip_fn = max if decay_rate < 1.0 else min + for i in range(transition_begin, transition_steps): + lr = schedule(i) + if staircase: + lr_gt = init_value * (decay_rate ** np.floor((i - transition_begin) / transition_steps)) + else: + lr_gt = init_value * (decay_rate ** ((i - transition_begin) / transition_steps)) + if end_value is not None: + lr_gt = clip_fn(lr_gt, end_value) + assert np.allclose(lr, lr_gt) def test_linear_schedule() -> None: @@ -35,3 +81,78 @@ def test_linear_schedule() -> None: lr = schedule(i) lr_gt = init_value - gap_value * (i - transition_begin) / transition_steps assert np.allclose(lr, lr_gt) + + +@helpers.parametrize( + dtype=[torch.float64, torch.float32], + lr=[1e-2, 1e-3], + total_iters=[helpers.NUM_UPDATES, helpers.NUM_UPDATES * 2], + optimizers=[ + (torchopt.sgd, torch.optim.SGD, {}), + (torchopt.adam, torch.optim.Adam, {}), + (torchopt.adamw, torch.optim.AdamW, {}), + (torchopt.adagrad, torch.optim.Adagrad, {'eps': 1e-8}), + (torchopt.rmsprop, torch.optim.RMSprop, {}), + ], + inplace=[True, False], + weight_decay=[0.0, 1e-2], + use_chain_flat=[True, False], +) +def test_lr_linear_schedule( + dtype: torch.dtype, + lr: float, + total_iters: int, + optimizers: tuple[Callable, torch.optim.Optimizer, dict[str, Any]], + inplace: bool, + weight_decay: float, + use_chain_flat: bool, +) -> None: + _set_use_chain_flat(use_chain_flat) + + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + torchopt_optimizer, torch_optimizer, optimizer_kwargs = optimizers + + fmodel, params, buffers = functorch.make_functional_with_buffers(model) + optim = torchopt_optimizer( + torchopt.schedule.linear_schedule( + init_value=lr, + end_value=0.1 * lr, + transition_steps=total_iters, + transition_begin=0, + ), + weight_decay=weight_decay, + **optimizer_kwargs, + ) + optim_state = optim.init(params) + optim_ref = torch_optimizer( + model_ref.parameters(), + lr, + weight_decay=weight_decay, + **optimizer_kwargs, + ) + torch_scheduler = torch.optim.lr_scheduler.LinearLR( + optim_ref, + start_factor=1.0, + end_factor=0.1, + total_iters=total_iters, + ) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = fmodel(params, buffers, xs) + pred_ref = model_ref(xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + grads = torch.autograd.grad(loss, params, allow_unused=True) + updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) + params = torchopt.apply_updates(params, updates, inplace=inplace) + + optim_ref.zero_grad() + loss_ref.backward() + optim_ref.step() + torch_scheduler.step() + + helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) + _set_use_chain_flat(True) diff --git a/tests/test_transform.py b/tests/test_transform.py new file mode 100644 index 00000000..0a7bd498 --- /dev/null +++ b/tests/test_transform.py @@ -0,0 +1,60 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import torch + +import torchopt + + +def test_nan_to_num() -> None: + fn = torchopt.nan_to_num(0.0, 1.0, -1.0) + nan = torch.tensor(torch.nan) + inf = torch.tensor(torch.inf) + ninf = torch.tensor(-torch.inf) + updated, _ = fn.update(nan, None, inplace=False) + assert torch.equal(updated, torch.tensor(0.0)) + assert updated is not nan + + updated, _ = fn.update(inf, None, inplace=False) + assert torch.equal(updated, torch.tensor(1.0)) + assert updated is not inf + + updated, _ = fn.update(ninf, None, inplace=False) + assert torch.equal(updated, torch.tensor(-1.0)) + assert updated is not ninf + + updated, _ = fn.update(nan, None, inplace=True) + assert torch.equal(updated, torch.tensor(0.0)) + assert updated is nan + + updated, _ = fn.update(inf, None, inplace=True) + assert torch.equal(updated, torch.tensor(1.0)) + assert updated is inf + + updated, _ = fn.update(ninf, None, inplace=True) + assert torch.equal(updated, torch.tensor(-1.0)) + assert updated is ninf + + +def test_masked() -> None: + fn = torchopt.nan_to_num(0.0, 1.0, -1.0) + nan = torch.tensor(torch.nan) + updates = [nan, nan, nan] + + masked_fn = torchopt.transform.masked(fn, [True, False, True]) + state = masked_fn.init(updates) + + updates, _ = masked_fn.update(updates, state) + assert nan is updates[1] diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..57c35e47 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,142 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import operator + +import torch + +import torchopt +from torchopt import pytree + + +def test_stop_gradient() -> None: + x = torch.tensor(1.0, requires_grad=True) + y = 2 * x + assert y.grad_fn is not None + torchopt.stop_gradient(y) + assert y.grad_fn is None + fc = torch.nn.Linear(1, 1, False) + fc._parameters['weight'] = fc.weight * 2 + assert fc.weight.grad_fn is not None + torchopt.stop_gradient(fc) + assert fc.weight.grad_fn is None + + +def test_module_clone() -> None: + x = torch.tensor(1.0, requires_grad=True) + y = 2 * x + assert y.grad_fn is not None + z = torchopt.module_clone(y, by='reference') + assert z is y + z = torchopt.module_clone(x, by='copy') + assert z is not x + assert z.grad_fn.next_functions[0][0].variable is x + + z = torchopt.module_clone(y, by='deepcopy') + assert z is not y + assert z.grad_fn is None + assert torch.equal(z, y) + + x = torch.tensor(1.0, requires_grad=True) + y = torchopt.module_clone(x, by='reference', device='meta') + assert y.grad_fn.next_functions[0][0].variable is x + assert y.is_meta + + y = torchopt.module_clone(x, by='copy', device='meta') + assert y is not x + assert y.grad_fn.next_functions[0][0].next_functions[0][0].variable is x + assert y.is_meta + + y = torchopt.module_clone(x, by='deepcopy', device='meta') + assert y is not x + assert y.grad_fn is None + assert y.is_meta + + if torch.cuda.is_available(): + x = torch.tensor(1.0, requires_grad=True) + y = torchopt.module_clone(x, by='reference', device='cuda') + assert y.grad_fn.next_functions[0][0].variable is x + assert y.is_cuda + + y = torchopt.module_clone(x, by='copy', device='cuda') + assert y is not x + assert y.grad_fn.next_functions[0][0].next_functions[0][0].variable is x + assert y.is_cuda + + y = torchopt.module_clone(x, by='deepcopy', device='cuda') + assert y is not x + assert y.grad_fn is None + assert torch.equal(y.to(x.device), x) + assert y.is_cuda + + +def test_extract_state_dict(): # noqa: C901 + fc = torch.nn.Linear(1, 1) + state_dict = torchopt.extract_state_dict(fc, by='reference', device=torch.device('meta')) + for param_dict in state_dict.params: + for k, v in param_dict.items(): + assert v.is_meta + assert v.grad_fn.next_functions[0][0].variable is fc._parameters[k] + + state_dict = torchopt.extract_state_dict(fc, by='copy', device=torch.device('meta')) + for param_dict in state_dict.params: + for k, v in param_dict.items(): + assert v.is_meta + assert v.grad_fn.next_functions[0][0].next_functions[0][0].variable is fc._parameters[k] + + state_dict = torchopt.extract_state_dict(fc, by='deepcopy', device=torch.device('meta')) + for param_dict in state_dict.params: + for v in param_dict.values(): + assert v.is_meta + assert v.grad_fn is None + + state_dict = torchopt.extract_state_dict(fc, by='reference') + for param_dict in state_dict.params: + for k, v in param_dict.items(): + assert v is fc._parameters[k] + + state_dict = torchopt.extract_state_dict(fc, by='copy') + for param_dict in state_dict.params: + for k, v in param_dict.items(): + assert torch.equal(v, fc._parameters[k]) + assert v.grad_fn.next_functions[0][0].variable is fc._parameters[k] + + state_dict = torchopt.extract_state_dict(fc, by='deepcopy') + for param_dict in state_dict.params: + for k, v in param_dict.items(): + assert torch.equal(v, fc._parameters[k]) + assert v.grad_fn is None + + optim = torchopt.MetaAdam(fc, 1.0) + loss = fc(torch.ones(1, 1)).sum() + optim.step(loss) + state_dict = torchopt.extract_state_dict(optim) + same = pytree.tree_map(operator.is_, state_dict, tuple(optim.state_groups)) + assert all(pytree.tree_flatten(same)[0]) + + +def test_stop_gradient_for_state_dict() -> None: + fc = torch.nn.Linear(1, 1) + + state_dict = torchopt.extract_state_dict(fc, by='copy') + for param_dict in state_dict.params: + for k, v in param_dict.items(): + assert v.grad_fn.next_functions[0][0].variable is fc._parameters[k] + + torchopt.stop_gradient(state_dict) + for param_dict in state_dict.params: + for k, v in param_dict.items(): + assert v.grad_fn is None + assert torch.equal(v, fc._parameters[k]) diff --git a/tests/test_zero_order.py b/tests/test_zero_order.py new file mode 100644 index 00000000..65642559 --- /dev/null +++ b/tests/test_zero_order.py @@ -0,0 +1,172 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import functorch +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.types + +import helpers +import torchopt + + +BATCH_SIZE = 8 +NUM_UPDATES = 5 + + +class FcNet(nn.Module): + def __init__(self, dim, out): + super().__init__() + self.fc = nn.Linear(in_features=dim, out_features=out, bias=True) + + def forward(self, x): + return self.fc(x) + + +@helpers.parametrize( + lr=[1e-2, 1e-3], + method=['naive', 'forward', 'antithetic'], + sigma=[0.01, 0.1, 1], +) +def test_zero_order(lr: float, method: str, sigma: float) -> None: + helpers.seed_everything(42) + input_size = 32 + output_size = 1 + batch_size = BATCH_SIZE + coef = 0.1 + num_iterations = NUM_UPDATES + num_samples = 500 + + model = FcNet(input_size, output_size) + + fmodel, params = functorch.make_functional(model) + x = torch.randn(batch_size, input_size) * coef + y = torch.randn(batch_size, 1) * coef + distribution = torch.distributions.Normal(loc=0, scale=1) + + @torchopt.diff.zero_order( + distribution=distribution, + method=method, + argnums=0, + sigma=sigma, + num_samples=num_samples, + ) + def forward_process(params, fn, x, y): + y_pred = fn(params, x) + return F.mse_loss(y_pred, y) + + optimizer = torchopt.adam(lr=lr) + opt_state = optimizer.init(params) # init optimizer + + for _ in range(num_iterations): + loss = forward_process(params, fmodel, x, y) # compute loss + + grads = torch.autograd.grad(loss, params) # compute gradients + updates, opt_state = optimizer.update(grads, opt_state) # get updates + params = torchopt.apply_updates(params, updates) # update network parameters + + +@helpers.parametrize( + lr=[1e-2, 1e-3], + method=['naive', 'forward', 'antithetic'], + sigma=[0.01, 0.1, 1], +) +def test_zero_order_module(lr: float, method: str, sigma: float) -> None: + helpers.seed_everything(42) + input_size = 32 + output_size = 1 + batch_size = BATCH_SIZE + coef = 0.1 + num_iterations = NUM_UPDATES + num_samples = 500 + + class FcNetWithLoss( + torchopt.nn.ZeroOrderGradientModule, + method=method, + sigma=sigma, + num_samples=num_samples, + ): + def __init__(self, dim, out): + super().__init__() + self.net = FcNet(dim, out) + self.loss = nn.MSELoss() + self.distribution = torch.distributions.Normal(loc=0, scale=1) + + def forward(self, x, y): + return self.loss(self.net(x), y) + + def sample(self, sample_shape=torch.Size()): # noqa: B008 + return self.distribution.sample(sample_shape) + + x = torch.randn(batch_size, input_size) * coef + y = torch.randn(batch_size, 1) * coef + model_with_loss = FcNetWithLoss(input_size, output_size) + + optimizer = torchopt.Adam(model_with_loss.parameters(), lr=lr) + + for _ in range(num_iterations): + loss = model_with_loss(x, y) # compute loss + + optimizer.zero_grad() + loss.backward() # compute gradients + optimizer.step() # update network parameters + + +def test_module_enable_zero_order_gradients_twice() -> None: + class MyModule(torchopt.nn.ZeroOrderGradientModule): + def forward(self): + return torch.tensor(0.0) + + def sample(self, sample_shape): + return torch.tensor(0.0) + + from torchopt.diff.zero_order.nn.module import enable_zero_order_gradients + + with pytest.raises( + TypeError, + match='Zero-order gradient estimation is already enabled for the `forward` method.', + ): + enable_zero_order_gradients(MyModule) + + +def test_module_empty_parameters() -> None: + class MyModule(torchopt.nn.ZeroOrderGradientModule): + def forward(self): + return torch.tensor(0.0) + + def sample(self, sample_shape): + return torch.tensor(0.0) + + m = MyModule() + with pytest.raises(RuntimeError, match='The module has no parameters.'): + m() + + +def test_module_abstract_methods() -> None: + class MyModule1(torchopt.nn.ZeroOrderGradientModule): + def forward(self): + return torch.tensor(0.0) + + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + MyModule1() + + class MyModule2(torchopt.nn.ZeroOrderGradientModule): + def sample(self, sample_shape): + return torch.tensor(0.0) + + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + MyModule2() diff --git a/torchopt/_C/adam_op.pyi b/torchopt/_C/adam_op.pyi index 7b98a576..5ef572aa 100644 --- a/torchopt/_C/adam_op.pyi +++ b/torchopt/_C/adam_op.pyi @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,9 +13,7 @@ # limitations under the License. # ============================================================================== -# isort: off - -from typing import Tuple +# pylint: disable=all import torch @@ -28,10 +26,10 @@ def forward_( eps: float, eps_root: float, count: int, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ... -def forwardMu(updates: torch.Tensor, mu: torch.Tensor, b1: float) -> torch.Tensor: ... -def forwardNu(updates: torch.Tensor, nu: torch.Tensor, b2: float) -> torch.Tensor: ... -def forwardUpdates( +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ... +def forward_mu(updates: torch.Tensor, mu: torch.Tensor, b1: float) -> torch.Tensor: ... +def forward_nu(updates: torch.Tensor, nu: torch.Tensor, b2: float) -> torch.Tensor: ... +def forward_updates( new_mu: torch.Tensor, new_nu: torch.Tensor, b1: float, @@ -40,18 +38,25 @@ def forwardUpdates( eps_root: float, count: int, ) -> torch.Tensor: ... -def backwardMu( - dmu: torch.Tensor, updates: torch.Tensor, mu: torch.Tensor, b1: float -) -> Tuple[torch.Tensor, torch.Tensor]: ... -def backwardNu( - dnu: torch.Tensor, updates: torch.Tensor, nu: torch.Tensor, b2: float -) -> Tuple[torch.Tensor, torch.Tensor]: ... -def backwardUpdates( +def backward_mu( + dmu: torch.Tensor, + updates: torch.Tensor, + mu: torch.Tensor, + b1: float, +) -> tuple[torch.Tensor, torch.Tensor]: ... +def backward_nu( + dnu: torch.Tensor, + updates: torch.Tensor, + nu: torch.Tensor, + b2: float, +) -> tuple[torch.Tensor, torch.Tensor]: ... +def backward_updates( dupdates: torch.Tensor, updates: torch.Tensor, new_mu: torch.Tensor, new_nu: torch.Tensor, b1: float, b2: float, + eps_root: float, count: int, -) -> Tuple[torch.Tensor, torch.Tensor]: ... +) -> tuple[torch.Tensor, torch.Tensor]: ... diff --git a/torchopt/__init__.py b/torchopt/__init__.py index ab7a5a4d..830072e3 100644 --- a/torchopt/__init__.py +++ b/torchopt/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,51 +14,117 @@ # ============================================================================== """TorchOpt: a high-performance optimizer library built upon PyTorch.""" -from torchopt._src import accelerated_op_available, clip, combine, hook, schedule, visual -from torchopt._src.alias import adam, adamw, rmsprop, sgd -from torchopt._src.clip import clip_grad_norm -from torchopt._src.combine import chain -from torchopt._src.optimizer import SGD, Adam, AdamW, Optimizer, RMSProp, RMSprop, meta -from torchopt._src.optimizer.meta import ( +from torchopt import ( + accelerated_op, + alias, + base, + clip, + combine, + diff, + distributed, + hook, + linalg, + linear_solve, + nn, + optim, + pytree, + schedule, + typing, + visual, +) +from torchopt.accelerated_op import is_available as accelerated_op_available +from torchopt.alias import adadelta, adagrad, adam, adamax, adamw, radam, rmsprop, sgd +from torchopt.clip import clip_grad_norm +from torchopt.combine import chain +from torchopt.hook import register_hook +from torchopt.optim import ( + SGD, + AdaDelta, + Adadelta, + AdaGrad, + Adagrad, + Adam, + AdaMax, + Adamax, + AdamW, + Optimizer, + RAdam, + RMSProp, + RMSprop, +) +from torchopt.optim.func import FuncOptimizer +from torchopt.optim.meta import ( + MetaAdaDelta, + MetaAdadelta, + MetaAdaGrad, + MetaAdagrad, MetaAdam, + MetaAdaMax, + MetaAdamax, MetaAdamW, MetaOptimizer, + MetaRAdam, MetaRMSProp, MetaRMSprop, MetaSGD, ) -from torchopt._src.update import apply_updates -from torchopt._src.utils import extract_state_dict, recover_state_dict, stop_gradient +from torchopt.transform import nan_to_num +from torchopt.update import apply_updates +from torchopt.utils import ( + extract_state_dict, + module_clone, + module_detach_, + recover_state_dict, + stop_gradient, +) from torchopt.version import __version__ __all__ = [ - 'accelerated_op_available', - 'clip', - 'combine', - 'hook', - 'schedule', - 'visual', - 'adam', - 'adamw', - 'rmsprop', - 'sgd', - 'clip_grad_norm', - 'chain', - 'Optimizer', 'SGD', + 'AdaDelta', + 'AdaGrad', + 'AdaMax', + 'Adadelta', + 'Adagrad', 'Adam', 'AdamW', - 'RMSProp', - 'RMSprop', - 'MetaOptimizer', - 'MetaSGD', + 'Adamax', + 'FuncOptimizer', + 'MetaAdaDelta', + 'MetaAdaGrad', + 'MetaAdaMax', + 'MetaAdadelta', + 'MetaAdagrad', 'MetaAdam', 'MetaAdamW', + 'MetaAdamax', + 'MetaOptimizer', + 'MetaRAdam', 'MetaRMSProp', 'MetaRMSprop', + 'MetaSGD', + 'Optimizer', + 'RAdam', + 'RMSProp', + 'RMSprop', + 'accelerated_op_available', + 'adadelta', + 'adagrad', + 'adam', + 'adamax', + 'adamw', 'apply_updates', + 'chain', + 'clip_grad_norm', 'extract_state_dict', + 'module_clone', + 'module_detach_', + 'nan_to_num', + 'radam', 'recover_state_dict', + 'register_hook', + 'rmsprop', + 'sgd', 'stop_gradient', ] diff --git a/torchopt/_src/alias.py b/torchopt/_src/alias.py deleted file mode 100644 index 40b2e92d..00000000 --- a/torchopt/_src/alias.py +++ /dev/null @@ -1,468 +0,0 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -# This file is modified from: -# https://github.com/deepmind/optax/blob/master/optax/_src/alias.py -# ============================================================================== -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -# pylint: disable=invalid-name - -from typing import Any, Callable, Optional, Tuple, Union - -from torchopt._src import base, combine, transform -from torchopt._src.typing import ScalarOrSchedule - - -def _flip_sign_and_weight_decay(weight_decay: float = 0.0, maximize=False): - if not 0.0 <= weight_decay: # pylint: disable=unneeded-not - raise ValueError(f'Invalid weight_decay value: {weight_decay}') - - if not maximize and weight_decay == 0.0: - return base.identity() - - def init_fn(params): # pylint: disable=unused-argument - return base.EmptyState() - - if not maximize: # gradient descent - - def update_fn(updates, state, *, params=None, inplace=True): - assert params is not None, ( - 'Parameters are required for weight decay. ' - 'Call `update(updates, state, params=params)` instead.' - ) - - if inplace: - - def f(g, p): - if g is not None: - if g.requires_grad: - return g.add_(p, alpha=weight_decay) - return g.add_(p.data, alpha=weight_decay) - return None - - else: - - def f(g, p): - return g.add(p, alpha=weight_decay) if g is not None else None - - updates = transform.map_flattened(f, updates, params) - return updates, state - - else: # gradient ascent - - if weight_decay == 0.0: - # pylint: disable-next=unused-argument - def update_fn(updates, state, *, params=None, inplace=True): - if inplace: - - def f(g): - return g.neg_() if g is not None else None - - else: - - def f(g): - return g.neg() if g is not None else None - - updates = transform.map_flattened(f, updates) - return updates, state - - else: - - def update_fn(updates, state, *, params=None, inplace=True): - assert params is not None, ( - 'Parameters are required for weight decay. ' - 'Call `update(updates, state, params=params)` instead.' - ) - - if inplace: - - def f(g, p): - if g is not None: - if g.requires_grad: - return g.neg_().add_(p, alpha=weight_decay) - return g.neg_().add_(p.data, alpha=weight_decay) - return None - - else: - - def f(g, p): - return g.neg().add_(p, alpha=weight_decay) if g is not None else None - - updates = transform.map_flattened(f, updates, params) - return updates, state - - return base.GradientTransformation(init_fn, update_fn) - - -def _scale_by_neg_lr(lr: ScalarOrSchedule): - if not (callable(lr) or 0.0 <= lr): - raise ValueError(f'Invalid learning rate: {lr}') - - if callable(lr): - - def schedule_wrapper(count): - def f(scaled_lr): - return -scaled_lr - - return transform.map_flattened(f, lr(count)) # type: ignore[operator] - - return transform._scale_by_schedule( # pylint: disable=protected-access - schedule_wrapper, already_flattened=True - ) - return transform._scale(-lr, already_flattened=True) # pylint: disable=protected-access - - -# pylint: disable-next=too-many-arguments -def adam( - lr: ScalarOrSchedule = 1e-3, - betas: Tuple[float, float] = (0.9, 0.999), - eps: float = 1e-8, - weight_decay: float = 0.0, - *, - eps_root: float = 0.0, - moment_requires_grad: bool = False, - maximize: bool = False, - use_accelerated_op: bool = False, -) -> base.GradientTransformation: - """The functional Adam optimizer. - - Adam is an SGD variant with learning rate adaptation. The *learning rate* used for each weight - is computed from estimates of first- and second-order moments of the gradients (using suitable - exponential moving averages). - - References: - - Kingma et al, 2014: https://arxiv.org/abs/1412.6980 - - Args: - lr: (default: :const:`1e-3`) - This is a fixed global scaling factor. - betas: (default: :const:`(0.9, 0.999)`) - Coefficients used for computing running averages of gradient and its square. - eps: (default: :const:`1e-8`) - A small constant applied to denominator outside of the square root (as in the Adam - paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - eps_root: (default: :data:`0.0`) - A small constant applied to denominator inside the square root (as in RMSProp), to avoid - dividing by zero when rescaling. This is needed for example when computing - (meta-)gradients through Adam. - moment_requires_grad: (default: :data:`False`) - If :data:`True` the momentums will be created with flag ``requires_grad=True``, this - flag is often used in Meta-Learning algorithms. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - use_accelerated_op: (default: :data:`False`) - If :data:`True` use our implemented fused operator. - - Returns: - The corresponding :class:`GradientTransformation` instance. - """ - b1, b2 = betas - # pylint: disable=unneeded-not - if not (callable(lr) or 0.0 <= lr): - raise ValueError(f'Invalid learning rate: {lr}') - if not 0.0 <= eps: - raise ValueError(f'Invalid epsilon value: {eps}') - if not 0.0 <= b1 < 1.0: - raise ValueError(f'Invalid beta parameter at index 0: {b1}') - if not 0.0 <= b2 < 1.0: - raise ValueError(f'Invalid beta parameter at index 1: {b2}') - if not 0.0 <= weight_decay: - raise ValueError(f'Invalid weight_decay value: {weight_decay}') - # pylint: enable=unneeded-not - - if use_accelerated_op: - adam_scaler = transform._scale_by_accelerated_adam # pylint: disable=protected-access - else: - adam_scaler = transform._scale_by_adam # pylint: disable=protected-access - - return transform.with_flattened_tree( - combine.chain( - _flip_sign_and_weight_decay(weight_decay=weight_decay, maximize=maximize), - adam_scaler( - b1=b1, - b2=b2, - eps=eps, - eps_root=eps_root, - moment_requires_grad=moment_requires_grad, - already_flattened=True, - ), - _scale_by_neg_lr(lr), - ) - ) - - -# pylint: disable-next=too-many-arguments -def adamw( - lr: ScalarOrSchedule = 1e-3, - betas: Tuple[float, float] = (0.9, 0.999), - eps: float = 1e-8, - weight_decay: float = 1e-2, - *, - eps_root: float = 0.0, - mask: Optional[Union[Any, Callable[['base.Params'], Any]]] = None, - moment_requires_grad: bool = False, - maximize: bool = False, - use_accelerated_op: bool = False, -) -> base.GradientTransformation: - """Adam with weight decay regularization. - - AdamW uses weight decay to regularize learning towards small weights, as - this leads to better generalization. In SGD you can also use L2 regularization - to implement this as an additive loss term, however L2 regularization - does not behave as intended for adaptive gradient algorithms such as Adam. - - References: - - Loshchilov et al, 2019: https://arxiv.org/abs/1711.05101 - - Args: - lr: (default: :const:`1e-3`) - This is a fixed global scaling factor. - betas: (default: :const:`(0.9, 0.999)`) - Coefficients used for computing running averages of gradient and its square. - eps: (default: :const:`1e-8`) - A small constant applied to denominator outside of the square root (as in the Adam - paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`1e-2`) - Strength of the weight decay regularization. Note that this weight decay is multiplied - with the learning rate. This is consistent with other frameworks such as PyTorch, but - different from (Loshchilov et al, 2019) where the weight decay is only multiplied with - the "schedule multiplier", but not the base learning rate. - eps_root: (default: :data:`0.0`) - A small constant applied to denominator inside the square root (as in RMSProp), to avoid - dividing by zero when rescaling. This is needed for example when computing - (meta-)gradients through Adam. - mask: (default: :data:`None`) - A tree with same structure as (or a prefix of) the params PyTree, or a Callable that - returns such a pytree given the params/updates. The leaves should be booleans, - :data:`True` for leaves/subtrees you want to apply the weight decay to, and - :data:`False` for those you want to skip. Note that the Adam gradient - transformations are applied to all parameters. - moment_requires_grad: (default: :data:`False`) - If :data:`True` the momentums will be created with flag ``requires_grad=True``, this - flag is often used in Meta-Learning algorithms. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - use_accelerated_op: (default: :data:`False`) - If :data:`True` use our implemented fused operator. - - Returns: - The corresponding :class:`GradientTransformation` instance. - """ - b1, b2 = betas - # pylint: disable=unneeded-not - if not (callable(lr) or 0.0 <= lr): - raise ValueError(f'Invalid learning rate: {lr}') - if not 0.0 <= eps: - raise ValueError(f'Invalid epsilon value: {eps}') - if not 0.0 <= b1 < 1.0: - raise ValueError(f'Invalid beta parameter at index 0: {b1}') - if not 0.0 <= b2 < 1.0: - raise ValueError(f'Invalid beta parameter at index 1: {b2}') - if not 0.0 <= weight_decay: - raise ValueError(f'Invalid weight_decay value: {weight_decay}') - # pylint: enable=unneeded-not - - if use_accelerated_op: - adam_scaler = transform._scale_by_accelerated_adam # pylint: disable=protected-access - else: - adam_scaler = transform._scale_by_adam # pylint: disable=protected-access - - return transform.with_flattened_tree( - combine.chain( - _flip_sign_and_weight_decay(weight_decay=0.0, maximize=maximize), - adam_scaler( - b1=b1, - b2=b2, - eps=eps, - eps_root=eps_root, - moment_requires_grad=moment_requires_grad, - already_flattened=True, - ), - transform._add_decayed_weights( # pylint: disable=protected-access - weight_decay=weight_decay, - mask=mask, - already_flattened=True, - ), - _scale_by_neg_lr(lr), - ) - ) - - -# pylint: disable-next=too-many-arguments -def rmsprop( - lr: ScalarOrSchedule = 1e-2, - alpha: float = 0.9, - eps: float = 1e-8, - weight_decay: float = 0.0, - momentum: float = 0.0, - centered: bool = False, - *, - initial_scale: float = 0.0, - nesterov: bool = False, - maximize: bool = False, -) -> base.GradientTransformation: - """The functional version of the RMSProp optimizer. - - RMSProp is an SGD variant with learning rate adaptation. The *learning rate* used for each - weight is scaled by a suitable estimate of the magnitude of the gradients on previous steps. - Several variants of RMSProp can be found in the literature. This alias provides an easy to - configure RMSProp optimizer that can be used to switch between several of these variants. - - References: - - Tieleman and Hinton, 2012: http://www.cs.toronto.edu/~hinton/coursera/lecture6/lec6.pdf - - Graves, 2013: https://arxiv.org/abs/1308.0850 - - Args: - lr: (default: :const:`1e-2`) - This is a fixed global scaling factor. - alpha: (default: :const:`0.99`) - Smoothing constant, the decay used to track the magnitude of previous gradients. - eps: (default: :const:`1e-8`) - A small numerical constant to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - momentum: (default: :const:`0.0`) - The decay rate used by the momentum term. The momentum is not used when it is set to - :const:`0.0`. - centered: (default: :data:`False`) - If :data:`True`, use the variance of the past gradients to rescale the latest - gradients. - initial_scale: (default: :data:`0.0`) - Initialization of accumulators tracking the magnitude of previous updates. PyTorch - uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When reproducing results from a - paper, verify the value used by the authors. - nesterov: (default: :data:`False`) - Whether to use Nesterov momentum. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - - Returns: - The corresponding :class:`GradientTransformation` instance. - """ - # pylint: disable=unneeded-not - if not (callable(lr) or 0.0 <= lr): - raise ValueError(f'Invalid learning rate: {lr}') - if not 0.0 <= alpha: - raise ValueError(f'Invalid alpha value: {alpha}') - if not 0.0 <= eps: - raise ValueError(f'Invalid epsilon value: {eps}') - if not 0.0 <= momentum: - raise ValueError(f'Invalid momentum value: {momentum}') - if not 0.0 <= weight_decay: - raise ValueError(f'Invalid weight_decay value: {weight_decay}') - # pylint: enable=unneeded-not - - if centered: - rmsprop_scaler = transform._scale_by_stddev # pylint: disable=protected-access - else: - rmsprop_scaler = transform._scale_by_rms # pylint: disable=protected-access - - return transform.with_flattened_tree( - combine.chain( - _flip_sign_and_weight_decay(weight_decay=weight_decay, maximize=maximize), - rmsprop_scaler( - alpha=alpha, - eps=eps, - initial_scale=initial_scale, - already_flattened=True, - ), - transform._trace( # pylint: disable=protected-access - momentum=momentum, - nesterov=nesterov, - already_flattened=True, - ), - _scale_by_neg_lr(lr), - ) - ) - - -def sgd( - lr: ScalarOrSchedule, - momentum: float = 0.0, - dampening: float = 0.0, - weight_decay: float = 0.0, - nesterov: bool = False, - *, - moment_requires_grad: bool = False, - maximize: bool = False, -) -> base.GradientTransformation: - """The functional version of the canonical Stochastic Gradient Descent optimizer. - - This implements stochastic gradient descent. It also includes support for momentum, and nesterov - acceleration, as these are standard practice when using stochastic gradient descent to train - deep neural networks. - - References: - - Sutskever et al, 2013: http://proceedings.mlr.press/v28/sutskever13.pdf - - Args: - lr: This is a fixed global scaling factor. - momentum: (default: :const:`0.0`) - The decay rate used by the momentum term. The momentum is not used when it is set to - :const:`0.0`. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - dampening: (default: :const:`0.0`) - Dampening for momentum. - nesterov: (default: :data:`False`) - Whether to use Nesterov momentum. - moment_requires_grad: (default: :data:`False`) - If :data:`True` the momentums will be created with flag ``requires_grad=True``, this - flag is often used in Meta-Learning algorithms. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - - Returns: - The corresponding :class:`GradientTransformation` instance. - """ - # pylint: disable=unneeded-not - if not (callable(lr) or 0.0 <= lr): - raise ValueError(f'Invalid learning rate: {lr}') - if not 0.0 <= momentum: - raise ValueError(f'Invalid momentum value: {momentum}') - if not 0.0 <= weight_decay: - raise ValueError(f'Invalid weight_decay value: {weight_decay}') - if nesterov and (momentum <= 0.0 or dampening != 0.0): - raise ValueError('Nesterov momentum requires a momentum and zero dampening') - # pylint: enable=unneeded-not - - return transform.with_flattened_tree( - combine.chain( - _flip_sign_and_weight_decay(weight_decay=weight_decay, maximize=maximize), - transform._trace( # pylint: disable=protected-access - momentum=momentum, - dampening=dampening, - nesterov=nesterov, - moment_requires_grad=moment_requires_grad, - already_flattened=True, - ), - _scale_by_neg_lr(lr), - ) - ) diff --git a/torchopt/_src/clip.py b/torchopt/_src/clip.py deleted file mode 100644 index 31d54797..00000000 --- a/torchopt/_src/clip.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -# This file is modified from: -# https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/clip_grad.py -# ============================================================================== - -import torch -from torch._six import inf - -from torchopt._src import base -from torchopt._src.utils import pytree - - -ClipState = base.EmptyState - - -def clip_grad_norm( - max_norm: float, norm_type: float = 2.0, error_if_nonfinite: bool = False -) -> base.GradientTransformation: - """Clips gradient norm of an iterable of parameters. - - Args: - max_delta: The maximum absolute value for each element in the update. - - Returns: - An ``(init_fn, update_fn)`` tuple. - """ - - def init_fn(params): # pylint: disable=unused-argument - return ClipState() - - def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument - available_updates = [] - for g in updates: - if g is not None: - available_updates.append(g) - if len(available_updates) == 0: - return torch.tensor(0.0) - device = available_updates[0].device - with torch.no_grad(): - if norm_type == inf: - norms = [p.abs().max().to(device) for p in available_updates] - total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) - else: - total_norm = torch.norm( - torch.stack([torch.norm(p, norm_type).to(device) for p in available_updates]), - norm_type, - ) - if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): - raise RuntimeError( - f'The total norm of order {norm_type} for gradients from `parameters` is ' - f'non-finite, so it cannot be clipped. To disable this error and scale the ' - f'gradients by the non-finite norm anyway, set `error_if_nonfinite=False`' - ) - clip_coef = max_norm / (float(total_norm) + 1e-6) - # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but - # doing so avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device - # synchronization when the gradients do not reside in CPU memory. - clip_coef_clamped = min(clip_coef, 1.0) - if inplace: - - def f(g): - return g.mul_(clip_coef_clamped) if g is not None else None - - else: - - def f(g): - return g.mul(clip_coef_clamped) if g is not None else None - - new_updates = pytree.tree_map(f, updates) - return new_updates, state - - return base.GradientTransformation(init_fn, update_fn) diff --git a/torchopt/_src/hook.py b/torchopt/_src/hook.py deleted file mode 100644 index 305c34ca..00000000 --- a/torchopt/_src/hook.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import torch - -from torchopt._src.base import EmptyState, GradientTransformation -from torchopt._src.utils import pytree - - -def zero_nan_hook(g: torch.Tensor) -> torch.Tensor: - """Registers a zero nan hook to replace nan with zero.""" - return torch.where(torch.isnan(g), torch.zeros_like(g), g) - - -def register_hook(hook) -> GradientTransformation: - """Stateless identity transformation that leaves input gradients untouched. - - This function passes through the *gradient updates* unchanged. - - Returns: - An ``(init_fn, update_fn)`` tuple. - """ - - def init_fn(params): # pylint: disable=unused-argument - return EmptyState() - - def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument - def f(g): - return g.register_hook(hook) if g is not None else None - - pytree.tree_map(f, updates) - return updates, state - - return GradientTransformation(init_fn, update_fn) diff --git a/torchopt/_src/optimizer/adam.py b/torchopt/_src/optimizer/adam.py deleted file mode 100644 index 6776408e..00000000 --- a/torchopt/_src/optimizer/adam.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from typing import Iterable, Tuple - -import torch - -from torchopt._src.alias import adam -from torchopt._src.optimizer.base import Optimizer -from torchopt._src.typing import ScalarOrSchedule - - -class Adam(Optimizer): - """The classic Adam optimizer. - - See Also: - - The functional Adam optimizer: :func:`torchopt.adam`. - - The differentiable meta-Adam optimizer: :class:`torchopt.MetaAdam`. - """ - - # pylint: disable-next=too-many-arguments - def __init__( - self, - params: Iterable[torch.Tensor], - lr: ScalarOrSchedule, - betas: Tuple[float, float] = (0.9, 0.999), - eps: float = 1e-8, - weight_decay: float = 0.0, - *, - eps_root: float = 0.0, - maximize: bool = False, - use_accelerated_op: bool = False, - ): - r"""The :meth:`init` function. - - Args: - params: (iterable of torch.Tensor) - An iterable of :class:`torch.Tensor`\s. Specifies what tensors should be optimized. - lr: (default: :const:`1e-3`) - This is a fixed global scaling factor. - betas: (default: :const:`(0.9, 0.999)`) - Coefficients used for computing running averages of gradient and its square. - eps: (default: :const:`1e-8`) - A small constant applied to denominator outside of the square root (as in the Adam - paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - eps_root: (default: :data:`0.0`) - A small constant applied to denominator inside the square root (as in RMSProp), to - avoid dividing by zero when rescaling. This is needed for example when computing - (meta-)gradients through Adam. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - use_accelerated_op: (default: :data:`False`) - If :data:`True` use our implemented fused operator. - """ - super().__init__( - params, - adam( - lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, - eps_root=eps_root, - moment_requires_grad=False, - maximize=maximize, - use_accelerated_op=use_accelerated_op, - ), - ) diff --git a/torchopt/_src/optimizer/adamw.py b/torchopt/_src/optimizer/adamw.py deleted file mode 100644 index 886cd77a..00000000 --- a/torchopt/_src/optimizer/adamw.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from typing import Any, Callable, Iterable, Optional, Tuple, Union - -import torch - -from torchopt._src import base # pylint: disable=unused-import -from torchopt._src.alias import adamw -from torchopt._src.optimizer.base import Optimizer -from torchopt._src.typing import ScalarOrSchedule - - -class AdamW(Optimizer): - """The classic AdamW optimizer. - - See Also: - - The functional AdamW optimizer: :func:`torchopt.adamw`. - - The differentiable meta-AdamW optimizer: :class:`torchopt.MetaAdamW`. - """ - - # pylint: disable-next=too-many-arguments - def __init__( - self, - params: Iterable[torch.Tensor], - lr: ScalarOrSchedule = 1e-3, - betas: Tuple[float, float] = (0.9, 0.999), - eps: float = 1e-8, - weight_decay: float = 1e-2, - *, - eps_root: float = 0.0, - mask: Optional[Union[Any, Callable[['base.Params'], Any]]] = None, - maximize: bool = False, - use_accelerated_op: bool = False, - ): - r"""The :meth:`init` function. - - Args: - params: (iterable of torch.Tensor) - An iterable of :class:`torch.Tensor`\s. Specifies what tensors should be optimized. - lr: (default: :const:`1e-3`) - This is a fixed global scaling factor. - betas: (default: :const:`(0.9, 0.999)`) - Coefficients used for computing running averages of gradient and its square. - eps: (default: :const:`1e-8`) - A small constant applied to denominator outside of the square root (as in the Adam - paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`1e-2`) - Strength of the weight decay regularization. Note that this weight decay is - multiplied with the learning rate. This is consistent with other frameworks such as - PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only - multiplied with the "schedule multiplier", but not the base learning rate. - eps_root: (default: :data:`0.0`) - A small constant applied to denominator inside the square root (as in RMSProp), to - avoid dividing by zero when rescaling. This is needed for example when computing - (meta-)gradients through Adam. - mask: (default: :data:`None`) - A tree with same structure as (or a prefix of) the params PyTree, or a Callable that - returns such a pytree given the params/updates. The leaves should be booleans, - :data:`True` for leaves/subtrees you want to apply the weight decay to, and - :data:`False` for those you want to skip. Note that the Adam gradient - transformations are applied to all parameters. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - use_accelerated_op: (default: :data:`False`) - If :data:`True` use our implemented fused operator. - """ - super().__init__( - params, - adamw( - lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, - eps_root=eps_root, - mask=mask, - moment_requires_grad=False, - maximize=maximize, - use_accelerated_op=use_accelerated_op, - ), - ) diff --git a/torchopt/_src/optimizer/base.py b/torchopt/_src/optimizer/base.py deleted file mode 100644 index 99e18b36..00000000 --- a/torchopt/_src/optimizer/base.py +++ /dev/null @@ -1,119 +0,0 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from typing import Iterable - -import torch - -from torchopt._src.base import GradientTransformation -from torchopt._src.update import apply_updates -from torchopt._src.utils import pytree - - -class Optimizer: - """A base class for classic optimizers that similar to :class:`torch.optim.Optimizer`.""" - - def __init__(self, params: Iterable[torch.Tensor], impl: GradientTransformation): - r"""The :meth:`init` function. - - Args: - params (iterable of torch.Tensor): An iterable of :class:`torch.Tensor`\s. Specifies - what tensors should be optimized. - impl (GradientTransformation): A low level optimizer function, it could be a optimizer - function provided by ``alias.py`` or a customized ``chain`` provided by - ``combine.py``. - Note that using ``Optimizer(sgd())`` or ``Optimizer(chain(sgd()))`` is equivalent to - :class:`torchopt.SGD`. - """ - self.impl = impl - self.param_groups = [] # type: ignore - self.param_tree_groups = [] # type: ignore - self.state_groups = [] # type: ignore - - if not isinstance(params, list): - params = list(params) - self.add_param_group(params) - - def zero_grad(self, set_to_none: bool = False): - r"""Sets the gradients of all optimized :class:`torch.Tensor`\s to zero. - - The behavior is similar to :meth:`torch.optim.Optimizer.zero_grad`. - - Args: - set_to_none (bool): Instead of setting to zero, set the ``grads`` to :data:`None`. - """ - for group in self.param_groups: - if set_to_none: - - def f(p): - p.grad = None - - else: - - def f(p): - if p.grad is None: - return - if p.grad.grad_fn is not None: - p.grad.detach_() - else: - p.grad.requires_grad_(False) - p.grad.zero_() - - pytree.tree_map(f, group) - - def state_dict(self): - """Returns the state of the optimizer.""" - return self.state_groups - - def load_state_dict(self, state_dict): - """Loads the optimizer state. - - Args: - state_dict (dict): Optimizer state. Should be an object returned from a call to - :meth:`state_dict`. - """ - self.state_groups = state_dict - - def step(self, closure=None): - """Performs a single optimization step. - - The behavior is similar to :meth:`torch.optim.Optimizer.step`. - - Args: - closure (callable, optional): A closure that reevaluates the model and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - def f(p): - return p.grad - - for i, (params, state) in enumerate(zip(self.param_groups, self.state_groups)): - grads = pytree.tree_map(f, params) - updates, new_state = self.impl.update(grads, state, params=params, inplace=True) - self.param_groups[i] = apply_updates(params, updates, inplace=True) - self.state_groups[i] = new_state - - return loss - - def add_param_group(self, params): - """Add a param group to the optimizer's :attr:`param_groups`.""" - params, params_tree = pytree.tree_flatten(params) - params = tuple(params) - self.param_groups.append(params) - self.param_tree_groups.append(params_tree) - self.state_groups.append(self.impl.init(params)) diff --git a/torchopt/_src/optimizer/meta/adam.py b/torchopt/_src/optimizer/meta/adam.py deleted file mode 100644 index 6b76f959..00000000 --- a/torchopt/_src/optimizer/meta/adam.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from typing import Tuple - -import torch.nn as nn - -from torchopt._src.alias import adam -from torchopt._src.optimizer.meta.base import MetaOptimizer -from torchopt._src.typing import ScalarOrSchedule - - -class MetaAdam(MetaOptimizer): - """The differentiable Adam optimizer. - - See Also: - - The functional Adam optimizer: :func:`torchopt.adam`. - - The classic Adam optimizer: :class:`torchopt.Adam`. - """ - - # pylint: disable-next=too-many-arguments - def __init__( - self, - net: nn.Module, - lr: ScalarOrSchedule = 1e-3, - betas: Tuple[float, float] = (0.9, 0.999), - eps: float = 1e-8, - weight_decay: float = 0.0, - *, - eps_root: float = 0.0, - moment_requires_grad: bool = True, - maximize: bool = False, - use_accelerated_op: bool = False, - ): - """The :meth:`init` function. - - Args: - net: (nn.Module) - A network whose parameters should be optimized. - lr: (default: :const:`1e-3`) - This is a fixed global scaling factor. - betas: (default: :const:`(0.9, 0.999)`) - Coefficients used for computing running averages of gradient and its square. - eps: (default: :const:`1e-8`) - A small constant applied to denominator outside of the square root (as in the Adam - paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - eps_root: (default: :data:`0.0`) - A small constant applied to denominator inside the square root (as in RMSProp), to - avoid dividing by zero when rescaling. This is needed for example when computing - (meta-)gradients through Adam. - moment_requires_grad: (default: :data:`True`) - If :data:`True` the momentums will be created with flag ``requires_grad=True``, this - flag is often used in Meta-Learning algorithms. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - use_accelerated_op: (default: :data:`False`) - If :data:`True` use our implemented fused operator. - """ - super().__init__( - net, - adam( - lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, - eps_root=eps_root, - moment_requires_grad=moment_requires_grad, - maximize=maximize, - use_accelerated_op=use_accelerated_op, - ), - ) diff --git a/torchopt/_src/optimizer/meta/adamw.py b/torchopt/_src/optimizer/meta/adamw.py deleted file mode 100644 index c38f3c5c..00000000 --- a/torchopt/_src/optimizer/meta/adamw.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from typing import Any, Callable, Optional, Tuple, Union - -import torch.nn as nn - -from torchopt._src import base # pylint: disable=unused-import -from torchopt._src.alias import adamw -from torchopt._src.optimizer.meta.base import MetaOptimizer -from torchopt._src.typing import ScalarOrSchedule - - -class MetaAdamW(MetaOptimizer): - """The differentiable AdamW optimizer. - - See Also: - - The functional AdamW optimizer: :func:`torchopt.adamw`. - - The classic AdamW optimizer: :class:`torchopt.AdamW`. - """ - - # pylint: disable-next=too-many-arguments - def __init__( - self, - net: nn.Module, - lr: ScalarOrSchedule = 1e-3, - betas: Tuple[float, float] = (0.9, 0.999), - eps: float = 1e-8, - weight_decay: float = 1e-2, - *, - eps_root: float = 0.0, - mask: Optional[Union[Any, Callable[['base.Params'], Any]]] = None, - moment_requires_grad: bool = False, - maximize: bool = False, - use_accelerated_op: bool = False, - ): - """The :meth:`init` function. - - Args: - net: (nn.Module) - A network whose parameters should be optimized. - lr: (default: :const:`1e-3`) - This is a fixed global scaling factor. - betas: (default: :const:`(0.9, 0.999)`) - Coefficients used for computing running averages of gradient and its square. - eps: (default: :const:`1e-8`) - A small constant applied to denominator outside of the square root (as in the Adam - paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`1e-2`) - Strength of the weight decay regularization. Note that this weight decay is - multiplied with the learning rate. This is consistent with other frameworks such as - PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only - multiplied with the "schedule multiplier", but not the base learning rate. - eps_root: (default: :data:`0.0`) - A small constant applied to denominator inside the square root (as in RMSProp), to - avoid dividing by zero when rescaling. This is needed for example when computing - (meta-)gradients through Adam. - mask: (default: :data:`None`) - A tree with same structure as (or a prefix of) the params PyTree, or a Callable that - returns such a pytree given the params/updates. The leaves should be booleans, - :data:`True` for leaves/subtrees you want to apply the weight decay to, and - :data:`False` for those you want to skip. Note that the Adam gradient - transformations are applied to all parameters. - moment_requires_grad: (default: :data:`False`) - If :data:`True` the momentums will be created with flag ``requires_grad=True``, this - flag is often used in Meta-Learning algorithms. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - use_accelerated_op: (default: :data:`False`) - If :data:`True` use our implemented fused operator. - """ - super().__init__( - net, - adamw( - lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, - eps_root=eps_root, - mask=mask, - moment_requires_grad=moment_requires_grad, - maximize=maximize, - use_accelerated_op=use_accelerated_op, - ), - ) diff --git a/torchopt/_src/optimizer/meta/base.py b/torchopt/_src/optimizer/meta/base.py deleted file mode 100644 index eb5a70b1..00000000 --- a/torchopt/_src/optimizer/meta/base.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import torch -import torch.nn as nn - -from torchopt._src.base import GradientTransformation -from torchopt._src.update import apply_updates -from torchopt._src.utils import pytree - - -class MetaOptimizer: - """The base class for high-level differentiable optimizers.""" - - def __init__(self, net: nn.Module, impl: GradientTransformation): - """The :meth:`init` function. - - Args: - net: (nn.Module) - A network whose parameters should be optimized. - impl: (GradientTransformation) - A low level optimizer function, it could be a optimizer function provided by - ``alias.py`` or a customized ``chain`` provided by ``combine.py``. - Note that using ``MetaOptimizer(sgd(moment_requires_grad=True))`` or - ``MetaOptimizer(chain(sgd(moment_requires_grad=True)))`` is equivalent to - :class:`torchopt.MetaSGD`. - """ - self.impl = impl - self.param_containers_groups = [] # type: ignore - self.state_groups = [] # type: ignore - - self.add_param_group(net) - - def step(self, loss: torch.Tensor): - """Compute the gradients of the loss to the network parameters and update network parameters. - - Graph of the derivative will be constructed, allowing to compute higher order derivative - products. We use the differentiable optimizer (pass argument ``inplace=False``) to scale the - gradients and update the network parameters without modifying tensors in-place. - - Args: - loss: (torch.Tensor) - The loss that is used to compute the gradients to the network parameters. - """ # pylint: disable=line-too-long - # Step parameter only - for i, (param_container, new_state) in enumerate( - zip(self.param_containers_groups, self.state_groups) - ): - flattened_params, container_treedef = pytree.tree_flatten(param_container) - flattened_params = tuple(flattened_params) - grads = torch.autograd.grad( - loss, flattened_params, create_graph=True, allow_unused=True - ) - updates, new_state = self.impl.update( - grads, - new_state, - params=flattened_params, - inplace=False, - ) - self.state_groups[i] = new_state - flattened_new_params = apply_updates(flattened_params, updates, inplace=False) - new_params = pytree.tree_unflatten(container_treedef, flattened_new_params) - for container, new_param in zip(param_container, new_params): - container.update(new_param) - - def add_param_group(self, net): - """Add a param group to the optimizer's :attr:`state_groups`.""" - # pylint: disable-next=import-outside-toplevel,cyclic-import - from torchopt._src.utils import _extract_container - - net_container = _extract_container(net, with_buffer=False) - flattened_params = tuple(pytree.tree_leaves(net_container)) - optimizer_state = self.impl.init(flattened_params) - self.param_containers_groups.append(net_container) - self.state_groups.append(optimizer_state) - - def state_dict(self): - """Extract the references of the optimizer states. - - Note that the states are references, so any in-place operations will change the states - inside :class:`MetaOptimizer` at the same time. - """ - return tuple(self.state_groups) - - def load_state_dict(self, state_dict): - """Load the references of the optimizer states.""" - self.state_groups[:] = list(state_dict) diff --git a/torchopt/_src/optimizer/meta/rmsprop.py b/torchopt/_src/optimizer/meta/rmsprop.py deleted file mode 100644 index 20183236..00000000 --- a/torchopt/_src/optimizer/meta/rmsprop.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import torch.nn as nn - -from torchopt._src.alias import rmsprop -from torchopt._src.optimizer.meta.base import MetaOptimizer -from torchopt._src.typing import ScalarOrSchedule - - -class MetaRMSProp(MetaOptimizer): - """The differentiable RMSProp optimizer. - - See Also: - - The functional RMSProp optimizer: :func:`torchopt.rmsprop`. - - The classic RMSProp optimizer: :class:`torchopt.RMSProp`. - """ - - # pylint: disable-next=too-many-arguments - def __init__( - self, - net: nn.Module, - lr: ScalarOrSchedule = 1e-2, - alpha: float = 0.99, - eps: float = 1e-8, - weight_decay: float = 0.0, - momentum: float = 0.0, - centered: bool = False, - *, - initial_scale: float = 0.0, - nesterov: bool = False, - maximize: bool = False, - ): - """The :meth:`init` function. - - Args: - net: (nn.Module) - A network whose parameters should be optimized. - lr: (default: :const:`1e-2`) - This is a fixed global scaling factor. - alpha: (default: :const:`0.99`) - Smoothing constant, the decay used to track the magnitude of previous gradients. - eps: (default: :const:`1e-8`) - A small numerical constant to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - momentum: (default: :const:`0.0`) - The decay rate used by the momentum term. The momentum is not used when it is set to - :const:`0.0`. - centered: (default: :data:`False`) - If :data:`True`, use the variance of the past gradients to rescale the latest - gradients. - initial_scale: (default: :data:`0.0`) - Initialization of accumulators tracking the magnitude of previous updates. PyTorch - uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When reproducing results from a - paper, verify the value used by the authors. - nesterov: (default: :data:`False`) - Whether to use Nesterov momentum. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - """ - super().__init__( - net, - rmsprop( - lr=lr, - alpha=alpha, - eps=eps, - weight_decay=weight_decay, - momentum=momentum, - centered=centered, - initial_scale=initial_scale, - nesterov=nesterov, - maximize=maximize, - ), - ) - - -MetaRMSprop = MetaRMSProp # alias for PyTorch compatibility diff --git a/torchopt/_src/optimizer/rmsprop.py b/torchopt/_src/optimizer/rmsprop.py deleted file mode 100644 index 3b8634f3..00000000 --- a/torchopt/_src/optimizer/rmsprop.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from typing import Iterable - -import torch - -from torchopt._src.alias import rmsprop -from torchopt._src.optimizer.base import Optimizer -from torchopt._src.typing import ScalarOrSchedule - - -class RMSProp(Optimizer): - """The classic RMSProp optimizer. - - See Also: - - The functional RMSProp optimizer: :func:`torchopt.rmsprop`. - - The differentiable meta-RMSProp optimizer: :class:`torchopt.MetaRMSProp`. - """ - - # pylint: disable-next=too-many-arguments - def __init__( - self, - params: Iterable[torch.Tensor], - lr: ScalarOrSchedule = 1e-2, - alpha: float = 0.99, - eps: float = 1e-8, - weight_decay: float = 0.0, - momentum: float = 0.0, - centered: bool = False, - *, - initial_scale: float = 0.0, - nesterov: bool = False, - maximize: bool = False, - ): - r"""The `init` function. - - Args: - params: (iterable of torch.Tensor) - An iterable of :class:`torch.Tensor`\s. Specifies what Tensors should be optimized. - lr: (default: :const:`1e-2`) - This is a fixed global scaling factor. - alpha: (default: :const:`0.99`) - Smoothing constant, the decay used to track the magnitude of previous gradients. - eps: (default: :const:`1e-8`) - A small numerical constant to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - momentum: (default: :const:`0.0`) - The decay rate used by the momentum term. The momentum is not used when it is set to - :const:`0.0`. - centered: (default: :data:`False`) - If :data:`True`, use the variance of the past gradients to rescale the latest - gradients. - initial_scale: (default: :data:`0.0`) - Initialization of accumulators tracking the magnitude of previous updates. PyTorch - uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When reproducing results from a - paper, verify the value used by the authors. - nesterov: (default: :data:`False`) - Whether to use Nesterov momentum. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - """ - super().__init__( - params, - rmsprop( - lr=lr, - alpha=alpha, - eps=eps, - weight_decay=weight_decay, - momentum=momentum, - centered=centered, - initial_scale=initial_scale, - nesterov=nesterov, - maximize=maximize, - ), - ) - - -RMSprop = RMSProp # alias for PyTorch compatibility diff --git a/torchopt/_src/transform.py b/torchopt/_src/transform.py deleted file mode 100644 index 15bf11ed..00000000 --- a/torchopt/_src/transform.py +++ /dev/null @@ -1,897 +0,0 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -# This file is modified from: -# https://github.com/deepmind/optax/blob/master/optax/_src/transform.py -# ============================================================================== -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -# pylint: disable=invalid-name - -from typing import Any, Callable, List, NamedTuple, Optional, Sequence, Union - -import torch - -from torchopt._src import base -from torchopt._src.typing import Schedule -from torchopt._src.utils import pytree - - -ScaleState = base.EmptyState -INT32_MAX = torch.iinfo(torch.int32).max -TRIPLE_PYTREEDEF = pytree.tree_structure((0, 1, 2)) - - -def map_flattened(func: Callable, *args: Any) -> List[Any]: - """Apply a function to each element of a flattened list.""" - return list(map(func, *args)) - - -def with_flattened_tree(inner: base.GradientTransformation) -> base.GradientTransformation: - # pylint: disable-next=line-too-long - """Wraps around the inner transformation that manipulates the flattened tree structure (:class:``list``).""" - - def init_fn(params): - return inner.init(pytree.tree_leaves(params)) - - def update_fn(updates, state, *, params=None, inplace=True): - flattened_updates, treedef = pytree.tree_flatten(updates) - if params is not None: - params = pytree.tree_leaves(params) - - flattened_updates, state = inner.update( - flattened_updates, state, params=params, inplace=inplace - ) - updates = pytree.tree_unflatten(treedef, flattened_updates) - - return updates, state - - return base.GradientTransformation(init_fn, update_fn) - - -def inc_count(updates: base.Updates, count: Sequence[torch.Tensor]) -> Sequence[torch.Tensor]: - """Increments int counter by one. - - Returns: - A counter incremeted by one, or max_int if the maximum precision is reached. - """ - return _inc_count(updates=updates, count=count, already_flattened=False) - - -def _inc_count( - updates: base.Updates, count: Sequence[torch.Tensor], *, already_flattened: bool = False -) -> Sequence[torch.Tensor]: - def f(c, g): - return c + (c != INT32_MAX).to(torch.int32) if g is not None else c - - if already_flattened: - return map_flattened(f, count, updates) - return pytree.tree_map(f, count, updates) - - -def scale(step_size: float) -> base.GradientTransformation: - """Scale updates by some fixed scalar ``step_size``. - - Args: - step_size: A scalar corresponding to a fixed scaling factor for updates. - - Returns: - An ``(init_fn, update_fn)`` tuple. - """ - return _scale(step_size=step_size, already_flattened=False) - - -def _scale(step_size: float, *, already_flattened: bool = False) -> base.GradientTransformation: - if already_flattened: - tree_map = map_flattened - else: - tree_map = pytree.tree_map - - def init_fn(params): # pylint: disable=unused-argument - return ScaleState() - - def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument - if inplace: - - def f(g): - return g.mul_(step_size) if g is not None else None - - else: - - def f(g): - return g.mul(step_size) if g is not None else None - - updates = tree_map(f, updates) - return updates, state - - return base.GradientTransformation(init_fn, update_fn) - - -class ScaleByScheduleState(NamedTuple): - """Maintains count for scale scheduling.""" - - count: Sequence[torch.Tensor] # type: ignore - - -def scale_by_schedule(step_size_fn: Schedule) -> base.GradientTransformation: - """Scale updates using a custom schedule for the ``step_size``. - - Args: - step_size_fn: - A function that takes an update count as input and proposes the ``step_size`` to - multiply the updates by. - - Returns: - An ``(init_fn, update_fn)`` tuple. - """ - return _scale_by_schedule(step_size_fn=step_size_fn, already_flattened=False) - - -def _scale_by_schedule( - step_size_fn: Schedule, *, already_flattened: bool = False -) -> base.GradientTransformation: - if already_flattened: - tree_map = map_flattened - else: - tree_map = pytree.tree_map - - def init_fn(params): - zero = tree_map( # count init - lambda t: torch.zeros(1, dtype=torch.int32, device=t.device), params - ) - return ScaleByScheduleState(count=zero) - - def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument - step_size = step_size_fn(state.count) - - if inplace: - - def f(g): - return g.mul_(step_size) if g is not None else None - - else: - - def f(g): - return g.mul(step_size) if g is not None else None - - updates = tree_map(f, updates) - return updates, ScaleByScheduleState(count=inc_count(updates, state.count)) - - return base.GradientTransformation(init_fn, update_fn) - - -def _update_moment(updates, moments, decay, *, order, inplace=True, already_flattened=False): - """Compute the exponential moving average of the ``order``-th moment.""" - assert order in (1, 2) - - if inplace: - - if order == 2: - - def f(g, t): - return t.mul_(decay).addcmul_(g, g, value=1 - decay) if g is not None else t - - else: - - def f(g, t): - return t.mul_(decay).add_(g, alpha=1 - decay) if g is not None else t - - else: - - if order == 2: - - def f(g, t): - return t.mul(decay).addcmul_(g, g, value=1 - decay) if g is not None else t - - else: - - def f(g, t): - return t.mul(decay).add_(g, alpha=1 - decay) if g is not None else t - - if already_flattened: - return map_flattened(f, updates, moments) - return pytree.tree_map(f, updates, moments) - - -class ScaleByAdamState(NamedTuple): - """State for the Adam algorithm.""" - - mu: base.Updates - nu: base.Updates - count: Sequence[torch.Tensor] # type: ignore - - -def _bias_correction(moment, decay, count, *, already_flattened=False): - """Perform bias correction. This becomes a no-op as count goes to infinity.""" - - def f(t, c): - return t.div(1 - decay**c) - - if already_flattened: - return map_flattened(f, moment, count) - return pytree.tree_map(f, moment, count) - - -def scale_by_adam( - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - moment_requires_grad: bool = False, -) -> base.GradientTransformation: - """Rescale updates according to the Adam algorithm. - - References: - [Kingma et al, 2014](https://arxiv.org/abs/1412.6980) - - Args: - b1: (default: :const:`0.9`) - Decay rate for the exponentially weighted average of grads. - b2: (default: :const:`0.999`) - Decay rate for the exponentially weighted average of squared grads. - eps: (default: :const:`1e-8`) - Term added to the denominator to improve numerical stability. - eps_root: (default: :const:`0.0`) - Term added to the denominator inside the square-root to improve - numerical stability when back-propagating gradients through the rescaling. - moment_requires_grad: (default: :data:`False`) - if :data:`True`, states will be created with flag `requires_grad = True`. - - Returns: - An (init_fn, update_fn) tuple. - """ - return _scale_by_adam( - b1=b1, - b2=b2, - eps=eps, - eps_root=eps_root, - moment_requires_grad=moment_requires_grad, - already_flattened=False, - ) - - -def _scale_by_adam( - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - moment_requires_grad: bool = False, - *, - already_flattened: bool = False, -) -> base.GradientTransformation: - # pylint: disable=unneeded-not - if not 0.0 <= eps: - raise ValueError(f'Invalid epsilon value: {eps}') - if not 0.0 <= b1 < 1.0: - raise ValueError(f'Invalid beta parameter at index 0: {b1}') - if not 0.0 <= b2 < 1.0: - raise ValueError(f'Invalid beta parameter at index 1: {b2}') - # pylint: enable=unneeded-not - - if already_flattened: - tree_map = map_flattened - else: - tree_map = pytree.tree_map - - def init_fn(params): - zero = tree_map( # count init - lambda t: torch.zeros(1, dtype=torch.int32, device=t.device), params - ) - mu = tree_map( # first moment - lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params - ) - nu = tree_map( # second moment - lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params - ) - return ScaleByAdamState(mu=mu, nu=nu, count=zero) - - def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument - mu = _update_moment( - updates, state.mu, b1, order=1, inplace=inplace, already_flattened=already_flattened - ) - nu = _update_moment( - updates, state.nu, b2, order=2, inplace=inplace, already_flattened=already_flattened - ) - count_inc = _inc_count(updates, state.count, already_flattened=already_flattened) - mu_hat = _bias_correction(mu, b1, count_inc, already_flattened=already_flattened) - nu_hat = _bias_correction(nu, b2, count_inc, already_flattened=already_flattened) - - if inplace: - - def f(g, m, v): - return m.div_(v.add_(eps_root).sqrt_().add_(eps)) if g is not None else None - - else: - - def f(g, m, v): - return m.div(v.add(eps_root).sqrt_().add_(eps)) if g is not None else None - - updates = tree_map(f, updates, mu_hat, nu_hat) - return updates, ScaleByAdamState(mu=mu, nu=nu, count=count_inc) - - return base.GradientTransformation(init_fn, update_fn) - - -def scale_by_accelerated_adam( - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - moment_requires_grad: bool = False, -) -> base.GradientTransformation: - """Rescale updates according to the Adam algorithm. - - This function is accelerated by using some fused accelerated operators. - - References: - [Kingma et al, 2014](https://arxiv.org/abs/1412.6980) - - Args: - b1: (default: :const:`0.9`) - Decay rate for the exponentially weighted average of grads. - b2: (default: :const:`0.999`) - Decay rate for the exponentially weighted average of squared grads. - eps: (default: :const:`1e-8`) - Term added to the denominator to improve numerical stability. - eps_root: (default: :const:`0.0`) - Term added to the denominator inside the square-root to improve - numerical stability when back-propagating gradients through the rescaling. - moment_requires_grad: (default: :data:`False`) - if :data:`True`, states will be created with flag `requires_grad = True`. - - Returns: - An (init_fn, update_fn) tuple. - """ - return _scale_by_accelerated_adam( - b1=b1, - b2=b2, - eps=eps, - eps_root=eps_root, - moment_requires_grad=moment_requires_grad, - already_flattened=False, - ) - - -def _scale_by_accelerated_adam( - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - moment_requires_grad: bool = False, - *, - already_flattened: bool = False, -) -> base.GradientTransformation: - # pylint: disable=unneeded-not - if not 0.0 <= eps: - raise ValueError(f'Invalid epsilon value: {eps}') - if not 0.0 <= b1 < 1.0: - raise ValueError(f'Invalid beta parameter at index 0: {b1}') - if not 0.0 <= b2 < 1.0: - raise ValueError(f'Invalid beta parameter at index 1: {b2}') - # pylint: enable=unneeded-not - - from torchopt._src.accelerated_op import AdamOp # pylint: disable=import-outside-toplevel - - if already_flattened: - tree_map = map_flattened - - # pylint: disable-next=unused-argument - def update_fn(updates, state, *, params=None, inplace=True): - count_inc = _inc_count(updates, state.count, already_flattened=True) - - op = AdamOp(b1=b1, b2=b2, eps=eps, eps_root=eps_root, inplace=inplace) - out = map_flattened(op, state.mu, state.nu, updates, count_inc) - - new_mu, new_nu, new_updates = tuple(zip(*out)) # transpose - return new_updates, ScaleByAdamState(mu=new_mu, nu=new_nu, count=count_inc) - - else: - tree_map = pytree.tree_map - - # pylint: disable-next=unused-argument - def update_fn(updates, state, *, params=None, inplace=True): - count_inc = _inc_count(updates, state.count, already_flattened=False) - - treedef = pytree.tree_structure(updates) - - op = AdamOp(b1=b1, b2=b2, eps=eps, eps_root=eps_root, inplace=inplace) - out = pytree.tree_map(op, state.mu, state.nu, updates, count_inc) - - new_mu, new_nu, new_updates = pytree.tree_transpose(treedef, TRIPLE_PYTREEDEF, out) - return new_updates, ScaleByAdamState(mu=new_mu, nu=new_nu, count=count_inc) - - def init_fn(params): - zero = tree_map( # count init - lambda t: torch.zeros(1, dtype=torch.int32, device=t.device), params - ) - mu = tree_map( # first moment - lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params - ) - nu = tree_map( # second moment - lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params - ) - return ScaleByAdamState(mu=mu, nu=nu, count=zero) - - return base.GradientTransformation(init_fn, update_fn) - - -class TraceState(NamedTuple): - """Holds an aggregation of past updates.""" - - trace: base.Params - - -def trace( - momentum: float = 0.9, - dampening: float = 0.0, - nesterov: bool = False, - moment_requires_grad: bool = False, -) -> base.GradientTransformation: - """Compute a trace of past updates. - - Note: `trace` and `ema` have very similar but distinct updates; - `trace = decay * trace + t`, while `ema = decay * ema + (1 - decay) * t`. - Both are frequently found in the optimization literature. - - Args: - momentum: (default: :const:`0.9`) - The decay rate for the trace of past updates. - dampening: (default: :const:`0.0`) - Dampening for momentum. - nesterov: (default: :data:`False`) - Whether to use Nesterov momentum. - moment_requires_grad: (default: :data:`False`) - if :data:`True`, states will be created with flag `requires_grad = True`. - - Returns: - An (init_fn, update_fn) tuple. - """ - return _trace( - momentum=momentum, - dampening=dampening, - nesterov=nesterov, - moment_requires_grad=moment_requires_grad, - already_flattened=False, - ) - - -def _trace( - momentum: float = 0.9, - dampening: float = 0.0, - nesterov: bool = False, - moment_requires_grad: bool = False, - *, - already_flattened: bool = False, -) -> base.GradientTransformation: - # pylint: disable=unneeded-not - if not 0.0 <= momentum: - raise ValueError(f'Invalid momentum value: {momentum}') - if nesterov and (momentum <= 0.0 or dampening != 0.0): - raise ValueError('Nesterov momentum requires a momentum and zero dampening') - # pylint: enable=unneeded-not - - if momentum == 0.0: - return base.identity() - - if already_flattened: - tree_map = map_flattened - else: - tree_map = pytree.tree_map - - def init_fn(params): - return TraceState( - trace=tree_map( - lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params - ) - ) - - first_call = True - - def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument - nonlocal first_call - - if nesterov: - if inplace: - - def f1(g, t): - if first_call: - return t.add_(g) - return t.mul_(momentum).add_(g) - - def f2(g, t): - return g.add_(t, alpha=momentum) - - new_trace = tree_map(f1, updates, state.trace) - updates = tree_map(f2, updates, new_trace) - else: - - def f1(g, t): - if first_call: - return t.add(g) - return t.mul(momentum).add_(g) - - def f2(g, t): - return g.add(t, alpha=momentum) - - new_trace = tree_map(f1, updates, state.trace) - updates = tree_map(f2, updates, new_trace) - else: - if inplace: - - def f(g, t): - if first_call: - return t.add(g) - return t.mul_(momentum).add_(g, alpha=1.0 - dampening) - - def copy_(g, t): - return g.copy_(t) - - new_trace = tree_map(f, updates, state.trace) - updates = tree_map(copy_, updates, new_trace) - else: - - def f(g, t): - if first_call: - return t.add(g) - return t.mul(momentum).add_(g, alpha=1.0 - dampening) - - new_trace = tree_map(f, updates, state.trace) - updates = tree_map(torch.clone, new_trace) - - first_call = False - return updates, TraceState(trace=new_trace) - - return base.GradientTransformation(init_fn, update_fn) - - -class ScaleByRmsState(NamedTuple): - """State for exponential root mean-squared (RMS)-normalized updates.""" - - nu: base.Updates - - -def scale_by_rms( - alpha: float = 0.9, eps: float = 1e-8, initial_scale: float = 0.0 -) -> base.GradientTransformation: - """Rescale updates by the root of the exp. moving avg of the square. - - References: - [Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) - - Args: - alpha: (default: :const:`0.9`) - Decay rate for the exponentially weighted average of squared grads. - eps: (default: :const:`1e-8`) - Term added to the denominator to improve numerical stability. - initial_scale: (default: :const:`0.0`) - Initial value for second moment - - Returns: - An (init_fn, update_fn) tuple. - """ - return _scale_by_rms( - alpha=alpha, - eps=eps, - initial_scale=initial_scale, - already_flattened=False, - ) - - -def _scale_by_rms( - alpha: float = 0.9, - eps: float = 1e-8, - initial_scale: float = 0.0, - *, - already_flattened: bool = False, -) -> base.GradientTransformation: - # pylint: disable=unneeded-not - if not 0.0 <= alpha: - raise ValueError(f'Invalid alpha value: {alpha}') - if not 0.0 <= eps: - raise ValueError(f'Invalid epsilon value: {eps}') - # pylint: enable=unneeded-not - - if already_flattened: - tree_map = map_flattened - else: - tree_map = pytree.tree_map - - def init_fn(params): - nu = tree_map(lambda n: torch.full_like(n, initial_scale), params) # second moment - return ScaleByRmsState(nu=nu) - - def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument - nu = _update_moment( - updates, state.nu, alpha, order=2, inplace=inplace, already_flattened=already_flattened - ) - - if inplace: - - def f(g, n): - return g.div_(n.sqrt().add_(eps)) - - else: - - def f(g, n): - return g.div(n.sqrt().add_(eps)) - - updates = tree_map(f, updates, nu) - return updates, ScaleByRmsState(nu=nu) - - return base.GradientTransformation(init_fn, update_fn) - - -class ScaleByRStdDevState(NamedTuple): - """State for centered exponential moving average of squares of updates.""" - - mu: base.Updates - nu: base.Updates - - -def scale_by_stddev( - alpha: float = 0.9, eps: float = 1e-8, initial_scale: float = 0.0 -) -> base.GradientTransformation: - """Rescale updates by the root of the centered exp. moving average of squares. - - References: - [Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) - - Args: - alpha: (default: :const:`0.9`) - Decay rate for the exponentially weighted average of squared grads. - eps: (default: :const:`1e-8`) - Term added to the denominator to improve numerical stability. - initial_scale: (default: :const:`0.0`) - Initial value for second moment - - Returns: - An (init_fn, update_fn) tuple. - """ - return _scale_by_stddev( - alpha=alpha, - eps=eps, - initial_scale=initial_scale, - already_flattened=False, - ) - - -def _scale_by_stddev( - alpha: float = 0.9, - eps: float = 1e-8, - initial_scale: float = 0.0, - *, - already_flattened: bool = False, -) -> base.GradientTransformation: - # pylint: disable=unneeded-not - if not 0.0 <= alpha: - raise ValueError(f'Invalid alpha value: {alpha}') - if not 0.0 <= eps: - raise ValueError(f'Invalid epsilon value: {eps}') - # pylint: enable=unneeded-not - - if already_flattened: - tree_map = map_flattened - else: - tree_map = pytree.tree_map - - def init_fn(params): - mu = tree_map(torch.zeros_like, params) # first moment - nu = tree_map(lambda n: torch.full_like(n, initial_scale), params) # second moment - return ScaleByRStdDevState(mu=mu, nu=nu) - - def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument - mu = _update_moment( - updates, state.mu, alpha, order=1, inplace=inplace, already_flattened=already_flattened - ) - nu = _update_moment( - updates, state.nu, alpha, order=2, inplace=inplace, already_flattened=already_flattened - ) - - if inplace: - - def f(g, m, n): - return g.div_(n.addcmul(m, m, value=-1.0).sqrt_().add_(eps)) - - else: - - def f(g, m, n): - return g.div(n.addcmul(m, m, value=-1.0).sqrt_().add_(eps)) - - updates = tree_map(f, updates, mu, nu) - return updates, ScaleByRStdDevState(mu=mu, nu=nu) - - return base.GradientTransformation(init_fn, update_fn) - - -class MaskedState(NamedTuple): - """Maintains inner transform state for masked transformations.""" - - inner_state: Any - - -class MaskedNode(NamedTuple): - """A node used to mask out unspecified parts of a tree. - - This node is ignored when mapping functions across the tree e.g. using - :func:`pytree.tree_map` since it is a container without children. It can - therefore be used to mask out parts of a tree. - """ - - -def masked( - inner: base.GradientTransformation, - mask: Union[Any, Callable[[base.Params], Any]], -) -> base.GradientTransformation: - """Mask updates so only some are transformed, the rest are passed through. - - For example, it is common to skip weight decay for BatchNorm scale and all - bias parameters. In many networks, these are the only parameters with only - one dimension. So, you may create a mask function to mask these out as - follows:: - mask_fn = lambda p: pytree.tree_map(lambda x: x.ndim != 1, p) - weight_decay = torchopt.masked(torchopt.add_decayed_weights(0.001), mask_fn) - You may alternatively create the mask pytree upfront:: - mask = pytree.tree_map(lambda x: x.ndim != 1, params) - weight_decay = torchopt.masked(torchopt.add_decayed_weights(0.001), mask) - For the ``inner`` transform, state will only be stored for the parameters that - have a mask value of ``True``. - - Args: - inner: Inner transformation to mask. - mask: a PyTree with same structure as (or a prefix of) the params PyTree, or - a Callable that returns such a pytree given the params/updates. The leaves - should be booleans, ``True`` for leaves/subtrees you want to apply the - transformation to, and ``False`` for those you want to skip. The mask must - be static for the gradient transformation to be jit-compilable. - - Returns: - New GradientTransformation wrapping ``inner``. - """ - return _masked( - inner=inner, - mask=mask, - already_flattened=False, - ) - - -def _masked( - inner: base.GradientTransformation, - mask: Union[Any, Callable[[base.Params], Any]], - *, - already_flattened: bool = False, -) -> base.GradientTransformation: - - if already_flattened: - tree_map = map_flattened - else: - tree_map = pytree.tree_map - - def tree_mask(params, mask_tree): - return tree_map(lambda p, m: p if m else MaskedNode(), params, mask_tree) - - def init_fn(params): - mask_tree = mask(params) if callable(mask) else mask - masked_params = tree_mask(params, mask_tree) - return MaskedState(inner_state=inner.init(masked_params)) - - def update_fn(updates, state, params=None, inplace=True): # pylint: disable=unused-argument - mask_tree = mask(updates) if callable(mask) else mask - masked_updates = tree_mask(updates, mask_tree) - masked_params = None if params is None else tree_mask(params, mask_tree) - - new_masked_updates, new_inner_state = inner.update( - masked_updates, state.inner_state, params=masked_params, inplace=inplace - ) - - new_updates = tree_map( - lambda new_u, old_u, m: new_u if m else old_u, new_masked_updates, updates, mask_tree - ) - return new_updates, MaskedState(inner_state=new_inner_state) - - return base.GradientTransformation(init_fn, update_fn) - - -AddDecayedWeightsState = base.EmptyState - - -# mypy: ignore-errors -def add_decayed_weights( - weight_decay: float = 0.0, - mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, -) -> base.GradientTransformation: - """Add parameter scaled by `weight_decay`. - - Args: - weight_decay: a scalar weight decay rate. - mask: a tree with same structure as (or a prefix of) the params PyTree, - or a Callable that returns such a pytree given the params/updates. - The leaves should be booleans, `True` for leaves/subtrees you want to - apply the transformation to, and `False` for those you want to skip. - - Returns: - An (init_fn, update_fn) tuple. - """ - return _add_decayed_weights( - weight_decay=weight_decay, - mask=mask, - already_flattened=False, - ) - - -# mypy: ignore-errors -def _add_decayed_weights( - weight_decay: float = 0.0, - mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, - *, - already_flattened: bool = False, -) -> base.GradientTransformation: - if not 0.0 <= weight_decay: # pylint: disable=unneeded-not - raise ValueError(f'Invalid weight_decay value: {weight_decay}') - - if weight_decay == 0.0 and mask is None: - return base.identity() - - if already_flattened: - tree_map = map_flattened - else: - tree_map = pytree.tree_map - - def init_fn(params): # pylint: disable=unused-argument - return AddDecayedWeightsState() - - def update_fn(updates, state, params=None, inplace=True): # pylint: disable=unused-argument - assert params is not None, ( - 'Parameters are required for weight decay. ' - 'Call `update(updates, state, params=params)` instead.' - ) - - if inplace: - - def f(g, p): - if g is not None: - if g.requires_grad: - return g.add_(p, alpha=weight_decay) - return g.add_(p.data, alpha=weight_decay) - return None - - else: - - def f(g, p): - return g.add(p, alpha=weight_decay) if g is not None else None - - updates = tree_map(f, updates, params) - return updates, state - - # If mask is not `None`, apply mask to the gradient transformation. - # E.g. it is common to skip weight decay on bias units and batch stats. - if mask is not None: - return _masked( - inner=base.GradientTransformation(init_fn, update_fn), - mask=mask, - already_flattened=already_flattened, - ) - return base.GradientTransformation(init_fn, update_fn) diff --git a/torchopt/_src/utils.py b/torchopt/_src/utils.py deleted file mode 100644 index 6bfd5bbe..00000000 --- a/torchopt/_src/utils.py +++ /dev/null @@ -1,194 +0,0 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from typing import Dict, List, NamedTuple, Union - -import optree as pytree -import torch -import torch.nn as nn - - -class _ModuleState(NamedTuple): - params: List[Dict] - visual_contents: Union[None, Dict] = None - - -# mypy: ignore-errors -def stop_gradient(target): - """Stop the gradient for the input object. - - Since a tensor use :attr:`grad_fn` to connect itself with the previous computation graph, the - back-propagated gradient will flow over the tensor and continue flow to the tensors that is - connected by :attr:`grad_fn`. Some algorithms requires manually detaching tensors from the - computation graph. - - Note that the :func:`stop_gradient` operation is in-place. - - Args: - target: The target that to be detached from the computation graph, it could be a - :class:`nn.Module`, :class:`torchopt.MetaOptimizer`, state of the - :class:`torchopt.MetaOptimizer`, or just a plain list of tensors. - inplace: If :data:`True`, the target will be detached in-place. if :data:`Frue`, this - function will return a detached copy of the target. The in-place operation is fast and - memory efficient but may raise back-propagation error. - """ - # pylint: disable-next=import-outside-toplevel,cyclic-import - from torchopt._src.optimizer.meta.base import MetaOptimizer - - def f(obj): - if isinstance(obj, torch.Tensor): - requires_grad = obj.requires_grad - obj.detach_().requires_grad_(requires_grad) - - if isinstance(target, _ModuleState): - true_target = target.params - elif isinstance(target, nn.Module): - true_target = tuple(target.parameters()) - elif isinstance(target, MetaOptimizer): - true_target = pytree.tree_leaves(target.state_dict()) - else: - true_target = target - - pytree.tree_map(f, true_target) - - -# pylint: disable-next=too-many-branches,too-many-locals -def extract_state_dict(mod, copy=False, *, with_buffer=True, enable_visual=False, visual_prefix=''): - """Extract target state. - - Since a tensor use :attr:`grad_fn` to connect itself with the previous computation graph, the - back-propagated gradient will flow over the tensor and continue flow to the tensors that is - connected by :attr:`grad_fn`. Some algorithms requires manually detaching tensors from the - computation graph. - - Note that the extracted state is a reference, which means any in-place operator will affect the - target that the state is extracted from. - - Args: - mod: It could be a :class:`nn.Module` or :class:`torchopt.MetaOptimizer`. - with_buffer: - Extract buffer together with parameters, this argument is only used if the input target - is :class:`nn.Module`. - enable_visual: - Add additional annotations, which could be used in computation graph visualization. - Currently, this flag only has effect on :class:`nn.Module` but we will support - :class:`torchopt.MetaOptimizer` later. - visual_prefix: Prefix for the visualization annotations. - - Returns: - State extracted of the input object. - """ - # pylint: disable=import-outside-toplevel,cyclic-import - from torchopt._src.optimizer.meta.base import MetaOptimizer - - if isinstance(mod, nn.Module): # pylint: disable=no-else-return - if enable_visual: - visual_contents = {} - - for k, v in mod.named_parameters(): # pylint: disable=invalid-name - if v.grad_fn is not None: - visual_contents.update({v.grad_fn: (visual_prefix + k, v)}) - else: - visual_contents.update({v: visual_prefix + k}) - else: - visual_contents = None - - params = [] - - def get_variable(t): - if copy: - requires_grad = t.requires_grad - return t.clone().detach_().requires_grad_(requires_grad) - return t - - def _update(term): - if len(term) != 0: - params.append({k: get_variable(v) for k, v in term.items()}) - - # pylint: disable=protected-access - _update(mod._parameters) - if with_buffer: - _update(mod._buffers) - for module in mod.modules(): - if module is mod: - continue - _update(module._parameters) - if with_buffer: - _update(module._buffers) - return _ModuleState(params=tuple(params), visual_contents=visual_contents) - - elif isinstance(mod, MetaOptimizer): - state = mod.state_dict() - if copy: - - def get_variable(t): - if not isinstance(t, torch.Tensor): - return t - requires_grad = t.requires_grad - return t.clone().detach_().requires_grad_(requires_grad) - - return pytree.tree_map(get_variable, state) - - return state - - raise RuntimeError(f'Unexpected class of {mod}') - - -def _extract_container(mod, with_buffer=True): - if isinstance(mod, nn.Module): - containers = [] - - def _update(term): - if len(term) != 0: - containers.append(term) - - # pylint: disable=protected-access - _update(mod._parameters) - if with_buffer: - _update(mod._buffers) - for module in mod.modules(): - if module is mod: - continue - _update(module._parameters) - if with_buffer: - _update(module._buffers) - return tuple(containers) - - raise RuntimeError(f'Unexpected class of {mod}') - - -def recover_state_dict(mod, state): - """Recover state. - - This function is compatible for the ``extract_state``. - - Note that the recovering process is not in-place, so the tensors of the object will not be - modified. - - Args: - mod: Target that need to recover. - state: The recovering state. - """ - # pylint: disable-next=import-outside-toplevel,cyclic-import - from torchopt._src.optimizer.meta.base import MetaOptimizer - - if isinstance(mod, nn.Module): - target_container = _extract_container(mod) - for target, source in zip(target_container, state.params): - target.update(source) - elif isinstance(mod, MetaOptimizer): - mod.load_state_dict(state) - else: - raise RuntimeError(f'Unexpected class of {mod}') diff --git a/torchopt/_src/accelerated_op/__init__.py b/torchopt/accelerated_op/__init__.py similarity index 72% rename from torchopt/_src/accelerated_op/__init__.py rename to torchopt/accelerated_op/__init__.py index 4c7f1cd9..90452046 100644 --- a/torchopt/_src/accelerated_op/__init__.py +++ b/torchopt/accelerated_op/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,17 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""The accelerated Ops.""" -from typing import Iterable, Optional, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Iterable import torch -from torchopt._src.accelerated_op.adam_op import AdamOp +from torchopt.accelerated_op.adam_op import AdamOp + + +if TYPE_CHECKING: + from torchopt.typing import Device -def accelerated_op_available( - devices: Optional[Union[str, torch.device, Iterable[Union[str, torch.device]]]] = None -) -> bool: +def is_available(devices: Device | Iterable[Device] | None = None) -> bool: """Check the availability of accelerated optimizer.""" op = AdamOp() @@ -30,7 +35,7 @@ def accelerated_op_available( devices = [torch.device('cuda'), torch.device('cpu')] elif isinstance(devices, torch.device): devices = [devices] - elif isinstance(devices, str): + elif isinstance(devices, (int, str)): devices = [torch.device(devices)] try: @@ -40,6 +45,6 @@ def accelerated_op_available( return False updates = torch.tensor(1.0, device=device) op(updates, updates, updates, 1) - return True - except BaseException: # pylint: disable=broad-except + except Exception: # noqa: BLE001 # pylint: disable=broad-except return False + return True diff --git a/torchopt/_src/__init__.py b/torchopt/accelerated_op/_src/__init__.py similarity index 84% rename from torchopt/_src/__init__.py rename to torchopt/accelerated_op/_src/__init__.py index 75b3cf8d..8c2f7b03 100644 --- a/torchopt/_src/__init__.py +++ b/torchopt/accelerated_op/_src/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,5 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - -from torchopt._src.accelerated_op import accelerated_op_available +"""The Python implementation of accelerated ops.""" diff --git a/torchopt/accelerated_op/_src/adam_op.py b/torchopt/accelerated_op/_src/adam_op.py new file mode 100644 index 00000000..d7f9796d --- /dev/null +++ b/torchopt/accelerated_op/_src/adam_op.py @@ -0,0 +1,130 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Python implementation of accelerated AdamOp.""" + +# pylint: disable=invalid-name,too-many-arguments,unused-argument + +from __future__ import annotations + +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + import torch + + +def forward_( + updates: torch.Tensor, + mu: torch.Tensor, + nu: torch.Tensor, + b1: float, + b2: float, + eps: float, + eps_root: float, + count: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Adam forward inplace.""" + mu = mu.mul_(b1).add_(updates, alpha=1.0 - b1) + nu = nu.mul_(b2).addcmul_(updates, updates, value=1.0 - b2) + updates.copy_( + mu.div(1.0 - pow(b1, count)).div_( + nu.div(1.0 - pow(b2, count)).add_(eps_root).sqrt_().add_(eps), + ), + ) + return updates, mu, nu + + +def forward_mu( + updates: torch.Tensor, + mu: torch.Tensor, + b1: float, +) -> torch.Tensor: + """Adam forward mu.""" + return mu.mul(b1).add_(updates, alpha=1.0 - b1) + + +def forward_nu( + updates: torch.Tensor, + nu: torch.Tensor, + b2: float, +) -> torch.Tensor: + """Adam forward nu.""" + return nu.mul(b2).addcmul_(updates, updates, value=1.0 - b2) + + +def forward_updates( + new_mu: torch.Tensor, + new_nu: torch.Tensor, + b1: float, + b2: float, + eps: float, + eps_root: float, + count: int, +) -> torch.Tensor: + """Adam forward updates.""" + return new_mu.div(1.0 - pow(b1, count)).div_( + new_nu.div(1.0 - pow(b2, count)).add_(eps_root).sqrt_().add_(eps), + ) + + +def backward_mu( + dmu: torch.Tensor, + updates: torch.Tensor, + mu: torch.Tensor, + b1: float, +) -> tuple[torch.Tensor, torch.Tensor]: + """Adam backward mu.""" + dupdates = dmu.mul(1.0 - b1) + dmu = dmu.mul(b1) + return dupdates, dmu + + +def backward_nu( + dnu: torch.Tensor, + updates: torch.Tensor, + nu: torch.Tensor, + b2: float, +) -> tuple[torch.Tensor, torch.Tensor]: + """Adam backward nu.""" + dupdates = updates.mul(dnu).mul_(2.0 * (1.0 - b2)) + dnu = dnu.mul(b2) + return dupdates, dnu + + +def backward_updates( + dupdates: torch.Tensor, + updates: torch.Tensor, + new_mu: torch.Tensor, + new_nu: torch.Tensor, + b1: float, + b2: float, + eps_root: float, + count: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Adam backward updates.""" + one_minus_pow_b1 = 1.0 - pow(b1, count) + inv_one_minus_pow_b2 = 1.0 / (1.0 - pow(b2, count) + eps_root) + + updates_div_new_mu = updates.div(new_mu) + dnew_mu_out = dupdates.mul(updates_div_new_mu) + denominator = updates_div_new_mu.mul_(one_minus_pow_b1) + dnew_nu_out = ( + denominator.square_().mul_(dupdates).mul_(updates).mul_(-0.5 * inv_one_minus_pow_b2) + ) + + mask = new_mu == 0 + dnew_mu_out[mask] = 0 + dnew_nu_out[mask] = 0 + return dnew_mu_out, dnew_nu_out diff --git a/torchopt/_src/accelerated_op/adam_op.py b/torchopt/accelerated_op/adam_op.py similarity index 51% rename from torchopt/_src/accelerated_op/adam_op.py rename to torchopt/accelerated_op/adam_op.py index 00261c1a..43ac26cd 100644 --- a/torchopt/_src/accelerated_op/adam_op.py +++ b/torchopt/accelerated_op/adam_op.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,14 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""The accelerated AdamOp.""" # pylint: disable=c-extension-no-member,invalid-name -from typing import Any, Optional, Tuple +from __future__ import annotations + +import contextlib +from typing import Any import torch -from torchopt._C import adam_op # pylint: disable=no-name-in-module + +try: + from torchopt._C import adam_op # pylint: disable=no-name-in-module +except ImportError: + from torchopt.accelerated_op._src import adam_op # type: ignore[no-redef] class AdamOp: # pylint: disable=too-few-public-methods @@ -30,14 +38,13 @@ class MuOp(torch.autograd.Function): # pylint: disable=abstract-method @staticmethod def jvp(ctx: Any, *grad_inputs: Any) -> Any: - # pylint: disable-next=line-too-long - """Defines a formula for differentiating the operation with forward mode automatic differentiation.""" + """Define a formula for differentiating the operation with forward mode automatic differentiation.""" @staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: - """Performs the operation.""" + """Perform the operation.""" updates, mu, b1 = args - new_mu = adam_op.forwardMu(updates, mu, b1) + new_mu = adam_op.forward_mu(updates, mu, b1) ctx.save_for_backward(updates, mu) ctx.b1 = b1 return new_mu @@ -45,11 +52,11 @@ def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: @staticmethod def backward(ctx: Any, *args: Any) -> Any: # pylint: disable-next=line-too-long - """Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the :meth:`vjp` method).""" + """Define a formula for differentiating the operation with backward mode automatic differentiation (alias to the :meth:`vjp` method).""" dmu = args[0] updates, mu = ctx.saved_tensors b1 = ctx.b1 - result = adam_op.backwardMu(dmu, updates, mu, b1) + result = adam_op.backward_mu(dmu, updates, mu, b1) return result[0], result[1], None class NuOp(torch.autograd.Function): # pylint: disable=abstract-method @@ -57,14 +64,13 @@ class NuOp(torch.autograd.Function): # pylint: disable=abstract-method @staticmethod def jvp(ctx: Any, *grad_inputs: Any) -> Any: - # pylint: disable-next=line-too-long - """Defines a formula for differentiating the operation with forward mode automatic differentiation.""" + """Define a formula for differentiating the operation with forward mode automatic differentiation.""" @staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: - """Performs the operation.""" + """Perform the operation.""" updates, nu, b2 = args - new_nu = adam_op.forwardNu(updates, nu, b2) + new_nu = adam_op.forward_nu(updates, nu, b2) ctx.save_for_backward(updates, nu) ctx.b2 = b2 return new_nu @@ -72,11 +78,11 @@ def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: @staticmethod def backward(ctx: Any, *args: Any) -> Any: # pylint: disable-next=line-too-long - """Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the :meth:`vjp` function).""" + """Define a formula for differentiating the operation with backward mode automatic differentiation (alias to the :meth:`vjp` function).""" dnu = args[0] updates, nu = ctx.saved_tensors b2 = ctx.b2 - result = adam_op.backwardNu(dnu, updates, nu, b2) + result = adam_op.backward_nu(dnu, updates, nu, b2) return result[0], result[1], None class UpdatesOp(torch.autograd.Function): # pylint: disable=abstract-method @@ -84,14 +90,13 @@ class UpdatesOp(torch.autograd.Function): # pylint: disable=abstract-method @staticmethod def jvp(ctx: Any, *grad_inputs: Any) -> Any: - # pylint: disable-next=line-too-long - """Defines a formula for differentiating the operation with forward mode automatic differentiation.""" + """Define a formula for differentiating the operation with forward mode automatic differentiation.""" @staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: - """Performs the operation.""" + """Perform the operation.""" new_mu, new_nu, (b1, b2, eps, eps_root, count) = args - new_updates = adam_op.forwardUpdates(new_mu, new_nu, b1, b2, eps, eps_root, count) + new_updates = adam_op.forward_updates(new_mu, new_nu, b1, b2, eps, eps_root, count) ctx.save_for_backward(new_updates, new_mu, new_nu) ctx.others = (b1, b2, eps, eps_root, count) return new_updates @@ -99,11 +104,20 @@ def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: @staticmethod def backward(ctx: Any, *args: Any) -> Any: # pylint: disable-next=line-too-long - """Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the :meth:`vjp` function).""" + """Define a formula for differentiating the operation with backward mode automatic differentiation (alias to the :meth:`vjp` function).""" dupdates = args[0] updates, new_mu, new_nu = ctx.saved_tensors - b1, b2, _, _, count = ctx.others - result = adam_op.backwardUpdates(dupdates, updates, new_mu, new_nu, b1, b2, count) + b1, b2, _, eps_root, count = ctx.others + result = adam_op.backward_updates( + dupdates, + updates, + new_mu, + new_nu, + b1, + b2, + eps_root, + count, + ) return result[0], result[1], None # pylint: disable-next=too-many-arguments @@ -116,7 +130,7 @@ def __init__( eps_root: float = 0.0, inplace: bool = True, ) -> None: - """The :meth:`__init__` function.""" + """Initialize the Adam operator.""" self.b1 = b1 self.b2 = b2 self.eps = eps @@ -124,24 +138,44 @@ def __init__( self.inplace = inplace def __call__( - self, mu: torch.Tensor, nu: torch.Tensor, updates: Optional[torch.Tensor], count: int - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """The :meth:`__call__` function.""" + self, + mu: torch.Tensor, + nu: torch.Tensor, + updates: torch.Tensor | None, + count: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: + """Apply the Adam operator.""" if updates is None: return mu, nu, None - if updates.is_cuda: - current_device = torch.cuda.current_device() - torch.cuda.set_device(updates.device) - if self.inplace: - new_updates, new_mu, new_nu = adam_op.forward_( - updates, mu, nu, self.b1, self.b2, self.eps, self.eps_root, count - ) - else: - new_mu = self.MuOp.apply(updates, mu, self.b1) - new_nu = self.NuOp.apply(updates, nu, self.b2) - new_updates = self.UpdatesOp.apply( - new_mu, new_nu, (self.b1, self.b2, self.eps, self.eps_root, count) - ) - if updates.is_cuda: - torch.cuda.set_device(current_device) + device_context = ( + torch.cuda.device(torch.cuda.current_device()) + if updates.is_cuda + else contextlib.nullcontext() + ) + with device_context: # type: ignore[attr-defined] + if self.inplace: + new_updates, new_mu, new_nu = adam_op.forward_( + updates, + mu, + nu, + self.b1, + self.b2, + self.eps, + self.eps_root, + count, + ) + else: + new_mu = self.MuOp.apply(updates, mu, self.b1) + new_nu = self.NuOp.apply(updates, nu, self.b2) + new_updates = self.UpdatesOp.apply( + new_mu, + new_nu, + ( + self.b1, + self.b2, + self.eps, + self.eps_root, + count, + ), + ) return new_mu, new_nu, new_updates diff --git a/torchopt/_src/combine.py b/torchopt/alias/__init__.py similarity index 67% rename from torchopt/_src/combine.py rename to torchopt/alias/__init__.py index 00e90bc1..5767c5d7 100644 --- a/torchopt/_src/combine.py +++ b/torchopt/alias/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,22 +29,25 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +r"""The aliases of preset :class:`GradientTransformation`\s for optimizers.""" -from torchopt._src import base +from torchopt.alias.adadelta import adadelta +from torchopt.alias.adagrad import adagrad +from torchopt.alias.adam import adam +from torchopt.alias.adamax import adamax +from torchopt.alias.adamw import adamw +from torchopt.alias.radam import radam +from torchopt.alias.rmsprop import rmsprop +from torchopt.alias.sgd import sgd -def chain(*args: base.GradientTransformation) -> base.GradientTransformation: - """Applies a list of chainable update transformations. - - Given a sequence of chainable transforms, :func:`chain` returns an :func:`init_fn` that - constructs a ``state`` by concatenating the states of the individual transforms, and returns an - :func:`update_fn` which chains the update transformations feeding the appropriate state to each. - - Args: - *args: - A sequence of chainable ``(init_fn, update_fn)`` tuples. - - Returns: - A single ``(init_fn, update_fn)`` tuple. - """ - return base.ChainedGradientTransformation(*args) +__all__ = [ + 'adadelta', + 'adagrad', + 'adam', + 'adamax', + 'adamw', + 'radam', + 'rmsprop', + 'sgd', +] diff --git a/torchopt/alias/adadelta.py b/torchopt/alias/adadelta.py new file mode 100644 index 00000000..910cb13e --- /dev/null +++ b/torchopt/alias/adadelta.py @@ -0,0 +1,103 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset :class:`GradientTransformation` for the Adadelta optimizer.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from torchopt.alias.utils import ( + _get_use_chain_flat, + flip_sign_and_add_weight_decay, + scale_by_neg_lr, +) +from torchopt.combine import chain +from torchopt.transform import scale_by_adadelta + + +if TYPE_CHECKING: + from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +__all__ = ['adadelta'] + + +# pylint: disable-next=too-many-arguments +def adadelta( + lr: ScalarOrSchedule = 1e-3, + rho: float = 0.9, + eps: float = 1e-6, + weight_decay: float = 0.0, + *, + moment_requires_grad: bool = False, +) -> GradientTransformation: + """Create a functional version of the AdaDelta optimizer. + + Adadelta is a per-dimension learning rate method for gradient descent. + + References: + - Zeiler, 2012: https://arxiv.org/abs/1212.5701 + + Args: + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + rho (float, optional): Coefficients used for computing running averages of gradient and its square. + (default: :const:`0.9`) + eps (float, optional): A small constant applied to the square root (as in the Adadelta paper) + to avoid dividing by zero when rescaling. + (default: :const:`1e-6`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created with + flag ``requires_grad=True``, this flag is often used in Meta-Learning algorithms. + (default: :data:`False`) + + Returns: + The corresponding :class:`GradientTransformation` instance. + + See Also: + The functional optimizer wrapper :class:`torchopt.FuncOptimizer`. + """ + # pylint: disable=unneeded-not + if not (callable(lr) or lr >= 0.0): # pragma: no cover + raise ValueError(f'Invalid learning rate: {lr}') + if not 0 <= rho <= 1: # pragma: no cover + raise ValueError(f'Invalid rho value: {rho}') + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') + if not weight_decay >= 0.0: # pragma: no cover + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + # pylint: enable=unneeded-not + + chain_fn = chain + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay + adadelta_scaler_fn = scale_by_adadelta + scale_by_neg_lr_fn = scale_by_neg_lr + + if _get_use_chain_flat(): # default behavior + chain_fn = chain_fn.flat # type: ignore[attr-defined] + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay_fn.flat # type: ignore[attr-defined] + adadelta_scaler_fn = adadelta_scaler_fn.flat # type: ignore[attr-defined] + scale_by_neg_lr_fn = scale_by_neg_lr_fn.flat # type: ignore[attr-defined] + + return chain_fn( + flip_sign_and_add_weight_decay_fn(weight_decay=weight_decay, maximize=False), + adadelta_scaler_fn( + rho=rho, + eps=eps, + moment_requires_grad=moment_requires_grad, + ), + scale_by_neg_lr_fn(lr), + ) diff --git a/torchopt/alias/adagrad.py b/torchopt/alias/adagrad.py new file mode 100644 index 00000000..6fdb4aa3 --- /dev/null +++ b/torchopt/alias/adagrad.py @@ -0,0 +1,166 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/alias.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset :class:`GradientTransformation` for the AdaGrad optimizer.""" + +import logging + +from torchopt.alias.utils import ( + _get_use_chain_flat, + flip_sign_and_add_weight_decay, + scale_by_neg_lr, +) +from torchopt.combine import chain +from torchopt.transform import scale_by_rss, scale_by_schedule +from torchopt.typing import GradientTransformation, Numeric, Scalar, ScalarOrSchedule, Schedule + + +__all__ = ['adagrad'] + + +def _adagrad_lr_schedule( + decay_rate: Scalar, + transition_begin: int = 0, +) -> Schedule: + """Construct a schedule dedicated to AdaGrad optimizer. + + This function applies an learning rate decay function to a provided initial value. The function + returns the decayed value as follows: + + .. code-block:: python + + decayed_value = init_value / (1 + count * decay_rate) + + Args: + decay_rate (float): The decay rate. + transition_begin (int, optional): Must be *positive*. After how many steps to start + annealing. (default: :const:`0`) + + Returns: + schedule: A function that maps step counts to values. + """ + if transition_begin < 0: # pragma: no cover + logging.info( + 'The AdaGrad learning rate schedule was set with a negative `transition_begin` ' + 'value; this will result in `transition_begin` falling back to `0`.', + ) + transition_begin = 0 + + def schedule(count: Numeric) -> Numeric: + decreased_count = count - transition_begin + return 1 / (1 + decay_rate * decreased_count) + + return schedule + + +# pylint: disable-next=too-many-arguments +def adagrad( + lr: ScalarOrSchedule = 1e-2, + lr_decay: float = 0.0, + weight_decay: float = 0.0, + initial_accumulator_value: float = 0.0, + eps: float = 1e-10, + *, + maximize: bool = False, +) -> GradientTransformation: + """Create a functional version of the AdaGrad optimizer. + + AdaGrad is an algorithm for gradient based optimization that anneals the learning rate for each + parameter during the course of training. + + .. warning:: + AdaGrad's main limit is the monotonic accumulation of squared gradients in the denominator. + Since all terms are ``> 0``, the sum keeps growing during training, and the learning rate + eventually becomes very small. + + References: + Duchi et al., 2011: https://jmlr.org/papers/v12/duchi11a.html + + Args: + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-2`) + lr_decay (float, optional): Learning rate decay. (default: :const:`0.0`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + initial_accumulator_value (float, optional): Initial value for the accumulator. + (default: :const:`0.0`) + eps (float, optional): A small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-10`) + maximize (bool, optional): Maximize the params based on the objective, instead of minimizing. + (default: :data:`False`) + + Returns: + The corresponding :class:`GradientTransformation` instance. + + See Also: + The functional optimizer wrapper :class:`torchopt.FuncOptimizer`. + """ + # pylint: disable=unneeded-not + if not (callable(lr) or lr >= 0.0): # pragma: no cover + raise ValueError(f'Invalid learning rate: {lr}') + if not lr_decay >= 0.0: # pragma: no cover + raise ValueError(f'Invalid lr_decay value: {lr_decay}') + if not weight_decay >= 0.0: # pragma: no cover + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + if not initial_accumulator_value >= 0.0: # pragma: no cover + raise ValueError(f'Invalid initial_accumulator_value value: {initial_accumulator_value}') + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') + # pylint: enable=unneeded-not + + chain_fn = chain + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay + adagrad_scaler_fn = scale_by_rss + scale_by_neg_lr_fn = scale_by_neg_lr + scale_by_schedule_fn = scale_by_schedule + + if _get_use_chain_flat(): # default behavior + chain_fn = chain_fn.flat # type: ignore[attr-defined] + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay_fn.flat # type: ignore[attr-defined] + adagrad_scaler_fn = adagrad_scaler_fn.flat # type: ignore[attr-defined] + scale_by_neg_lr_fn = scale_by_neg_lr_fn.flat # type: ignore[attr-defined] + scale_by_schedule_fn = scale_by_schedule_fn.flat # type: ignore[attr-defined] + + return chain_fn( + flip_sign_and_add_weight_decay_fn(weight_decay=weight_decay, maximize=maximize), + adagrad_scaler_fn( + initial_accumulator_value=initial_accumulator_value, + eps=eps, + ), + scale_by_schedule_fn( + step_size_fn=_adagrad_lr_schedule( + decay_rate=lr_decay, + transition_begin=0, + ), + ), + scale_by_neg_lr_fn(lr), + ) diff --git a/torchopt/alias/adam.py b/torchopt/alias/adam.py new file mode 100644 index 00000000..0ae0eb8e --- /dev/null +++ b/torchopt/alias/adam.py @@ -0,0 +1,137 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/alias.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset :class:`GradientTransformation` for the Adam optimizer.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from torchopt.alias.utils import ( + _get_use_chain_flat, + flip_sign_and_add_weight_decay, + scale_by_neg_lr, +) +from torchopt.combine import chain +from torchopt.transform import scale_by_accelerated_adam, scale_by_adam + + +if TYPE_CHECKING: + from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +__all__ = ['adam'] + + +# pylint: disable-next=too-many-arguments +def adam( + lr: ScalarOrSchedule = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + *, + eps_root: float = 0.0, + moment_requires_grad: bool = False, + maximize: bool = False, + use_accelerated_op: bool = False, +) -> GradientTransformation: + """Create a functional version of the Adam optimizer. + + Adam is an SGD variant with learning rate adaptation. The *learning rate* used for each weight + is computed from estimates of first- and second-order moments of the gradients (using suitable + exponential moving averages). + + References: + - Kingma et al., 2014: https://arxiv.org/abs/1412.6980 + + Args: + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + eps_root (float, optional): A small constant applied to denominator inside the square root + (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example + when computing (meta-)gradients through Adam. (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created with + flag ``requires_grad=True``, this flag is often used in Meta-Learning algorithms. + (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of minimizing. + (default: :data:`False`) + use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator. + (default: :data:`False`) + + Returns: + The corresponding :class:`GradientTransformation` instance. + + See Also: + The functional optimizer wrapper :class:`torchopt.FuncOptimizer`. + """ + b1, b2 = betas # pylint: disable=invalid-name + # pylint: disable=unneeded-not + if not (callable(lr) or lr >= 0.0): # pragma: no cover + raise ValueError(f'Invalid learning rate: {lr}') + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= b1 < 1.0: # pragma: no cover + raise ValueError(f'Invalid beta parameter at index 0: {b1}') + if not 0.0 <= b2 < 1.0: # pragma: no cover + raise ValueError(f'Invalid beta parameter at index 1: {b2}') + if not weight_decay >= 0.0: # pragma: no cover + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + # pylint: enable=unneeded-not + + chain_fn = chain + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay + adam_scaler_fn = scale_by_accelerated_adam if use_accelerated_op else scale_by_adam + scale_by_neg_lr_fn = scale_by_neg_lr + + if _get_use_chain_flat(): # default behavior + chain_fn = chain_fn.flat # type: ignore[attr-defined] + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay_fn.flat # type: ignore[attr-defined] + adam_scaler_fn = adam_scaler_fn.flat # type: ignore[attr-defined] + scale_by_neg_lr_fn = scale_by_neg_lr_fn.flat # type: ignore[attr-defined] + + return chain_fn( + flip_sign_and_add_weight_decay_fn(weight_decay=weight_decay, maximize=maximize), + adam_scaler_fn( + b1=b1, + b2=b2, + eps=eps, + eps_root=eps_root, + moment_requires_grad=moment_requires_grad, + ), + scale_by_neg_lr_fn(lr), + ) diff --git a/torchopt/alias/adamax.py b/torchopt/alias/adamax.py new file mode 100644 index 00000000..3da16713 --- /dev/null +++ b/torchopt/alias/adamax.py @@ -0,0 +1,105 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset :class:`GradientTransformation` for the Adamax optimizer.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from torchopt.alias.utils import ( + _get_use_chain_flat, + flip_sign_and_add_weight_decay, + scale_by_neg_lr, +) +from torchopt.combine import chain +from torchopt.transform import scale_by_adamax + + +if TYPE_CHECKING: + from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +__all__ = ['adamax'] + + +# pylint: disable-next=too-many-arguments +def adamax( + lr: ScalarOrSchedule = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-6, + weight_decay: float = 0.0, + *, + moment_requires_grad: bool = False, +) -> GradientTransformation: + """Create a functional version of the AdaMax optimizer. + + References: + - Kingma et al., 2014: https://arxiv.org/abs/1412.6980 + + Args: + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to the square root (as in the RAdam paper) + to avoid dividing by zero when rescaling. + (default: :const:`1e-6`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created with + flag ``requires_grad=True``, this flag is often used in Meta-Learning algorithms. + (default: :data:`False`) + + Returns: + The corresponding :class:`GradientTransformation` instance. + + See Also: + The functional optimizer wrapper :class:`torchopt.FuncOptimizer`. + """ + b1, b2 = betas # pylint: disable=invalid-name + # pylint: disable=unneeded-not + if not (callable(lr) or lr >= 0.0): # pragma: no cover + raise ValueError(f'Invalid learning rate: {lr}') + if not 0 <= b1 <= 1: # pragma: no cover + raise ValueError(f'Invalid rho value: {b1}') + if not 0 <= b2 <= 1: # pragma: no cover + raise ValueError(f'Invalid rho value: {b2}') + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') + if not weight_decay >= 0.0: # pragma: no cover + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + # pylint: enable=unneeded-not + + chain_fn = chain + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay + adamax_scaler_fn = scale_by_adamax + scale_by_neg_lr_fn = scale_by_neg_lr + + if _get_use_chain_flat(): # default behavior + chain_fn = chain_fn.flat # type: ignore[attr-defined] + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay_fn.flat # type: ignore[attr-defined] + adamax_scaler_fn = adamax_scaler_fn.flat # type: ignore[attr-defined] + scale_by_neg_lr_fn = scale_by_neg_lr_fn.flat # type: ignore[attr-defined] + + return chain_fn( + flip_sign_and_add_weight_decay_fn(weight_decay=weight_decay, maximize=False), + adamax_scaler_fn( + b1=b1, + b2=b2, + eps=eps, + moment_requires_grad=moment_requires_grad, + ), + scale_by_neg_lr_fn(lr), + ) diff --git a/torchopt/alias/adamw.py b/torchopt/alias/adamw.py new file mode 100644 index 00000000..2dc72ef1 --- /dev/null +++ b/torchopt/alias/adamw.py @@ -0,0 +1,151 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/alias.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset :class:`GradientTransformation` for the AdamW optimizer.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable + +from torchopt.alias.utils import ( + _get_use_chain_flat, + flip_sign_and_add_weight_decay, + scale_by_neg_lr, +) +from torchopt.combine import chain +from torchopt.transform import add_decayed_weights, scale_by_accelerated_adam, scale_by_adam + + +if TYPE_CHECKING: + from torchopt.typing import GradientTransformation, OptState, Params, ScalarOrSchedule + + +__all__ = ['adamw'] + + +# pylint: disable-next=too-many-arguments,too-many-locals +def adamw( + lr: ScalarOrSchedule = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 1e-2, + *, + eps_root: float = 0.0, + mask: OptState | Callable[[Params], OptState] | None = None, + moment_requires_grad: bool = False, + maximize: bool = False, + use_accelerated_op: bool = False, +) -> GradientTransformation: + """Create a functional version of the Adam optimizer with weight decay regularization. + + AdamW uses weight decay to regularize learning towards small weights, as + this leads to better generalization. In SGD you can also use L2 regularization + to implement this as an additive loss term, however L2 regularization + does not behave as intended for adaptive gradient algorithms such as Adam. + + References: + - Loshchilov et al., 2019: https://arxiv.org/abs/1711.05101 + + Args: + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with other + frameworks such as PyTorch, but different from (Loshchilov et al., 2019) where the weight + decay is only multiplied with the "schedule multiplier", but not the base learning rate. + (default: :const:`1e-2`) + eps_root (float, optional): A small constant applied to denominator inside the square root + (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example + when computing (meta-)gradients through Adam. (default: :const:`0.0`) + mask (tree of Tensor, callable, or None, optional): + A tree with same structure as (or a prefix of) the params pytree, or a function that + returns such a pytree given the params/updates. The leaves should be booleans, + :data:`True` for leaves/subtrees you want to apply the weight decay to, and + :data:`False` for those you want to skip. Note that the Adam gradient transformations + are applied to all parameters. (default: :data:`None`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created with + flag ``requires_grad=True``, this flag is often used in Meta-Learning algorithms. + (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) + use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator. + (default: :data:`False`) + + Returns: + The corresponding :class:`GradientTransformation` instance. + + See Also: + The functional optimizer wrapper :class:`torchopt.FuncOptimizer`. + """ + b1, b2 = betas # pylint: disable=invalid-name + # pylint: disable=unneeded-not + if not (callable(lr) or lr >= 0.0): # pragma: no cover + raise ValueError(f'Invalid learning rate: {lr}') + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= b1 < 1.0: # pragma: no cover + raise ValueError(f'Invalid beta parameter at index 0: {b1}') + if not 0.0 <= b2 < 1.0: # pragma: no cover + raise ValueError(f'Invalid beta parameter at index 1: {b2}') + if not weight_decay >= 0.0: # pragma: no cover + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + # pylint: enable=unneeded-not + + chain_fn = chain + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay + adam_scaler_fn = scale_by_accelerated_adam if use_accelerated_op else scale_by_adam + add_decayed_weights_fn = add_decayed_weights + scale_by_neg_lr_fn = scale_by_neg_lr + + if _get_use_chain_flat(): # default behavior + chain_fn = chain_fn.flat # type: ignore[attr-defined] + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay_fn.flat # type: ignore[attr-defined] + adam_scaler_fn = adam_scaler_fn.flat # type: ignore[attr-defined] + add_decayed_weights_fn = add_decayed_weights_fn.flat # type: ignore[attr-defined] + scale_by_neg_lr_fn = scale_by_neg_lr_fn.flat # type: ignore[attr-defined] + + return chain_fn( + flip_sign_and_add_weight_decay_fn(weight_decay=0.0, maximize=maximize), + adam_scaler_fn( + b1=b1, + b2=b2, + eps=eps, + eps_root=eps_root, + moment_requires_grad=moment_requires_grad, + ), + add_decayed_weights_fn(weight_decay=weight_decay, mask=mask), + scale_by_neg_lr_fn(lr), + ) diff --git a/torchopt/alias/radam.py b/torchopt/alias/radam.py new file mode 100644 index 00000000..9e2880ee --- /dev/null +++ b/torchopt/alias/radam.py @@ -0,0 +1,107 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset :class:`GradientTransformation` for the RAdam optimizer.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from torchopt.alias.utils import ( + _get_use_chain_flat, + flip_sign_and_add_weight_decay, + scale_by_neg_lr, +) +from torchopt.combine import chain +from torchopt.transform import scale_by_radam + + +if TYPE_CHECKING: + from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +__all__ = ['radam'] + + +# pylint: disable-next=too-many-arguments +def radam( + lr: ScalarOrSchedule = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-6, + weight_decay: float = 0.0, + *, + moment_requires_grad: bool = False, +) -> GradientTransformation: + """Create a functional version of the RAdam optimizer. + + RAdam is a variance of the adaptive learning rate rectified optimizer. + + References: + - Liu, 2019: https://arxiv.org/abs/1908.03265 + + Args: + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to the square root (as in the RAdam paper) + to avoid dividing by zero when rescaling. + (default: :const:`1e-6`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created with + flag ``requires_grad=True``, this flag is often used in Meta-Learning algorithms. + (default: :data:`False`) + + Returns: + The corresponding :class:`GradientTransformation` instance. + + See Also: + The functional optimizer wrapper :class:`torchopt.FuncOptimizer`. + """ + b1, b2 = betas # pylint: disable=invalid-name + # pylint: disable=unneeded-not + if not (callable(lr) or lr >= 0.0): # pragma: no cover + raise ValueError(f'Invalid learning rate: {lr}') + if not 0 <= b1 <= 1: # pragma: no cover + raise ValueError(f'Invalid rho value: {b1}') + if not 0 <= b2 <= 1: # pragma: no cover + raise ValueError(f'Invalid rho value: {b2}') + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') + if not weight_decay >= 0.0: # pragma: no cover + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + # pylint: enable=unneeded-not + + chain_fn = chain + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay + radam_scaler_fn = scale_by_radam + scale_by_neg_lr_fn = scale_by_neg_lr + + if _get_use_chain_flat(): # default behavior + chain_fn = chain_fn.flat # type: ignore[attr-defined] + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay_fn.flat # type: ignore[attr-defined] + radam_scaler_fn = radam_scaler_fn.flat # type: ignore[attr-defined] + scale_by_neg_lr_fn = scale_by_neg_lr_fn.flat # type: ignore[attr-defined] + + return chain_fn( + flip_sign_and_add_weight_decay_fn(weight_decay=weight_decay, maximize=False), + radam_scaler_fn( + b1=b1, + b2=b2, + eps=eps, + moment_requires_grad=moment_requires_grad, + ), + scale_by_neg_lr_fn(lr), + ) diff --git a/torchopt/alias/rmsprop.py b/torchopt/alias/rmsprop.py new file mode 100644 index 00000000..612e4f45 --- /dev/null +++ b/torchopt/alias/rmsprop.py @@ -0,0 +1,133 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/alias.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset :class:`GradientTransformation` for the RMSProp optimizer.""" + +from torchopt.alias.utils import ( + _get_use_chain_flat, + flip_sign_and_add_weight_decay, + scale_by_neg_lr, +) +from torchopt.combine import chain +from torchopt.transform import scale_by_rms, scale_by_stddev, trace +from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +__all__ = ['rmsprop'] + + +# pylint: disable-next=too-many-arguments +def rmsprop( + lr: ScalarOrSchedule = 1e-2, + alpha: float = 0.99, + eps: float = 1e-8, + weight_decay: float = 0.0, + momentum: float = 0.0, + centered: bool = False, + *, + initial_scale: float = 0.0, + nesterov: bool = False, + maximize: bool = False, +) -> GradientTransformation: + """Create a functional version of the RMSProp optimizer. + + RMSProp is an SGD variant with learning rate adaptation. The *learning rate* used for each + weight is scaled by a suitable estimate of the magnitude of the gradients on previous steps. + Several variants of RMSProp can be found in the literature. This alias provides an easy to + configure RMSProp optimizer that can be used to switch between several of these variants. + + References: + - Tieleman and Hinton, 2012: http://www.cs.toronto.edu/~hinton/coursera/lecture6/lec6.pdf + - Graves, 2013: https://arxiv.org/abs/1308.0850 + + Args: + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-2`) + alpha (float, optional): Smoothing constant, the decay used to track the magnitude of + previous gradients. (default: :const:`0.99`) + eps (float, optional): A small numerical constant to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + momentum (float, optional): The decay rate used by the momentum term. The momentum is not + used when it is set to :const:`0.0`. (default: :const:`0.0`) + centered (bool, optional): If :data:`True`, use the variance of the past gradients to + rescale the latest gradients. (default: :data:`False`) + initial_scale (float, optional): Initialization of accumulators tracking the magnitude of + previous updates. PyTorch uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When + reproducing results from a paper, verify the value used by the authors. + (default: :data:`0.0`) + nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) + + Returns: + The corresponding :class:`GradientTransformation` instance. + + See Also: + The functional optimizer wrapper :class:`torchopt.FuncOptimizer`. + """ + # pylint: disable=unneeded-not + if not (callable(lr) or lr >= 0.0): # pragma: no cover + raise ValueError(f'Invalid learning rate: {lr}') + if not alpha >= 0.0: # pragma: no cover + raise ValueError(f'Invalid alpha value: {alpha}') + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') + if not momentum >= 0.0: # pragma: no cover + raise ValueError(f'Invalid momentum value: {momentum}') + if not weight_decay >= 0.0: # pragma: no cover + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + # pylint: enable=unneeded-not + + chain_fn = chain + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay + rmsprop_scaler_fn = scale_by_stddev if centered else scale_by_rms + trace_fn = trace + scale_by_neg_lr_fn = scale_by_neg_lr + + if _get_use_chain_flat(): # default behavior + chain_fn = chain_fn.flat # type: ignore[attr-defined] + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay_fn.flat # type: ignore[attr-defined] + rmsprop_scaler_fn = rmsprop_scaler_fn.flat # type: ignore[attr-defined] + trace_fn = trace_fn.flat # type: ignore[attr-defined] + scale_by_neg_lr_fn = scale_by_neg_lr_fn.flat # type: ignore[attr-defined] + + return chain_fn( + flip_sign_and_add_weight_decay_fn(weight_decay=weight_decay, maximize=maximize), + rmsprop_scaler_fn( + alpha=alpha, + eps=eps, + initial_scale=initial_scale, + ), + trace_fn(momentum=momentum, nesterov=nesterov), + scale_by_neg_lr_fn(lr), + ) diff --git a/torchopt/alias/sgd.py b/torchopt/alias/sgd.py new file mode 100644 index 00000000..6d5935bc --- /dev/null +++ b/torchopt/alias/sgd.py @@ -0,0 +1,119 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/alias.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset :class:`GradientTransformation` for the SGD optimizer.""" + +from torchopt.alias.utils import ( + _get_use_chain_flat, + flip_sign_and_add_weight_decay, + scale_by_neg_lr, +) +from torchopt.combine import chain +from torchopt.transform import trace +from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +__all__ = ['sgd'] + + +# pylint: disable-next=too-many-arguments +def sgd( + lr: ScalarOrSchedule, + momentum: float = 0.0, + dampening: float = 0.0, + weight_decay: float = 0.0, + nesterov: bool = False, + *, + moment_requires_grad: bool = False, + maximize: bool = False, +) -> GradientTransformation: + """Create a functional version of the canonical Stochastic Gradient Descent optimizer. + + This implements stochastic gradient descent. It also includes support for momentum, and nesterov + acceleration, as these are standard practice when using stochastic gradient descent to train + deep neural networks. + + References: + - Sutskever et al., 2013: http://proceedings.mlr.press/v28/sutskever13.pdf + + Args: + lr (float or callable): This is a fixed global scaling factor or a learning rate + scheduler. + momentum (float, optional): The decay rate used by the momentum term. The momentum is not + used when it is set to :const:`0.0`. (default: :const:`0.0`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + dampening (float, optional): Dampening for momentum. (default: :const:`0.0`) + nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created with + flag ``requires_grad=True``, this flag is often used in Meta-Learning algorithms. + (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) + + Returns: + The corresponding :class:`GradientTransformation` instance. + + See Also: + The functional optimizer wrapper :class:`torchopt.FuncOptimizer`. + """ + # pylint: disable=unneeded-not + if not (callable(lr) or lr >= 0.0): # pragma: no cover + raise ValueError(f'Invalid learning rate: {lr}') + if not momentum >= 0.0: # pragma: no cover + raise ValueError(f'Invalid momentum value: {momentum}') + if not weight_decay >= 0.0: # pragma: no cover + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + if nesterov and (momentum <= 0.0 or dampening != 0.0): # pragma: no cover + raise ValueError('Nesterov momentum requires a momentum and zero dampening') + # pylint: enable=unneeded-not + + chain_fn = chain + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay + trace_fn = trace + scale_by_neg_lr_fn = scale_by_neg_lr + + if _get_use_chain_flat(): # default behavior + chain_fn = chain_fn.flat # type: ignore[attr-defined] + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay_fn.flat # type: ignore[attr-defined] + trace_fn = trace_fn.flat # type: ignore[attr-defined] + scale_by_neg_lr_fn = scale_by_neg_lr_fn.flat # type: ignore[attr-defined] + + return chain_fn( + flip_sign_and_add_weight_decay_fn(weight_decay=weight_decay, maximize=maximize), + trace_fn( + momentum=momentum, + dampening=dampening, + nesterov=nesterov, + moment_requires_grad=moment_requires_grad, + ), + scale_by_neg_lr_fn(lr), + ) diff --git a/torchopt/alias/utils.py b/torchopt/alias/utils.py new file mode 100644 index 00000000..0f41e822 --- /dev/null +++ b/torchopt/alias/utils.py @@ -0,0 +1,230 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r"""Utilities for the aliases of preset :class:`GradientTransformation`\s for optimizers.""" + +from __future__ import annotations + +import threading +from typing import TYPE_CHECKING + +from torchopt import pytree +from torchopt.base import EmptyState, GradientTransformation, identity +from torchopt.transform import scale, scale_by_schedule +from torchopt.transform.utils import tree_map_flat, tree_map_flat_ + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import Numeric, OptState, Params, ScalarOrSchedule, Updates + + +__all__ = ['flip_sign_and_add_weight_decay', 'scale_by_neg_lr'] + + +__USE_CHAIN_FLAT_LOCK = threading.Lock() +__USE_CHAIN_FLAT = True + + +def _set_use_chain_flat(use_chain_flat: bool) -> None: # only used for testing purposes + global __USE_CHAIN_FLAT # pylint: disable=global-statement + with __USE_CHAIN_FLAT_LOCK: + __USE_CHAIN_FLAT = use_chain_flat + + +def _get_use_chain_flat() -> bool: # only used for testing purposes + with __USE_CHAIN_FLAT_LOCK: + return __USE_CHAIN_FLAT + + +def flip_sign_and_add_weight_decay( + weight_decay: float = 0.0, + maximize: bool = False, +) -> GradientTransformation: + """Flip the sign of the updates and adds weight decay.""" + return _flip_sign_and_add_weight_decay( + weight_decay=weight_decay, + maximize=maximize, + already_flattened=False, + ) + + +def _flip_sign_and_add_weight_decay_flat( + weight_decay: float = 0.0, + maximize: bool = False, +) -> GradientTransformation: + """Flip the sign of the updates and adds weight decay.""" + return _flip_sign_and_add_weight_decay( + weight_decay=weight_decay, + maximize=maximize, + already_flattened=True, + ) + + +def _flip_sign_and_add_weight_decay( # noqa: C901 + weight_decay: float = 0.0, + maximize: bool = False, + *, + already_flattened: bool = False, +) -> GradientTransformation: + """Flip the sign of the updates and adds weight decay.""" + # pylint: disable-next=unneeded-not + if not weight_decay >= 0.0: # pragma: no cover + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + + if not maximize and weight_decay == 0.0: + return identity() + + if already_flattened: + tree_map = tree_map_flat + tree_map_ = tree_map_flat_ + else: + tree_map = pytree.tree_map # type: ignore[assignment] + tree_map_ = pytree.tree_map_ # type: ignore[assignment] + + def init_fn(params: Params) -> OptState: # pylint: disable=unused-argument + return EmptyState() + + if not maximize: # gradient descent + + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, + inplace: bool = True, + ) -> tuple[Updates, OptState]: + assert params is not None, ( + 'Parameters are required for weight decay. ' + 'Call `update(updates, state, params=params)` instead.' + ) + + if inplace: + + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g + if g.requires_grad: + return g.add_(p, alpha=weight_decay) + return g.add_(p.data, alpha=weight_decay) + + tree_map_(f, params, updates) + + else: + + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.add(p, alpha=weight_decay) if g is not None else g + + updates = tree_map(f, params, updates) + + return updates, state + + else: # gradient ascent + if weight_decay == 0.0: + + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: + if inplace: + + def f(g: torch.Tensor) -> torch.Tensor: + return g.neg_() + + tree_map_(f, updates) + + else: + + def f(g: torch.Tensor) -> torch.Tensor: + return g.neg() + + updates = tree_map(f, updates) + + return updates, state + + else: + + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, + inplace: bool = True, + ) -> tuple[Updates, OptState]: + assert params is not None, ( + 'Parameters are required for weight decay. ' + 'Call `update(updates, state, params=params)` instead.' + ) + + if inplace: + + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g + if g.requires_grad: + return g.neg_().add_(p, alpha=weight_decay) + return g.neg_().add_(p.data, alpha=weight_decay) + + tree_map_(f, params, updates) + + else: + + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.neg().add_(p, alpha=weight_decay) if g is not None else g + + updates = tree_map(f, params, updates) + + return updates, state + + return GradientTransformation(init_fn, update_fn) + + +flip_sign_and_add_weight_decay.flat = _flip_sign_and_add_weight_decay_flat # type: ignore[attr-defined] +flip_sign_and_add_weight_decay.impl = _flip_sign_and_add_weight_decay # type: ignore[attr-defined] + + +def scale_by_neg_lr(lr: ScalarOrSchedule) -> GradientTransformation: + """Scale the updates by the negative learning rate.""" + return _scale_by_neg_lr(lr=lr, already_flattened=False) + + +def _scale_by_neg_lr_flat(lr: ScalarOrSchedule) -> GradientTransformation: + return _scale_by_neg_lr(lr=lr, already_flattened=True) + + +def _scale_by_neg_lr( + lr: ScalarOrSchedule, + *, + already_flattened: bool = False, +) -> GradientTransformation: + if not (callable(lr) or lr >= 0.0): # pragma: no cover + raise ValueError(f'Invalid learning rate: {lr}') + + if callable(lr): + + def schedule_wrapper(count: Numeric) -> Numeric: + return -lr(count) + + return scale_by_schedule.impl( # type: ignore[attr-defined] + schedule_wrapper, + already_flattened=already_flattened, + ) + return scale.impl(-lr, already_flattened=already_flattened) # type: ignore[attr-defined] + + +scale_by_neg_lr.flat = _scale_by_neg_lr_flat # type: ignore[attr-defined] +scale_by_neg_lr.impl = _scale_by_neg_lr # type: ignore[attr-defined] diff --git a/torchopt/_src/base.py b/torchopt/base.py similarity index 67% rename from torchopt/_src/base.py rename to torchopt/base.py index f17bf00f..81892e17 100644 --- a/torchopt/_src/base.py +++ b/torchopt/base.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,32 +29,37 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""The base classes for gradient transformation.""" + +from __future__ import annotations import itertools from abc import abstractmethod -from typing import Callable, NamedTuple, Optional, Tuple - -from torchopt._src.typing import Numeric, TensorTree - +from typing import TYPE_CHECKING, Callable, NamedTuple, Protocol +from typing_extensions import Self # Python 3.11+ -try: - from typing import Protocol -except ImportError: - from typing_extensions import Protocol # type: ignore[misc] +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates -OptState = TensorTree # States are arbitrary nests of `torch.Tensor`. -# Parameters are arbitrary nests of `torch.Tensor`. -Params = TensorTree -Updates = Params # Gradient updates are of the same type as parameters. -Schedule = Callable[[Numeric], Numeric] +__all__ = [ + 'ChainedGradientTransformation', + 'EmptyState', + 'GradientTransformation', + 'UninitializedState', + 'identity', +] class EmptyState(NamedTuple): """An empty state for the simplest stateless transformations.""" +class UninitializedState(NamedTuple): + """A state that is not initialized yet.""" + + class TransformInitFn(Protocol): # pylint: disable=too-few-public-methods """A callable type for the :func:`init` step of a :class:`GradientTransformation`. @@ -65,11 +70,10 @@ class TransformInitFn(Protocol): # pylint: disable=too-few-public-methods @abstractmethod def __call__(self, params: Params) -> OptState: - """The `init` function. + """Initialize the gradient transformation state. Args: - params: - The initial value of the parameters. + params (tree of Tensor): The initial value of the parameters. Returns: The initial state of the gradient transformation. @@ -93,18 +97,18 @@ def __call__( updates: Updates, state: OptState, *, - params: Optional[Params] = None, + params: Params | None = None, inplace: bool = True, - ) -> Tuple[Updates, OptState]: - """The `update` function. + ) -> tuple[Updates, OptState]: + """Transform the updates and state. Args: - updates: A tree of candidate updates. - state: The state of the gradient transformation. - params: (optional) - The current value of the parameters. - inplace: (optional) - If :data:`True`, modify updates and state using inplace operations. + updates (tree of Tensor): A tree of candidate updates. + state (tree of Tensor): The state of the gradient transformation. + params (tree of Tensor or None, optional): The current value of the parameters. + (default: :data:`None`) + inplace (bool, optional): If :data:`True`, modify updates and state using inplace + operations. (default: :data:`True`) Returns: The transformed ``updates``, and the updated ``state``. @@ -131,9 +135,9 @@ class GradientTransformation(NamedTuple): optimizer state. update: A pure function which takes as input a pytree of updates (with the same tree structure - as the original params ``pytree`` passed to :attr:`init`), the previous optimizer state - (which may have been initialized using the :attr:`init` function), and optionally the - ``inplace`` flag. The :attr:`update` function then returns the computed gradient + as the original params ``pytree`` passed to ``init``), the previous optimizer state + (which may have been initialized using the ``init`` function), and optionally the + ``inplace`` flag. The ``update`` function then returns the computed gradient updates, and a updates optimizer state. If the ``inplace`` flag is :data:`True`, the output results are the same instance as the input. """ @@ -142,7 +146,7 @@ class GradientTransformation(NamedTuple): update: TransformUpdateFn # pylint: disable-next=redefined-builtin - def chain(self, next: 'GradientTransformation') -> 'ChainedGradientTransformation': + def chain(self, next: GradientTransformation) -> ChainedGradientTransformation: """Chain two gradient transformations together.""" return ChainedGradientTransformation(self, next) @@ -154,29 +158,40 @@ class ChainedGradientTransformation(GradientTransformation): gradient transformations. """ - transformations: Tuple[GradientTransformation, ...] + transformations: tuple[GradientTransformation, ...] - def __new__(cls, *transformations: GradientTransformation) -> 'ChainedGradientTransformation': - """Creates a new chained gradient transformation.""" + def __new__(cls, *transformations: GradientTransformation) -> Self: + """Create a new chained gradient transformation.""" transformations = tuple( itertools.chain.from_iterable( - t.transformations - if isinstance(t, ChainedGradientTransformation) - else ((t,) if not isinstance(t, IdentityGradientTransformation) else ()) + ( + t.transformations + if isinstance(t, ChainedGradientTransformation) + else ((t,) if not isinstance(t, IdentityGradientTransformation) else ()) + ) for t in transformations - ) + ), ) + if len(transformations) == 0: + transformations = (IdentityGradientTransformation(),) + init_fns, update_fns = tuple(zip(*transformations)) - def init_fn(params): + def init_fn(params: Params) -> OptState: return tuple(fn(params) for fn in init_fns) - def update_fn(updates, state, *, params=None, inplace=True): + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, + inplace: bool = True, + ) -> tuple[Updates, OptState]: if len(update_fns) != len(state): raise ValueError( 'The number of updates and states has to be the same in chain! Make sure you' - 'have called init first!' + 'have called init first!', ) new_state = [] for s, fn in zip(state, update_fns): # pylint: disable=invalid-name @@ -188,16 +203,15 @@ def update_fn(updates, state, *, params=None, inplace=True): instance.transformations = transformations return instance - def __str__(self): - """Returns a string representation of the chained gradient transformation.""" - return '{}(\n {}\n)'.format( - self.__class__.__name__, ',\n '.join(repr(t) for t in self.transformations) + def __repr__(self) -> str: + """Return a string representation of the chained gradient transformation.""" + return '{}(\n {},\n)'.format( + self.__class__.__name__, + ',\n '.join(repr(t) for t in self.transformations), ) - __repr__ = __str__ - def __eq__(self, other: object) -> bool: - """Returns whether two chained gradient transformations are equal.""" + """Return whether two chained gradient transformations are equal.""" if isinstance(other, ChainedGradientTransformation): return self.transformations == other.transformations if isinstance(other, GradientTransformation): @@ -205,44 +219,43 @@ def __eq__(self, other: object) -> bool: return False def __hash__(self) -> int: - """Returns the hash of the chained gradient transformation.""" + """Return the hash of the chained gradient transformation.""" return hash(self.transformations) - def __getstate__(self) -> Tuple[GradientTransformation, ...]: - """Returns the state of the chained gradient transformation for serialization.""" + def __getstate__(self) -> tuple[GradientTransformation, ...]: + """Return the state of the chained gradient transformation for serialization.""" return self.transformations - def __setstate__(self, state: Tuple[GradientTransformation, ...]) -> None: - """Sets the state of the chained gradient transformation from serialization.""" + def __setstate__(self, state: tuple[GradientTransformation, ...]) -> None: + """Set the state of the chained gradient transformation from serialization.""" self.transformations = state - def __reduce__(self) -> Tuple[Callable, Tuple[Tuple[GradientTransformation, ...]]]: - """Serialization support for chained gradient transformation.""" + def __reduce__(self) -> tuple[Callable, tuple[tuple[GradientTransformation, ...]]]: + """Serialize the chained gradient transformation.""" return ChainedGradientTransformation, (self.transformations,) class IdentityGradientTransformation(GradientTransformation): """A gradient transformation that does nothing.""" - def __new__(cls): + def __new__(cls) -> Self: """Create a new gradient transformation that does nothing.""" return super().__new__(cls, init=cls.init_fn, update=cls.update_fn) @staticmethod def init_fn(params: Params) -> OptState: # pylint: disable=unused-argument - """Returns empty state.""" + """Return empty state.""" return EmptyState() @staticmethod - # pylint: disable-next=unused-argument def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, - inplace: bool = True, - ) -> Tuple[Updates, OptState]: - """Returns updates unchanged.""" + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, # pylint: disable=unused-argument + ) -> tuple[Updates, OptState]: + """Return updates unchanged.""" return updates, state diff --git a/torchopt/clip.py b/torchopt/clip.py new file mode 100644 index 00000000..d64afc58 --- /dev/null +++ b/torchopt/clip.py @@ -0,0 +1,107 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/clip_grad.py +# ============================================================================== +"""Utilities for gradient clipping.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from torchopt import pytree +from torchopt.base import EmptyState, GradientTransformation + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates + + +__all__ = ['clip_grad_norm'] + + +ClipState = EmptyState + + +def clip_grad_norm( + max_norm: float, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, +) -> GradientTransformation: + """Clip gradient norm of an iterable of parameters. + + Args: + max_norm (float): The maximum absolute value for each element in the update. + norm_type (float, optional): Type of the used p-norm. Can be ``'inf'`` for infinity norm. + (default: :const:`2.0`) + error_if_nonfinite (bool, optional): If :data:`True`, an error is thrown if the total norm + of the gradients from ``updates`` is ``nan``, ``inf``, or ``-inf``. + (default: :data:`False`) + + Returns: + An ``(init_fn, update_fn)`` tuple. + """ + + def init_fn(params: Params) -> OptState: # pylint: disable=unused-argument + return ClipState() + + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: + available_updates = pytree.tree_leaves(updates) + if len(available_updates) == 0: + return updates, state + device = available_updates[0].device + with torch.no_grad(): + if norm_type == torch.inf: + norms = [p.abs().max().to(device) for p in available_updates] + total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) + else: + total_norm = torch.norm( + torch.stack([torch.norm(p, norm_type).to(device) for p in available_updates]), + norm_type, + ) + if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f'The total norm of order {norm_type} for gradients from `parameters` is ' + f'non-finite, so it cannot be clipped. To disable this error and scale the ' + f'gradients by the non-finite norm anyway, set `error_if_nonfinite=False`', + ) + clip_coefficient = max_norm / (float(total_norm) + 1e-6) + # Note: multiplying by the clamped coefficient is redundant when the coefficient is + # clamped to 1, but doing so avoids a `if clip_coefficient < 1:` conditional which + # can require a CPU <=> device synchronization when the gradients do not reside in + # CPU memory. + clip_coefficient_clamped = min(clip_coefficient, 1.0) + if inplace: + + def f(g: torch.Tensor) -> torch.Tensor: + return g.mul_(clip_coefficient_clamped) + + else: + + def f(g: torch.Tensor) -> torch.Tensor: + return g.mul(clip_coefficient_clamped) + + new_updates = pytree.tree_map(f, updates) + return new_updates, state + + return GradientTransformation(init_fn, update_fn) diff --git a/torchopt/combine.py b/torchopt/combine.py new file mode 100644 index 00000000..15345286 --- /dev/null +++ b/torchopt/combine.py @@ -0,0 +1,105 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/alias.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities to define a chained transformation.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from torchopt import pytree +from torchopt.base import ChainedGradientTransformation, GradientTransformation, identity + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates + + +__all__ = ['chain', 'chain_flat'] + + +def chain(*transformations: GradientTransformation) -> GradientTransformation: + """Apply a list of chainable update transformations. + + Given a sequence of chainable transforms, :func:`chain` returns an :func:`init_fn` that + constructs a ``state`` by concatenating the states of the individual transforms, and returns an + :func:`update_fn` which chains the update transformations feeding the appropriate state to each. + + Args: + *transformations (iterable of GradientTransformation): A sequence of chainable + ``(init_fn, update_fn)`` tuples. + + Returns: + A single ``(init_fn, update_fn)`` tuple. + """ + if len(transformations) == 0: + return identity() + if len(transformations) == 1: + return transformations[0] + return ChainedGradientTransformation(*transformations) + + +def chain_flat(*transformations: GradientTransformation) -> GradientTransformation: + """Wrap around the inner transformations that manipulate the flattened tree structure (:class:``list``). + + Args: + *transformations (iterable of GradientTransformation): A sequence of chainable + ``(init_fn, update_fn)`` tuples. + + Returns: + A single ``(init_fn, update_fn)`` tuple. + """ + if len(transformations) == 0: + return identity() + inner = transformations[0] if len(transformations) == 1 else chain(*transformations) + + def init_fn(params: Params) -> OptState: + return inner.init(pytree.tree_leaves(params, none_is_leaf=True)) + + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, + inplace: bool = True, + ) -> tuple[Updates, OptState]: + flat_updates, treespec = pytree.tree_flatten(updates, none_is_leaf=True) + flat_params = pytree.tree_leaves(params, none_is_leaf=True) if params is not None else None + + flat_updates, state = inner.update(flat_updates, state, params=flat_params, inplace=inplace) + updates: Updates + updates = pytree.tree_unflatten(treespec, flat_updates) + return updates, state + + return GradientTransformation(init_fn, update_fn) + + +chain.flat = chain_flat # type: ignore[attr-defined] diff --git a/torchopt/_src/optimizer/__init__.py b/torchopt/diff/__init__.py similarity index 64% rename from torchopt/_src/optimizer/__init__.py rename to torchopt/diff/__init__.py index 8501fb15..194512f5 100644 --- a/torchopt/_src/optimizer/__init__.py +++ b/torchopt/diff/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,10 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Differentiable Gradient Estimation.""" -from torchopt._src.optimizer import meta -from torchopt._src.optimizer.adam import Adam -from torchopt._src.optimizer.adamw import AdamW -from torchopt._src.optimizer.base import Optimizer -from torchopt._src.optimizer.rmsprop import RMSProp, RMSprop -from torchopt._src.optimizer.sgd import SGD +from torchopt.diff import implicit, zero_order +from torchopt.diff.implicit import ImplicitMetaGradientModule +from torchopt.diff.zero_order import ZeroOrderGradientModule diff --git a/torchopt/_src/optimizer/meta/__init__.py b/torchopt/diff/implicit/__init__.py similarity index 64% rename from torchopt/_src/optimizer/meta/__init__.py rename to torchopt/diff/implicit/__init__.py index ec227474..4cff14c6 100644 --- a/torchopt/_src/optimizer/meta/__init__.py +++ b/torchopt/diff/implicit/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Implicit Meta-Gradient.""" -from torchopt._src.optimizer.meta.adam import MetaAdam -from torchopt._src.optimizer.meta.adamw import MetaAdamW -from torchopt._src.optimizer.meta.base import MetaOptimizer -from torchopt._src.optimizer.meta.rmsprop import MetaRMSProp, MetaRMSprop -from torchopt._src.optimizer.meta.sgd import MetaSGD +from torchopt.diff.implicit import nn +from torchopt.diff.implicit.decorator import custom_root +from torchopt.diff.implicit.nn import ImplicitMetaGradientModule + + +__all__ = ['ImplicitMetaGradientModule', 'custom_root'] diff --git a/torchopt/diff/implicit/decorator.py b/torchopt/diff/implicit/decorator.py new file mode 100644 index 00000000..11ba0153 --- /dev/null +++ b/torchopt/diff/implicit/decorator.py @@ -0,0 +1,506 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/google/jaxopt/blob/main/jaxopt/_src/implicit_diff.py +# ============================================================================== +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implicit Meta-Gradient.""" + +# pylint: disable=invalid-name + +from __future__ import annotations + +import functools +import inspect +from typing import TYPE_CHECKING, Any, Callable, Dict, Sequence, Tuple + +import functorch +import torch +from torch.autograd import Function + +from torchopt import linear_solve, pytree + + +if TYPE_CHECKING: + from torchopt.typing import ( + ListOfOptionalTensors, + ListOfTensors, + TensorOrTensors, + TupleOfOptionalTensors, + TupleOfTensors, + ) + + +__all__ = ['custom_root'] + + +Args = Tuple[Any, ...] +KwArgs = Dict[str, Any] + + +class MaskedOptimalityFn: # pylint: disable=missing-class-docstring,too-few-public-methods + def __init__( + self, + optimality_fn: Callable[..., TensorOrTensors], + solution: TensorOrTensors, + output_is_tensor: bool, + argnums: tuple[int, ...], + *args: Any, + ) -> None: + self.optimality_fn = optimality_fn + self.solution = solution + self.output_is_tensor = output_is_tensor + self.argnums = argnums + + pre_filled = [] + post_filled = [] + for idx, arg in enumerate(args): + if idx + 1 in argnums: # plus 1 because we exclude the first argument + post_filled.append(arg) + else: + pre_filled.append(arg) + self.len_args = len(pre_filled) + len(post_filled) + self.pre_filled = tuple(pre_filled) + self.post_filled = tuple(post_filled) + + def __call__(self, *args: Any) -> TensorOrTensors: + true_args = [] + pre_filled_counter = 0 + for idx in range(self.len_args): + if idx + 1 in self.argnums: # plus 1 because we exclude the first argument + arg = args[idx] + else: + arg = self.pre_filled[pre_filled_counter] + pre_filled_counter += 1 + true_args.append(arg) + if self.output_is_tensor: + return self.optimality_fn(self.solution[0], *true_args) + return self.optimality_fn(self.solution, *true_args) + + +# pylint: disable-next=too-many-arguments,too-many-locals,too-many-branches +def _root_vjp( + optimality_fn: Callable[..., TensorOrTensors], + solution: TupleOfTensors, + args: Args, + grad_outputs: TupleOfTensors, + output_is_tensor: bool, + argnums: tuple[int, ...], + solve: Callable[..., TensorOrTensors], +) -> TupleOfOptionalTensors: + if output_is_tensor: + + def optimality_cond(solution: TupleOfTensors) -> TensorOrTensors: + return optimality_fn(solution[0], *args) + + else: + + def optimality_cond(solution: TupleOfTensors) -> TensorOrTensors: + return optimality_fn(solution, *args) + + _, optimality_cond_vjp_fn, *_ = functorch.vjp(optimality_cond, solution) + + # Compute the multiplication A^T u = (u^T A)^T. + if output_is_tensor: + + def matvec(u: TupleOfTensors) -> TupleOfTensors: + return optimality_cond_vjp_fn(u[0])[0] + + else: + + def matvec(u: TupleOfTensors) -> TupleOfTensors: + return optimality_cond_vjp_fn(u)[0] + + # The solution of A^T u = v, where + # A = jacobian(optimality_fn, argnums=0) + # v = -grad_outputs. + v: TupleOfTensors = pytree.tree_map(torch.neg, grad_outputs) # type: ignore[arg-type,assignment] + u: TupleOfTensors = solve(matvec, v) # type: ignore[assignment] + + masked_optimality_fn = MaskedOptimalityFn( + optimality_fn, + solution, + output_is_tensor, + argnums, + *args, + ) + + _, optimality_vjp_fn, *_ = functorch.vjp( + masked_optimality_fn, + *masked_optimality_fn.post_filled, + ) + + output: TupleOfTensors + output = optimality_vjp_fn(u[0]) if output_is_tensor else optimality_vjp_fn(u) + + # Prepend None as the vjp for init_params. + true_output: ListOfOptionalTensors = [None] + for idx in range(masked_optimality_fn.len_args): + if idx + 1 in argnums: # plus 1 because we exclude the first argument + true_output.append(output[idx]) + else: + true_output.append(None) + + return tuple(true_output) + + +def _extract_kwargs(kwarg_keys: Sequence[str], flat_args: tuple[Any, ...]) -> tuple[Args, KwArgs]: + nargs = len(flat_args) - len(kwarg_keys) + args, kwarg_vals = flat_args[:nargs], flat_args[nargs:] + kwargs = dict(zip(kwarg_keys, kwarg_vals)) + return args, kwargs + + +def _signature_bind(signature: inspect.Signature, *args: Any, **kwargs: Any) -> tuple[Args, KwArgs]: + bound = signature.bind(*args, **kwargs) + bound.apply_defaults() + return bound.args, bound.kwargs + + +def _signature_bind_and_match( + signature: inspect.Signature, + *args: Any, + **kwargs: Any, +) -> tuple[Args, KwArgs, Callable[[Args], tuple[Args, KwArgs]]]: + # We want to bind *args and **kwargs based on the provided signature, but also to associate the + # resulting positional arguments back. To achieve this, we lift arguments to a triple: + # + # (was_kwarg, ref, value) + # + # where ref is an index position (int) if the original argument was from *args and a dictionary + # key if the original argument was from **kwargs. After binding to the inspected signature, we + # use the tags to associate the resolved positional arguments back to their args and kwargs + # source. + + args = [(False, i, v) for i, v in enumerate(args)] + kwargs = {k: (True, k, v) for (k, v) in kwargs.items()} + bound = signature.bind(*args, **kwargs) + + mapping = [(was_kwarg, ref) for was_kwarg, ref, _ in bound.args] + + def map_args_back(out_args: Args) -> tuple[Args, KwArgs]: + src_args = [None] * len(args) + src_kwargs = {} + for (was_kwarg, ref), out_arg in zip(mapping, out_args): + if was_kwarg: + src_kwargs[ref] = out_arg + else: + src_args[ref] = out_arg + return tuple(src_args), src_kwargs + + out_args = tuple(v for _, _, v in bound.args) + out_kwargs = {k: v for k, (_, _, v) in bound.kwargs.items()} + return out_args, out_kwargs, map_args_back + + +def _split_tensor_and_others( + mixed_tuple: tuple[Any, ...], +) -> tuple[pytree.PyTreeSpec, tuple[bool, ...], TupleOfTensors, tuple[Any, ...]]: + flattened: list[Any] + flattened, treespec = pytree.tree_flatten(mixed_tuple, none_is_leaf=True) # type: ignore[arg-type] + tensors: ListOfTensors = [] + non_tensors: list[Any] = [] + is_tensor_mask: list[bool] = [] + for item in flattened: + is_tensor = isinstance(item, torch.Tensor) + is_tensor_mask.append(is_tensor) + if is_tensor: + tensors.append(item.data) + else: + non_tensors.append(item) + return treespec, tuple(is_tensor_mask), tuple(tensors), tuple(non_tensors) + + +def _merge_tensor_and_others( + treespec: pytree.PyTreeSpec, + is_tensor_mask: tuple[bool, ...], + tensors: TupleOfTensors, + non_tensors: tuple[Any, ...], +) -> tuple[Any, ...]: + tensor_counter = 0 + non_tensor_counter = 0 + results = [] + for is_tensor in is_tensor_mask: + if is_tensor: + results.append(tensors[tensor_counter]) + tensor_counter += 1 + else: + results.append(non_tensors[non_tensor_counter]) + non_tensor_counter += 1 + return pytree.tree_unflatten(treespec, results) # type: ignore[return-value] + + +# pylint: disable-next=too-many-arguments,too-many-statements +def _custom_root( # noqa: C901 + solver_fn: Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]], + optimality_fn: Callable[..., TensorOrTensors], + solve: Callable[..., TensorOrTensors], + argnums: tuple[int, ...], + has_aux: bool, + reference_signature: inspect.Signature | Callable | None = None, +) -> Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]]: + solver_fn_signature = inspect.signature(solver_fn) + + if reference_signature is None: + reference_signature = inspect.signature(optimality_fn) + elif not isinstance(reference_signature, inspect.Signature): + # If is a CompositeLinearFunction, accesses subfn. + # Otherwise, assumes a Callable. + fn = getattr(reference_signature, 'subfn', reference_signature) + reference_signature = inspect.signature(fn) + + def make_custom_vjp_solver_fn( # noqa: C901 + solver_fn: Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]], + kwarg_keys: Sequence[str], + args_signs: tuple[tuple[int, int, type[tuple | list] | None], ...], + ) -> type[Function]: + # pylint: disable-next=missing-class-docstring,abstract-method + class ImplicitMetaGradient(Function): + @staticmethod + def forward( # pylint: disable=arguments-differ + ctx: Any, + *flat_args: Any, + ) -> tuple[Any, ...]: + output, aux, output_is_tensor = None, None, False + + args = [] + for offset, nargs, arg_seq_type in args_signs: + if arg_seq_type is not None: + args.append(arg_seq_type(flat_args[offset : offset + nargs])) + else: + args.append(flat_args[offset]) + args = tuple(args) + + args, kwargs = _extract_kwargs(kwarg_keys, args) + output = solver_fn(*args, **kwargs) + if has_aux: + if not (isinstance(output, tuple) and len(output) == 2): + raise RuntimeError( + f'custom_root(optimality_fn)(solver_fn)(*args): output of function ' + f'solver_fn should be a tuple: (output, aux) if has_aux is True. ' + f'Got {output}', + ) + output, aux = output + if isinstance(output, torch.Tensor): + output_is_tensor = True + output = (output,) + elif not (isinstance(output, tuple) and all(map(torch.is_tensor, output))): + raise RuntimeError( + f'custom_root(optimality_fn)(solver_fn)(*args): output of function ' + f'solver_fn should be a torch.Tensor or a tuple of torch.Tensor. ' + f'Got {output}', + ) + output = tuple(t.data for t in output) + + ( + args_treespec, + args_is_tensor_mask, + args_tensors, + args_non_tensors, + ) = _split_tensor_and_others(args) + ctx.args_treespec = args_treespec + ctx.args_is_tensor_mask = args_is_tensor_mask + ctx.args_non_tensors = args_non_tensors + + ctx.save_for_backward(*output, *args_tensors) + ctx.output_is_tensor = output_is_tensor + + return (*output, aux, output_is_tensor, type(output)) + + @staticmethod + def backward( # pylint: disable=too-many-locals + ctx: Any, + *grad_outputs: Any, + ) -> TupleOfTensors: + grad_outputs: TupleOfTensors = grad_outputs[:-3] + + saved_tensors = ctx.saved_tensors + output = saved_tensors[: len(grad_outputs)] + args_tensors = saved_tensors[len(grad_outputs) :] + args_treespec = ctx.args_treespec + args_is_tensor_mask = ctx.args_is_tensor_mask + args_non_tensors = ctx.args_non_tensors + args = _merge_tensor_and_others( + args_treespec, + args_is_tensor_mask, + args_tensors, + args_non_tensors, + ) + + args, kwargs = _extract_kwargs(kwarg_keys, args) + + bound_args, bound_kwargs, map_args_back = _signature_bind_and_match( + reference_signature, # type: ignore[arg-type] + *args, + **kwargs, + ) + if bound_kwargs: + raise TypeError( + f'keyword arguments to solver_fn could not be resolved to positional ' + f'arguments based on the signature {reference_signature}. This can ' + f'happen under custom_root if optimality_fn takes catch-all **kwargs, or ' + f'under custom_fixed_point if fixed_point_fn takes catch-all **kwargs, ' + f'both of which are currently unsupported.', + ) + + # Compute VJPs w.r.t. args. + vjps = _root_vjp( + optimality_fn=optimality_fn, + solution=output, + args=bound_args[1:], + grad_outputs=grad_outputs, + output_is_tensor=ctx.output_is_tensor, + argnums=argnums, + solve=solve, + ) + + args_vjps, kwargs_vjps = map_args_back(vjps) + ordered_vjps = tuple(args_vjps) + tuple(kwargs_vjps[k] for k in kwargs) + true_vjps = [] + for (_, _, arg_seq_type), vjp in zip(args_signs, ordered_vjps): + if arg_seq_type is not None: + true_vjps.extend(vjp) + else: + true_vjps.append(vjp) + return tuple(true_vjps) + + return ImplicitMetaGradient + + @functools.wraps(solver_fn) + def wrapped_solver_fn( + *args: Any, + **kwargs: Any, + ) -> TensorOrTensors | tuple[TensorOrTensors, Any]: + args, kwargs = _signature_bind(solver_fn_signature, *args, **kwargs) + keys, vals = list(kwargs.keys()), list(kwargs.values()) + + args_signs: list[tuple[int, int, type[tuple | list] | None]] = [] + flat_args: list[Any] = [] + args_offset = 0 + for idx, arg in enumerate(args): + if idx in argnums: + if isinstance(arg, torch.Tensor): + args_signs.append((args_offset, 1, None)) # start position, None + flat_args.append(arg) + args_offset += 1 + elif isinstance(arg, (tuple, list)) and all(map(torch.is_tensor, arg)): + nargs = len(arg) + args_signs.append( + (args_offset, nargs, type(arg)), # start position, sequence type + ) + flat_args.extend(arg) + args_offset += nargs + else: + raise RuntimeError( + 'custom_root(optimality_fn)(solver_fn)(*args): argument of function ' + 'solver_fn specified with `argnums` should be a torch.Tensor or a tuple of ' + 'torch.Tensor', + ) + else: + args_signs.append((args_offset, 1, None)) # start position, None + flat_args.append(arg) + args_offset += 1 + + args_signs = tuple(args_signs) + flat_args = tuple(flat_args) + + result = make_custom_vjp_solver_fn(solver_fn, keys, args_signs).apply(*flat_args, *vals) + *output, aux, output_is_tensor, output_type = result + output = output[0] if output_is_tensor else output_type(output) + if has_aux: + return output, aux + return output + + return wrapped_solver_fn + + +def custom_root( + optimality_fn: Callable[..., TensorOrTensors], + argnums: int | tuple[int, ...], + has_aux: bool = False, + solve: Callable[..., TensorOrTensors] | None = None, +) -> Callable[ + [Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]]], + Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]], +]: + """Return a decorator for adding implicit differentiation to a root solver. + + This wrapper should be used as a decorator: + + .. code-block:: python + + def optimality_fn(optimal_params, ...): + ... + return residual + + @custom_root(optimality_fn, argnums=argnums) + def solver_fn(params, arg1, arg2, ...): + ... + return optimal_params + + optimal_params = solver_fn(init_params, ...) + + The first argument to ``optimality_fn`` and ``solver_fn`` is preserved as the parameter input. + The ``argnums`` argument refers to the indices of the variables in ``solver_fn``'s signature. + For example, setting ``argnums=(1, 2)`` will compute the gradient of ``optimal_params`` with + respect to ``arg1`` and ``arg2`` in the above snippet. Note that, except the first argument, the + keyword arguments of the ``optimality_fn`` should be a subset of the ones of ``solver_fn``. + **In best practice, the ``optimality_fn`` should have the same signature as ``solver_fn``.** + + Args: + optimality_fn (callable): An equation function, ``optimality_fn(params, *args)``. The + invariant is ``optimality_fn(solution, *args) == 0`` at the solution / root of + ``solution``. + argnums (int or tuple of int): Specifies arguments to compute gradients with respect to. The + ``argnums`` can be an integer or a tuple of integers, which respect to the zero-based + indices of the arguments of the ``solver_fn(params, *args)`` function. The argument + ``params`` is included for the counting, while it is indexed as ``argnums=0``. + has_aux (bool, optional): Whether the decorated solver function returns auxiliary data. + (default: :data:`False`) + solve (callable, optional): A linear solver of the form ``solve(matvec, b)``. + (default: :func:`linear_solve.solve_normal_cg`) + + Returns: + A solver function decorator, i.e., ``custom_root(optimality_fn)(solver_fn)``. + """ + if isinstance(argnums, int): + assert argnums != 0 + argnums = (argnums,) + else: + assert 0 not in argnums + + if solve is None: + solve = linear_solve.solve_normal_cg() + + return functools.partial( + _custom_root, + optimality_fn=optimality_fn, + solve=solve, + argnums=argnums, + has_aux=has_aux, + ) diff --git a/torchopt/diff/implicit/nn/__init__.py b/torchopt/diff/implicit/nn/__init__.py new file mode 100644 index 00000000..e91ef8ed --- /dev/null +++ b/torchopt/diff/implicit/nn/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The base class for differentiable implicit meta-gradient models.""" + +import torchopt.nn.module # preload to resolve circular references +from torchopt.diff.implicit.nn.module import ImplicitMetaGradientModule + + +__all__ = ['ImplicitMetaGradientModule'] + +del torchopt diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py new file mode 100644 index 00000000..6b214cb8 --- /dev/null +++ b/torchopt/diff/implicit/nn/module.py @@ -0,0 +1,297 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The base class for differentiable implicit meta-gradient models.""" + +# pylint: disable=redefined-builtin + +from __future__ import annotations + +import abc +import functools +import inspect +import itertools +from typing import TYPE_CHECKING, Any, Iterable + +import functorch + +from torchopt.diff.implicit.decorator import custom_root +from torchopt.nn.module import MetaGradientModule +from torchopt.nn.stateless import reparametrize, swap_state + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import LinearSolver, TupleOfTensors + + +__all__ = ['ImplicitMetaGradientModule'] + + +def _stateless_objective_fn( + flat_params: TupleOfTensors, + flat_meta_params: TupleOfTensors, + params_names: Iterable[str], + meta_params_names: Iterable[str], + self: ImplicitMetaGradientModule, + /, + *input: Any, + **kwargs: Any, +) -> torch.Tensor: + with reparametrize( + self, + itertools.chain( + zip(params_names, flat_params), + zip(meta_params_names, flat_meta_params), + ), + ): + return self.objective(*input, **kwargs) + + +def _stateless_optimality_fn( + flat_params: TupleOfTensors, + flat_meta_params: TupleOfTensors, + params_names: Iterable[str], + meta_params_names: Iterable[str], + self: ImplicitMetaGradientModule, + /, + *input: Any, + **kwargs: Any, +) -> TupleOfTensors: + with reparametrize( + self, + itertools.chain( + zip(params_names, flat_params), + zip(meta_params_names, flat_meta_params), + ), + ): + return self.optimality(*input, **kwargs) + + +def make_optimality_from_objective( + cls: type[ImplicitMetaGradientModule], +) -> type[ImplicitMetaGradientModule]: + """Derive the optimality function of the objective function.""" + static_super_objective = inspect.getattr_static(ImplicitMetaGradientModule, 'objective') + static_cls_objective = inspect.getattr_static(cls, 'objective', static_super_objective) + if static_cls_objective is static_super_objective: + raise TypeError('The objective function is not defined.') + + def optimality(self: ImplicitMetaGradientModule, *input: Any, **kwargs: Any) -> TupleOfTensors: + named_params = tuple(self.named_parameters()) + named_meta_params = tuple(self.named_meta_parameters()) + if len(named_params) == 0: + raise RuntimeError('The module has no parameters.') + if len(named_meta_params) == 0: + raise RuntimeError('The module has no meta-parameters.') + params_names, flat_params = tuple(zip(*named_params)) + meta_params_names, flat_meta_params = tuple(zip(*named_meta_params)) + + objective_grad_fn = functorch.grad(_stateless_objective_fn, argnums=0) + return objective_grad_fn( + flat_params, + flat_meta_params, + params_names, + meta_params_names, + self, + *input, + **kwargs, + ) + + cls.optimality = optimality # type: ignore[method-assign] + return cls + + +def enable_implicit_gradients( + cls: type[ImplicitMetaGradientModule], +) -> type[ImplicitMetaGradientModule]: + """Enable implicit gradients for the :func:`solve` method.""" + cls_solve = cls.solve + if getattr(cls_solve, '__implicit_gradients_enabled__', False): + raise TypeError('Implicit gradients are already enabled for the `solve` method.') + + solve_kwargs = {'solve': cls.linear_solve} if cls.linear_solve is not None else {} + + @custom_root(_stateless_optimality_fn, argnums=1, has_aux=True, **solve_kwargs) + def stateless_solver_fn( + # pylint: disable=unused-argument + flat_params: TupleOfTensors, + flat_meta_params: TupleOfTensors, + params_names: Iterable[str], + meta_params_names: Iterable[str], + # pylint: enable=unused-argument + self: ImplicitMetaGradientModule, + /, + *input: Any, + **kwargs: Any, + ) -> tuple[TupleOfTensors, Any]: + """Solve the optimization problem.""" + output = cls_solve(self, *input, **kwargs) + flat_optimal_params = tuple(p.detach_() for p in self.parameters()) + return flat_optimal_params, output + + @functools.wraps(cls_solve) + def wrapped(self: ImplicitMetaGradientModule, *input: Any, **kwargs: Any) -> Any: + """Solve the optimization problem.""" + named_params = tuple(self.named_parameters()) + named_meta_params = tuple(self.named_meta_parameters()) + if len(named_params) == 0: + raise RuntimeError('The module has no parameters.') + if len(named_meta_params) == 0: + raise RuntimeError('The module has no meta-parameters.') + params_names, flat_params = tuple(zip(*named_params)) + meta_params_names, flat_meta_params = tuple(zip(*named_meta_params)) + + flat_optimal_params, output = stateless_solver_fn( + flat_params, + flat_meta_params, + params_names, + meta_params_names, + self, + *input, + **kwargs, + ) + swap_state(self, zip(params_names, flat_optimal_params)) + return output + + wrapped.__implicit_gradients_enabled__ = True # type: ignore[attr-defined] + cls.solve = wrapped # type: ignore[method-assign] + return cls + + +class ImplicitMetaGradientModule(MetaGradientModule, metaclass=abc.ABCMeta): + """The base class for differentiable implicit meta-gradient models.""" + + _custom_optimality: bool + _custom_objective: bool + linear_solve: LinearSolver | None + + def __init_subclass__(cls, linear_solve: LinearSolver | None = None) -> None: + """Validate and initialize the subclass.""" + super().__init_subclass__() + cls.linear_solve = linear_solve + + static_super_optimality = inspect.getattr_static(ImplicitMetaGradientModule, 'optimality') + static_super_objective = inspect.getattr_static(ImplicitMetaGradientModule, 'objective') + static_cls_optimality = inspect.getattr_static(cls, 'optimality') + static_cls_objective = inspect.getattr_static(cls, 'objective') + cls._custom_optimality = static_cls_optimality is not static_super_optimality + cls._custom_objective = static_cls_objective is not static_super_objective + + if cls._custom_optimality: + if isinstance(static_cls_optimality, staticmethod): + raise TypeError('method optimality() must not be a staticmethod.') + if isinstance(static_cls_optimality, classmethod): + raise TypeError('method optimality() must not be a classmethod.') + if not callable(static_cls_optimality): + raise TypeError('method optimality() must be callable.') + elif not cls._custom_objective: + raise TypeError( + 'ImplicitMetaGradientModule requires either an optimality() method or an objective() method', + ) + else: + if isinstance(static_cls_objective, staticmethod): + raise TypeError('method objective() must not be a staticmethod.') + if isinstance(static_cls_objective, classmethod): + raise TypeError('method objective() must not be a classmethod.') + if not callable(static_cls_objective): + raise TypeError('method objective() must be callable.') + + make_optimality_from_objective(cls) + + enable_implicit_gradients(cls) + + @abc.abstractmethod + def solve(self, *input: Any, **kwargs: Any) -> Any: + """Solve the inner optimization problem. + + .. warning:: + For gradient-based optimization methods, the parameter inputs should be explicitly + specified in the :func:`torch.autograd.backward` function as argument ``inputs``. + Otherwise, if not provided, the gradient is accumulated into all the leaf Tensors + (including the meta-parameters) that were used to compute the objective output. + Alternatively, please use :func:`torch.autograd.grad` instead. + + Examples: + .. code-block:: python + + def solve(self, batch, labels): + parameters = tuple(self.parameters()) + optimizer = torch.optim.Adam(parameters, lr=1e-3) + with torch.enable_grad(): + for _ in range(100): + loss = self.objective(batch, labels) + optimizer.zero_grad() + # Only update the `.grad` attribute for parameters + # and leave the meta-parameters unchanged + loss.backward(inputs=parameters) + optimizer.step() + return self + """ + raise NotImplementedError # update parameters + + def optimality(self, *input: Any, **kwargs: Any) -> TupleOfTensors: + r"""Compute the optimality residual. + + This method stands for the optimality residual to the optimal parameters after solving the + inner optimization problem (:meth:`solve`), i.e.: + + .. code-block:: python + + module.solve(*input, **kwargs) + module.optimality(*input, **kwargs) # -> 0 + + 1. For gradient-based optimization, the :meth:`optimality` function is the KKT condition, + usually it is the gradients of the :meth:`objective` function with respect to the module + parameters (not the meta-parameters). If this method is not implemented, it will be + automatically derived from the gradient of the :meth:`objective` function. + + .. math:: + + \text{optimality residual} = \nabla_{\boldsymbol{x}} f (\boldsymbol{x}, \boldsymbol{\theta}) \to \boldsymbol{0} + + where :math:`\boldsymbol{x}` is the joint vector of the module parameters and + :math:`\boldsymbol{\theta}` is the joint vector of the meta-parameters. + + References: + - Karush-Kuhn-Tucker (KKT) conditions: https://en.wikipedia.org/wiki/Karush-Kuhn-Tucker_conditions + + 2. For fixed point iteration, the :meth:`optimality` function can be the residual of the + parameters between iterations, i.e.: + + .. math:: + + \text{optimality residual} = f (\boldsymbol{x}, \boldsymbol{\theta}) - \boldsymbol{x} \to \boldsymbol{0} + + where :math:`\boldsymbol{x}` is the joint vector of the module parameters and + :math:`\boldsymbol{\theta}` is the joint vector of the meta-parameters. + + Returns: + A tuple of tensors, the optimality residual to the optimal parameters after solving the + inner optimization problem. The returned tensors should correspond to the outputs of + `tuple(self.parameters())`. + """ # pylint: disable=line-too-long + raise NotImplementedError + + def objective(self, *input: Any, **kwargs: Any) -> torch.Tensor: + """Compute the objective function value. + + This method is used to calculate the :meth:`optimality` if it is not implemented. + Otherwise, this method is optional. + + Returns: + A scalar tensor (``dim=0``), the objective function value. + """ + raise NotImplementedError diff --git a/torchopt/diff/zero_order/__init__.py b/torchopt/diff/zero_order/__init__.py new file mode 100644 index 00000000..4369f4e5 --- /dev/null +++ b/torchopt/diff/zero_order/__init__.py @@ -0,0 +1,43 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Zero-Order Gradient.""" + +import sys as _sys +from types import ModuleType as _ModuleType +from typing import Any, Callable + +import torch + +from torchopt.diff.zero_order import nn +from torchopt.diff.zero_order.decorator import zero_order +from torchopt.diff.zero_order.nn import ZeroOrderGradientModule + + +__all__ = ['ZeroOrderGradientModule', 'zero_order'] + + +class _CallableModule(_ModuleType): # pylint: disable=too-few-public-methods + def __call__( + self, + *args: Any, + **kwargs: Any, + ) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: + return self.zero_order(*args, **kwargs) + + +# Replace entry in sys.modules for this module with an instance of _CallableModule +_modself = _sys.modules[__name__] +_modself.__class__ = _CallableModule +del _sys, _ModuleType, _modself, _CallableModule diff --git a/torchopt/diff/zero_order/decorator.py b/torchopt/diff/zero_order/decorator.py new file mode 100644 index 00000000..e498b43c --- /dev/null +++ b/torchopt/diff/zero_order/decorator.py @@ -0,0 +1,417 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Zero-Order Gradient Estimation.""" + +from __future__ import annotations + +import functools +import itertools +from typing import Any, Callable, Literal, Sequence +from typing_extensions import TypeAlias # Python 3.10+ + +import torch +from torch.autograd import Function + +from torchopt import pytree +from torchopt.typing import ListOfTensors, Numeric, Samplable, SampleFunc, TupleOfOptionalTensors + + +class WrappedSamplable(Samplable): # pylint: disable=too-few-public-methods + """A wrapper that wraps a sample function to a :class:`Samplable` object.""" + + def __init__(self, sample_fn: SampleFunc) -> None: + """Wrap a sample function to make it a :class:`Samplable` object.""" + self.sample_fn = sample_fn + + def sample( + self, + sample_shape: torch.Size = torch.Size(), # noqa: B008 + ) -> torch.Tensor | Sequence[Numeric]: + # pylint: disable-next=line-too-long + """Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.""" + return self.sample_fn(sample_shape) + + +def _zero_order_naive( # noqa: C901 # pylint: disable=too-many-statements + fn: Callable[..., torch.Tensor], + distribution: Samplable, + argnums: tuple[int, ...], + num_samples: int, + sigma: float, +) -> Callable[..., torch.Tensor]: + @functools.wraps(fn) + def apply(*args: Any) -> torch.Tensor: # noqa: C901 # pylint: disable=too-many-statements + diff_params = [args[argnum] for argnum in argnums] + flat_diff_params: list[Any] + flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params) # type: ignore[arg-type] + + class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-method + @staticmethod + def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: + flat_diff_params = args[:-1] + origin_args = list(args[-1][0]) + flat_args: list[Any] + flat_args, args_treespec = pytree.tree_flatten(origin_args, none_is_leaf=True) # type: ignore[arg-type] + ctx.args_treespec = args_treespec + + is_tensor_mask = [] + tensors = [] + non_tensors = [] + for origin_arg in flat_args: + is_tensor = isinstance(origin_arg, torch.Tensor) + is_tensor_mask.append(is_tensor) + if is_tensor: + tensors.append(origin_arg) + else: + non_tensors.append(origin_arg) + + ctx.non_tensors = non_tensors + ctx.is_tensor_mask = is_tensor_mask + + output = fn(*origin_args) + if not isinstance(output, torch.Tensor): + raise TypeError('`output` must be a tensor.') + if output.ndim != 0: + raise RuntimeError('`output` must be a scalar tensor.') + ctx.save_for_backward(*flat_diff_params, *tensors) + ctx.len_args = len(args) + ctx.len_params = len(flat_diff_params) + return output + + @staticmethod + def backward( # pylint: disable=too-many-locals + ctx: Any, + *grad_outputs: Any, + ) -> TupleOfOptionalTensors: + saved_tensors = ctx.saved_tensors + flat_diff_params = saved_tensors[: ctx.len_params] + tensors = saved_tensors[ctx.len_params :] + non_tensors = ctx.non_tensors + + flat_args = [] + tensors_counter = 0 + non_tensors_counter = 0 + for is_tensor in ctx.is_tensor_mask: + if is_tensor: + flat_args.append(tensors[tensors_counter]) + tensors_counter += 1 + else: + flat_args.append(non_tensors[non_tensors_counter]) + non_tensors_counter += 1 + + args: list[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment] + + def add_perturbation( + tensor: torch.Tensor, + noise: torch.Tensor | Numeric, + ) -> torch.Tensor: + return tensor.add(noise, alpha=sigma) + + param_grads: ListOfTensors = [0.0 for _ in range(len(flat_diff_params))] # type: ignore[misc] + + for _ in range(num_samples): + noises = [distribution.sample(sample_shape=p.shape) for p in flat_diff_params] + flat_noisy_params = list( + itertools.starmap(add_perturbation, zip(flat_diff_params, noises)), + ) + noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment] + diff_params_treespec, + flat_noisy_params, + ) + + for argnum, noisy_param in zip(argnums, noisy_params): + args[argnum] = noisy_param + + output = fn(*args) + weighted_grad = grad_outputs[0].mul(output).mul_(1 / sigma) + + for i, noise in enumerate(noises): + param_grads[i] += weighted_grad * noise + + for i in range(len(flat_diff_params)): + param_grads[i] /= num_samples + + return tuple(param_grads + [None] * (ctx.len_args - len(flat_diff_params))) + + return ZeroOrder.apply(*flat_diff_params, (args,)) + + return apply + + +def _zero_order_forward( # noqa: C901 # pylint: disable=too-many-statements + fn: Callable[..., torch.Tensor], + distribution: Samplable, + argnums: tuple[int, ...], + num_samples: int, + sigma: float, +) -> Callable[..., torch.Tensor]: + @functools.wraps(fn) + def apply(*args: Any) -> torch.Tensor: # noqa: C901 # pylint: disable=too-many-statements + diff_params = [args[argnum] for argnum in argnums] + flat_diff_params: list[Any] + flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params) # type: ignore[arg-type] + + class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-method + @staticmethod + def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: + flat_diff_params = args[:-1] + origin_args = list(args[-1][0]) + flat_args: list[Any] + flat_args, args_treespec = pytree.tree_flatten(origin_args, none_is_leaf=True) # type: ignore[arg-type] + ctx.args_treespec = args_treespec + + is_tensor_mask = [] + tensors = [] + non_tensors = [] + for origin_arg in flat_args: + is_tensor = isinstance(origin_arg, torch.Tensor) + is_tensor_mask.append(is_tensor) + if is_tensor: + tensors.append(origin_arg) + else: + non_tensors.append(origin_arg) + + ctx.non_tensors = non_tensors + ctx.is_tensor_mask = is_tensor_mask + + output = fn(*origin_args) + if not isinstance(output, torch.Tensor): + raise TypeError('`output` must be a tensor.') + if output.ndim != 0: + raise RuntimeError('`output` must be a scalar tensor.') + ctx.save_for_backward(*flat_diff_params, *tensors, output) + ctx.len_args = len(args) + ctx.len_params = len(flat_diff_params) + return output + + @staticmethod + def backward( # pylint: disable=too-many-locals + ctx: Any, + *grad_outputs: Any, + ) -> TupleOfOptionalTensors: + saved_tensors = ctx.saved_tensors + flat_diff_params = saved_tensors[: ctx.len_params] + tensors = saved_tensors[ctx.len_params : -1] + output = saved_tensors[-1] + non_tensors = ctx.non_tensors + + flat_args = [] + tensors_counter = 0 + non_tensors_counter = 0 + for is_tensor in ctx.is_tensor_mask: + if is_tensor: + flat_args.append(tensors[tensors_counter]) + tensors_counter += 1 + else: + flat_args.append(non_tensors[non_tensors_counter]) + non_tensors_counter += 1 + + args: list[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment] + + def add_perturbation(tensor: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: + return tensor.add(noise, alpha=sigma) + + param_grads: ListOfTensors = [0.0 for _ in range(len(flat_diff_params))] # type: ignore[misc] + + for _ in range(num_samples): + noises = [distribution.sample(sample_shape=p.shape) for p in flat_diff_params] + flat_noisy_params = list( + itertools.starmap(add_perturbation, zip(flat_diff_params, noises)), + ) + noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment] + diff_params_treespec, + flat_noisy_params, + ) + + for argnum, noisy_param in zip(argnums, noisy_params): + args[argnum] = noisy_param + + noisy_output = fn(*args) + output = noisy_output - output + weighted_grad = grad_outputs[0].mul(output).div_(1.0 / sigma) + + for i, noise in enumerate(noises): + param_grads[i] += weighted_grad * noise + + for i in range(len(flat_diff_params)): + param_grads[i] /= num_samples + + return tuple(param_grads + [None] * (ctx.len_args - len(flat_diff_params))) + + return ZeroOrder.apply(*flat_diff_params, (args,)) + + return apply + + +def _zero_order_antithetic( # noqa: C901 # pylint: disable=too-many-statements + fn: Callable[..., torch.Tensor], + distribution: Samplable, + argnums: tuple[int, ...], + num_samples: int, + sigma: float, +) -> Callable[..., torch.Tensor]: + @functools.wraps(fn) + def apply(*args: Any) -> torch.Tensor: # noqa: C901 # pylint: disable=too-many-statements + diff_params = [args[argnum] for argnum in argnums] + flat_diff_params: list[Any] + flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params) # type: ignore[arg-type] + + class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-method + @staticmethod + def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: + flat_diff_params = args[:-1] + origin_args = list(args[-1][0]) + flat_args: list[Any] + flat_args, args_treespec = pytree.tree_flatten(origin_args, none_is_leaf=True) # type: ignore[arg-type] + ctx.args_treespec = args_treespec + + is_tensor_mask = [] + tensors = [] + non_tensors = [] + for origin_arg in flat_args: + is_tensor = isinstance(origin_arg, torch.Tensor) + is_tensor_mask.append(is_tensor) + if is_tensor: + tensors.append(origin_arg) + else: + non_tensors.append(origin_arg) + + ctx.non_tensors = non_tensors + ctx.is_tensor_mask = is_tensor_mask + + output = fn(*origin_args) + if not isinstance(output, torch.Tensor): + raise TypeError('`output` must be a tensor.') + if output.ndim != 0: + raise RuntimeError('`output` must be a scalar tensor.') + ctx.save_for_backward(*flat_diff_params, *tensors) + ctx.len_args = len(args) + ctx.len_params = len(flat_diff_params) + return output + + @staticmethod + def backward( # pylint: disable=too-many-locals + ctx: Any, + *grad_outputs: Any, + ) -> TupleOfOptionalTensors: + saved_tensors = ctx.saved_tensors + flat_diff_params = saved_tensors[: ctx.len_params] + tensors = saved_tensors[ctx.len_params :] + non_tensors = ctx.non_tensors + + flat_args = [] + tensors_counter = 0 + non_tensors_counter = 0 + for is_tensor in ctx.is_tensor_mask: + if is_tensor: + flat_args.append(tensors[tensors_counter]) + tensors_counter += 1 + else: + flat_args.append(non_tensors[non_tensors_counter]) + non_tensors_counter += 1 + + args: list[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment] + + param_grads: ListOfTensors = [0.0 for _ in range(len(flat_diff_params))] # type: ignore[misc] + + def get_output( + add_perturbation_fn: Callable, + noises: Sequence[torch.Tensor | Numeric], + ) -> torch.Tensor: + flat_noisy_params = [ + add_perturbation_fn(t, n, alpha=sigma) + for t, n in zip(flat_diff_params, noises) + ] + noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment] + diff_params_treespec, + flat_noisy_params, + ) + + for argnum, noisy_param in zip(argnums, noisy_params): + args[argnum] = noisy_param + + return fn(*args) + + for _ in range(num_samples): + noises = [distribution.sample(sample_shape=p.shape) for p in flat_diff_params] + output = get_output(torch.add, noises) - get_output(torch.sub, noises) # type: ignore[arg-type] + weighted_grad = grad_outputs[0].mul(output).mul_(0.5 / sigma) + + for i, noise in enumerate(noises): + param_grads[i] += weighted_grad * noise + + for i in range(len(flat_diff_params)): + param_grads[i] /= num_samples + + return tuple(param_grads + [None] * (ctx.len_args - len(flat_diff_params))) + + return ZeroOrder.apply(*flat_diff_params, (args,)) + + return apply + + +Method: TypeAlias = Literal['naive', 'forward', 'antithetic'] + + +def zero_order( + distribution: SampleFunc | Samplable, + method: Method = 'naive', + argnums: int | tuple[int, ...] = (0,), + num_samples: int = 1, + sigma: float = 1.0, +) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: + """Return a decorator for applying zero-order differentiation. + + Args: + distribution (callable or Samplable): A samplable object that has method + ``samplable.sample(sample_shape)`` or a function that takes the shape as input and + returns a shaped batch of samples. This is used to sample perturbations from the given + distribution. The distribution should be sphere symmetric. + method (str, optional): The algorithm to use. The currently supported algorithms are + :const:`'naive'`, :const:`'forward'`, and :const:`'antithetic'`. + (default: :const:`'naive'`) + argnums (int or tuple of int, optional): Specifies arguments to compute gradients with + respect to. (default: :const:`0`) + num_samples (int, optional): The number of sample to get the averaged estimated gradient. + (default: :const:`1`) + sigma (float, optional): The standard deviation of the perturbation. + (default: :const:`1.0`) + + Returns: + A function decorator that enables zero-order gradient estimation. + """ + assert method in ('naive', 'forward', 'antithetic') + if method == 'naive': + method_fn = _zero_order_naive + elif method == 'forward': + method_fn = _zero_order_forward + else: + method_fn = _zero_order_antithetic + + if isinstance(argnums, int): + argnums = (argnums,) + + if not isinstance(distribution, Samplable): + if not callable(distribution): + raise TypeError('`distribution` must be a callable or an instance of `Samplable`.') + distribution = WrappedSamplable(distribution) + + return functools.partial( + method_fn, + distribution=distribution, + argnums=argnums, + num_samples=num_samples, + sigma=sigma, + ) diff --git a/torchopt/diff/zero_order/nn/__init__.py b/torchopt/diff/zero_order/nn/__init__.py new file mode 100644 index 00000000..f2753b27 --- /dev/null +++ b/torchopt/diff/zero_order/nn/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The base class for zero-order gradient models.""" + +import torchopt.nn.module # preload to resolve circular references +from torchopt.diff.zero_order.nn.module import ZeroOrderGradientModule + + +__all__ = ['ZeroOrderGradientModule'] + +del torchopt diff --git a/torchopt/diff/zero_order/nn/module.py b/torchopt/diff/zero_order/nn/module.py new file mode 100644 index 00000000..eeddabeb --- /dev/null +++ b/torchopt/diff/zero_order/nn/module.py @@ -0,0 +1,106 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The base class for zero-order gradient models.""" + +# pylint: disable=redefined-builtin + +from __future__ import annotations + +import abc +import functools +from typing import TYPE_CHECKING, Any, Sequence + +import torch +import torch.nn as nn + +from torchopt.diff.zero_order.decorator import Method, Samplable, zero_order +from torchopt.nn.stateless import reparametrize + + +if TYPE_CHECKING: + from torchopt.typing import Numeric, TupleOfTensors + + +__all__ = ['ZeroOrderGradientModule'] + + +def enable_zero_order_gradients( + cls: type[ZeroOrderGradientModule], + method: Method = 'naive', + num_samples: int = 1, + sigma: float = 1.0, +) -> type[ZeroOrderGradientModule]: + """Enable zero-order gradient estimation for the :func:`forward` method.""" + cls_forward = cls.forward + if getattr(cls_forward, '__zero_order_gradients_enabled__', False): + raise TypeError( + 'Zero-order gradient estimation is already enabled for the `forward` method.', + ) + + @functools.wraps(cls_forward) + def wrapped(self: ZeroOrderGradientModule, *input: Any, **kwargs: Any) -> torch.Tensor: + """Do the forward pass calculation.""" + named_params = tuple(self.named_parameters()) + if len(named_params) == 0: + raise RuntimeError('The module has no parameters.') + params_names, flat_params = tuple(zip(*named_params)) + + @zero_order(self.sample, argnums=0, method=method, num_samples=num_samples, sigma=sigma) + def forward_fn( + __flat_params: TupleOfTensors, + *input: Any, + **kwargs: Any, + ) -> torch.Tensor: + with reparametrize(self, zip(params_names, __flat_params)): + return cls_forward(self, *input, **kwargs) + + return forward_fn(flat_params, *input, **kwargs) + + wrapped.__zero_order_gradients_enabled__ = True # type: ignore[attr-defined] + cls.forward = wrapped # type: ignore[method-assign] + return cls + + +class ZeroOrderGradientModule(nn.Module, Samplable): + """The base class for zero-order gradient models.""" + + def __init_subclass__( # pylint: disable=arguments-differ + cls, + method: Method = 'naive', + num_samples: int = 1, + sigma: float = 1.0, + ) -> None: + """Validate and initialize the subclass.""" + super().__init_subclass__() + enable_zero_order_gradients( + cls, + method=method, + num_samples=num_samples, + sigma=sigma, + ) + + @abc.abstractmethod + def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: + """Do the forward pass of the model.""" + raise NotImplementedError + + @abc.abstractmethod + def sample( + self, + sample_shape: torch.Size = torch.Size(), # noqa: B008 # pylint: disable=unused-argument + ) -> torch.Tensor | Sequence[Numeric]: + # pylint: disable-next=line-too-long + """Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.""" + raise NotImplementedError diff --git a/torchopt/distributed/__init__.py b/torchopt/distributed/__init__.py new file mode 100644 index 00000000..31f1283b --- /dev/null +++ b/torchopt/distributed/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Distributed utilities.""" + +import torch.distributed as dist +import torch.distributed.rpc as rpc + +from torchopt.distributed import api, autograd, world +from torchopt.distributed.api import * # noqa: F403 +from torchopt.distributed.world import * # noqa: F403 + + +__all__ = ['is_available', *api.__all__, *world.__all__] + + +def is_available() -> bool: + """Check if the distributed module is available.""" + return dist.is_available() and rpc.is_available() and autograd.is_available() diff --git a/torchopt/distributed/api.py b/torchopt/distributed/api.py new file mode 100644 index 00000000..97be682f --- /dev/null +++ b/torchopt/distributed/api.py @@ -0,0 +1,482 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Distributed APIs.""" + +from __future__ import annotations + +import functools +import sys +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Sequence, + Tuple, + TypeVar, + Union, + cast, +) + +import torch +import torch.distributed.rpc as rpc + +from torchopt import pytree +from torchopt.distributed.world import get_worker_id, get_world_rank, get_world_size +from torchopt.typing import Future + + +__all__ = [ + 'TensorDimensionPartitioner', + 'batch_partitioner', + 'dim_partitioner', + 'mean_reducer', + 'parallelize', + 'parallelize_async', + 'parallelize_sync', + 'remote_async_call', + 'remote_sync_call', + 'sum_reducer', +] + + +UNSET_RPC_TIMEOUT = rpc.api.UNSET_RPC_TIMEOUT if rpc.is_available() else -1.0 + + +T = TypeVar('T') +U = TypeVar('U') +Args = Tuple[Any, ...] +KwArgs = Dict[str, Any] +PartitionFunction = Callable[..., Sequence[Tuple[int, Optional[Args], Optional[KwArgs]]]] +Partitioner = Union[int, str, PartitionFunction] + + +class TensorDimensionPartitioner: + """Partitioner class that partitions a batch of inputs along a given dimension. + + All tensors in the ``args`` and ``kwargs`` will be partitioned along the dimension ``dim``, + while the non-tensor values will be broadcasted to partitions. + + Args: + dim (int): The dimension to partition. + exclusive (bool, optional): Whether to partition the batch exclusively. (default: :data:`False`) + If :data:`True`, the batch will be partitioned into ``batch_size`` partitions, where + ``batch_size`` is the size of the batch along the given dimension. Each batch sample + will be assigned to a separate RPC call. + If :data:`False`, the batch will be partitioned into ``min(batch_size, num_workers)`` + partitions, where ``num_workers`` is the number of workers in the world. When + ``batch_size > num_workers``, there can be multiple batch samples forward in a single + RPC call. + keepdim (bool, optional): Whether to keep the partitioned dimension. (default: :data:`True`) + If :data:`True`, keep the batch dimension. If :data:`False`, use select instead of + slicing. This functionality should be used with ``exclusive=True``. + workers (sequence of int or str, or None, optional): The workers to partition the batch to. + If :data:`None`, the batch will be partitioned to all workers in the world. + (default: :data:`None`) + """ + + def __init__( + self, + dim: int, + *, + exclusive: bool = False, + keepdim: bool = False, + workers: Sequence[int | str] | None = None, + ) -> None: + """Initialize the partitioner instance.""" + if not keepdim and not exclusive: + raise ValueError('keepdim=False should be used with exclusive=True.') + + self.dim = dim + self.exclusive = exclusive + self.keepdim = keepdim + self.workers = workers + + # pylint: disable-next=too-many-branches,too-many-locals + def __call__( # noqa: C901 + self, + *args: Any, + **kwargs: Any, + ) -> list[tuple[int, Args | None, KwArgs | None]]: + """Partition the batch of inputs along the given dimension.""" + if self.workers is None: + workers = list(range(get_world_size())) + else: + workers = list(map(get_worker_id, self.workers)) + num_workers = len(workers) + + args_tree = (args, kwargs) + flat_args: list[Any] + flat_args, treespec = pytree.tree_flatten(args_tree) # type: ignore[arg-type] + + batch_size = None + for arg in flat_args: + if isinstance(arg, torch.Tensor): + if batch_size is None: + batch_size = arg.shape[self.dim] + elif batch_size != arg.shape[self.dim]: # type: ignore[unreachable] + raise ValueError( + f'Batch size mismatch on dim={self.dim}. ' + f'Expected {batch_size}, got {arg.shape[self.dim]} (shape: {arg.shape}).', + ) + + if batch_size is None: + return [(get_world_rank(), args, kwargs.copy())] + + dim_slices: list[int | slice] + batch_slices: list[tuple[int | slice | Ellipsis.__class__, ...]] # type: ignore[name-defined] + if self.exclusive: + num_replicas = batch_size + if self.keepdim: + dim_slices = [slice(i, i + 1) for i in range(num_replicas)] + else: + dim_slices = list(range(num_replicas)) + else: + if batch_size <= num_workers: + num_replicas = batch_size + dim_slices = [slice(i, i + 1) for i in range(batch_size)] # keepdim=True + else: + num_replicas = num_workers + local_size = batch_size // num_workers + local_batch_indices = [i * local_size for i in range(num_workers)] + [batch_size] + dim_slices = [ + slice(local_batch_indices[i], local_batch_indices[i + 1]) + for i in range(num_workers) + ] + + if self.dim >= 0: + batch_slices = [ + (slice(None, None),) * self.dim + (dim_slice,) for dim_slice in dim_slices + ] + elif self.dim < 0: + batch_slices = [ + ( + ..., + dim_slice, + ) + + (slice(None, None),) * (-self.dim - 1) + for dim_slice in dim_slices + ] + + flat_args_replicas: list[list[Any]] = [[] for _ in range(num_replicas)] + for arg in flat_args: + if isinstance(arg, torch.Tensor): + for i, batch_slice in enumerate(batch_slices): + flat_args_replicas[i].append(arg[batch_slice]) + else: + for i in range(num_replicas): + flat_args_replicas[i].append(arg) + + args_replicas: list[tuple[Args, KwArgs]] = [ + pytree.tree_unflatten(treespec, args_replica) # type: ignore[misc] + for args_replica in flat_args_replicas + ] + + return [ + (workers[i % num_workers], worker_args, worker_kwargs) + for i, (worker_args, worker_kwargs) in enumerate(args_replicas) + ] + + def __reduce__( + self, + ) -> tuple[ + Callable[..., TensorDimensionPartitioner], + tuple[int], + dict[str, bool | Sequence[int | str] | None], + ]: + """Return a tuple that allows the partitioner to be pickled.""" + return ( + TensorDimensionPartitioner, + (self.dim,), + {'exclusive': self.exclusive, 'keepdim': self.keepdim, 'workers': self.workers}, + ) + + +def dim_partitioner( + dim: int = 0, + *, + exclusive: bool = False, + keepdim: bool = True, + workers: Sequence[int | str] | None = None, +) -> PartitionFunction: + """Partition a batch of inputs along a given dimension. + + All tensors in the ``args`` and ``kwargs`` will be partitioned along the dimension ``dim``, + while the non-tensor values will be broadcasted to partitions. + + Args: + dim (int, optional): The dimension to partition. (default: :const:`0`) + exclusive (bool, optional): Whether to partition the batch exclusively. (default: :data:`False`) + If :data:`True`, the batch will be partitioned into ``batch_size`` partitions, where + ``batch_size`` is the size of the batch along the given dimension. Each batch sample + will be assigned to a separate RPC call. + If :data:`False`, the batch will be partitioned into ``min(batch_size, num_workers)`` + partitions, where ``num_workers`` is the number of workers in the world. When + ``batch_size > num_workers``, there can be multiple batch samples forward in a single + RPC call. + keepdim (bool, optional): Whether to keep the partitioned dimension. (default: :data:`False`) + If :data:`True`, keep the batch dimension. If :data:`False`, use select instead of + slicing. This functionality should be used with ``exclusive=True``. + workers (sequence of int or str, or None, optional): The workers to partition the batch to. + If :data:`None`, the batch will be partitioned to all workers in the world. + (default: :data:`None`) + + Returns: + A partition function. + """ + return TensorDimensionPartitioner(dim, exclusive=exclusive, keepdim=keepdim, workers=workers) + + +batch_partitioner: PartitionFunction = dim_partitioner(dim=0, keepdim=True, exclusive=False) +"""Partitioner for batch dimension. Divide and replicates the arguments to all workers along the first dimension. + +The batch will be partitioned into ``min(batch_size, num_workers)`` partitions, where +``num_workers`` is the number of workers in the world. +When ``batch_size > num_workers``, there can be multiple batch samples forward in a single RPC call. + +All tensors in the ``args`` and ``kwargs`` will be partitioned along the dimension ``dim``, +while the non-tensor values will be broadcasted to partitions. +""" +exclusive_batch_partitioner: PartitionFunction = dim_partitioner(dim=0, keepdim=True, exclusive=True) # fmt: skip +"""Partitioner for batch dimension. Divide and replicates the arguments to all workers along the first dimension. + +Each batch sample will be assigned to a separate RPC call. + +All tensors in the ``args`` and ``kwargs`` will be partitioned along the dimension ``dim``, +while the non-tensor values will be broadcasted to partitions. +""" + + +def mean_reducer(results: Iterable[torch.Tensor]) -> torch.Tensor: + """Reduce the results by averaging them.""" + return torch.mean(torch.stack(tuple(results), dim=0), dim=0) + + +def sum_reducer(results: Iterable[torch.Tensor]) -> torch.Tensor: + """Reduce the results by summing them.""" + return torch.sum(torch.stack(tuple(results), dim=0), dim=0) + + +# pylint: disable-next=too-many-arguments +def remote_async_call( + func: Callable[..., T], + *, + args: Args | None = None, + kwargs: KwArgs | None = None, + partitioner: Partitioner | None = None, + reducer: Callable[[Iterable[T]], U] | None = None, + timeout: float | None = UNSET_RPC_TIMEOUT, +) -> Future[list[T]] | Future[U]: + """Asynchronously do an RPC on remote workers and return the a :class:`torch.Future` instance at the current worker. + + Args: + func (callable): The function to call. + args (tuple of object or None, optional): The arguments to pass to the function. + (default: :data:`None`) + kwargs (dict[str, object] or None, optional): The keyword arguments to pass to the function. + (default: :data:`None`) + partitioner (int, str, or callable, optional): A partitioner that partitions the arguments + to multiple workers. (default: :func:`batch_partitioner`) + reducer (callable or None, optional): A reducer that reduces the results from multiple + workers. If :data:`None`, do not reduce the results. (default: :data:`None`) + timeout (float, optional): The timeout for the RPC call. + (default: :data:`rpc.api.UNSET_RPC_TIMEOUT`) + + Returns: + A :class:`torch.Future` instance for the result. The result is at the current worker. + """ + if args is None: + args = () + if kwargs is None: + kwargs = {} + if partitioner is None: + partitioner = batch_partitioner + if isinstance(partitioner, (int, str)): + partitions = [(get_worker_id(id=partitioner), args, kwargs)] + elif callable(partitioner): + partitions = partitioner(*args, **kwargs) # type: ignore[assignment] + else: + raise TypeError(f'Invalid partitioner: {partitioner!r}.') + + futures = [] + for rank, worker_args, worker_kwargs in partitions: + fut = rpc.rpc_async(rank, func, args=worker_args, kwargs=worker_kwargs, timeout=timeout) + futures.append(fut) + + future = cast( + Future[List[T]], + torch.futures.collect_all(futures).then(lambda fut: [f.wait() for f in fut.wait()]), + ) + if reducer is not None: + return cast( + Future[U], + future.then(lambda fut: reducer(fut.wait())), + ) + return future + + +# pylint: disable-next=too-many-arguments +def remote_sync_call( + func: Callable[..., T], + *, + args: Args | None = None, + kwargs: KwArgs | None = None, + partitioner: Partitioner | None = None, + reducer: Callable[[Iterable[T]], U] | None = None, + timeout: float | None = UNSET_RPC_TIMEOUT, +) -> list[T] | U: + """Do an RPC synchronously on remote workers and return the result to the current worker. + + Args: + func (callable): The function to call. + args (tuple of object or None, optional): The arguments to pass to the function. + (default: :data:`None`) + kwargs (dict[str, object] or None, optional): The keyword arguments to pass to the function. + (default: :data:`None`) + partitioner (int, str, or callable, optional): A partitioner that partitions the arguments + to multiple workers. (default: :func:`batch_partitioner`) + reducer (callable or None, optional): A reducer that reduces the results from multiple + workers. If :data:`None`, do not reduce the results. (default: :data:`None`) + timeout (float, optional): The timeout for the RPC call. + (default: :data:`rpc.api.UNSET_RPC_TIMEOUT`) + + Returns: + The result of the RPC call. The result is at the current worker. + """ + return remote_async_call( + func, + args=args, + kwargs=kwargs, + partitioner=partitioner, + timeout=timeout, + reducer=reducer, + ).wait() + + +def parallelize_async( + partitioner: Partitioner | None = None, + reducer: Callable[[Iterable[T]], U] | None = None, + timeout: float | None = UNSET_RPC_TIMEOUT, +) -> Callable[[Callable[..., T]], Callable[..., Future[list[T]] | Future[U]]]: + """Return a decorator for parallelizing a function. + + This decorator can be used to parallelize a function call across multiple workers. The + function will be called asynchronously on remote workers. The decorated function will + return a :class:`torch.Future` instance of the result. + + Args: + partitioner (int, str, or callable, optional): A partitioner that partitions the arguments + to multiple workers. (default: :func:`batch_partitioner`) + reducer (callable or None, optional): A reducer that reduces the results from multiple + workers. If :data:`None`, do not reduce the results. (default: :data:`None`) + timeout (float, optional): The timeout for the RPC call. + (default: :data:`rpc.api.UNSET_RPC_TIMEOUT`) + + Returns: + The decorator function. + """ + if partitioner is None: + partitioner = batch_partitioner + if reducer is None: + reducer = mean_reducer # type: ignore[assignment] + + def wrapper(func: Callable[..., T]) -> Callable[..., Future[list[T]] | Future[U]]: + @functools.wraps(func) + def wrapped(*args: Any, **kwargs: Any) -> Future[list[T]] | Future[U]: + return remote_async_call( + func, + args=args, + kwargs=kwargs, + partitioner=partitioner, + reducer=reducer, + timeout=timeout, + ) + + suffix = '__parallelize_async_unwrapped__' + module_name = func.__module__ + try: + name = func.__qualname__ + except AttributeError: + name = func.__name__ + else: + func.__qualname__ = f'{func.__qualname__}{suffix}' + func.__name__ = f'{func.__name__}{suffix}' + __import__(module_name, level=0) + module = sys.modules[module_name] + setattr(module, f'{name}{suffix}', func) + + return wrapped + + return wrapper + + +def parallelize( + partitioner: Partitioner | None = None, + reducer: Callable[[Iterable[T]], U] | None = None, + timeout: float | None = UNSET_RPC_TIMEOUT, +) -> Callable[[Callable[..., T]], Callable[..., list[T] | U]]: + """Return a decorator for parallelizing a function. + + This decorator can be used to parallelize a function call across multiple workers. + + Args: + partitioner (int, str, or callable, optional): A partitioner that partitions the arguments + to multiple workers. (default: :func:`batch_partitioner`) + reducer (callable or None, optional): A reducer that reduces the results from multiple + workers. If :data:`None`, do not reduce the results. (default: :data:`None`) + timeout (float, optional): The timeout for the RPC call. + (default: :data:`rpc.api.UNSET_RPC_TIMEOUT`) + + Returns: + The decorator function. + """ + if partitioner is None: + partitioner = batch_partitioner + if reducer is None: + reducer = mean_reducer # type: ignore[assignment] + + def wrapper(func: Callable[..., T]) -> Callable[..., list[T] | U]: + @functools.wraps(func) + def wrapped(*args: Any, **kwargs: Any) -> list[T] | U: + return remote_sync_call( + func, + args=args, + kwargs=kwargs, + partitioner=partitioner, + reducer=reducer, + timeout=timeout, + ) + + suffix = '__parallelize_unwrapped__' + module_name = func.__module__ + try: + name = func.__qualname__ + except AttributeError: + name = func.__name__ + else: + func.__qualname__ = f'{func.__qualname__}{suffix}' + func.__name__ = f'{func.__name__}{suffix}' + __import__(module_name, level=0) + module = sys.modules[module_name] + setattr(module, f'{name}{suffix}', func) + + return wrapped + + return wrapper + + +parallelize_sync = parallelize diff --git a/torchopt/distributed/autograd.py b/torchopt/distributed/autograd.py new file mode 100644 index 00000000..71afdb86 --- /dev/null +++ b/torchopt/distributed/autograd.py @@ -0,0 +1,137 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Distributed Autograd.""" + +from __future__ import annotations + +from threading import Lock +from typing import TYPE_CHECKING + +import torch +import torch.distributed.autograd as autograd +from torch.distributed.autograd import context + + +if TYPE_CHECKING: + from torchopt.typing import TensorOrTensors, TupleOfOptionalTensors + + +__all__ = ['context', 'is_available'] + + +LOCK = Lock() + + +def is_available() -> bool: + """Check if distributed autograd module is available.""" + return autograd.is_available() + + +if is_available(): + # pylint: disable-next=unused-import,ungrouped-imports + from torch.distributed.autograd import DistAutogradContext, get_gradients + + def backward( + autograd_ctx_id: int, + tensors: TensorOrTensors, + retain_graph: bool = False, + inputs: TensorOrTensors | None = None, + ) -> None: + """Perform distributed backward pass for local parameters. + + Compute the sum of gradients of given tensors with respect to graph leaves. + + Args: + autograd_ctx_id (int): The autograd context id. + tensors (Tensor or sequence of Tensor): Tensors of which the derivative will be computed. + retain_graph (bool, optional): If :data:`False`, the graph used to compute the grad will + be freed. Note that in nearly all cases setting this option to :data:`True` is not + needed and often can be worked around in a much more efficient way. + (default: :data:`False`) + inputs (Tensor, sequence of Tensor, or None, optional): Inputs w.r.t. which the gradient + be will accumulated into ``.grad``. All other Tensors will be ignored. If not + provided, the gradient is accumulated into all the leaf Tensors that were used to + compute the ``tensors``. (default: :data:`None`) + """ + if inputs is not None: + if isinstance(inputs, torch.Tensor): + inputs = (inputs,) + elif len(inputs) == 0: + raise RuntimeError("'inputs' argument to backward() cannot be empty.") + else: + inputs = tuple(inputs) + if not all(t.requires_grad for t in inputs): + raise RuntimeError('One of the differentiated Tensors does not require grad') + + roots = [tensors] if isinstance(tensors, torch.Tensor) else list(tensors) + autograd.backward(autograd_ctx_id, roots=roots, retain_graph=retain_graph) + + all_local_grads = autograd.get_gradients(autograd_ctx_id) + if inputs is not None: + inputs = set(inputs) # type: ignore[assignment] + all_local_grads = {p: g for p, g in all_local_grads.items() if p in inputs} + + with LOCK: + for p, g in all_local_grads.items(): + if p.grad is not None: + p.grad = p.grad.add(g) + else: + p.grad = g + + def grad( + autograd_ctx_id: int, + outputs: TensorOrTensors, + inputs: TensorOrTensors, + retain_graph: bool = False, + allow_unused: bool = False, + ) -> TupleOfOptionalTensors: + """Compute and return the sum of gradients of outputs with respect to the inputs. + + Args: + autograd_ctx_id (int): The autograd context id. + outputs (Tensor or sequence of Tensor): Outputs of the differentiated function. + inputs (Tensor or sequence of Tensor): Inputs w.r.t. which the gradient will be returned + (and not accumulated into ``.grad``). + retain_graph (bool, optional): If :data:`False`, the graph used to compute the grad will + be freed. Note that in nearly all cases setting this option to :data:`True` is not + needed and often can be worked around in a much more efficient way. + (default: :data:`False`) + allow_unused (bool, optional): If :data:`False`, specifying inputs that were not used + when computing outputs (and therefore their grad is always zero) is an error. + (default: :data:`False`) + """ + outputs = [outputs] if isinstance(outputs, torch.Tensor) else list(outputs) + inputs = (inputs,) if isinstance(inputs, torch.Tensor) else tuple(inputs) + if not all(t.requires_grad for t in inputs): + raise RuntimeError('One of the differentiated Tensors does not require grad') + + autograd.backward(autograd_ctx_id, roots=outputs, retain_graph=retain_graph) + + all_local_grads = autograd.get_gradients(autograd_ctx_id) + grads = [] + for p in inputs: + try: + grads.append(all_local_grads[p]) + except KeyError as ex: # noqa: PERF203 + if not allow_unused: + raise RuntimeError( + 'One of the differentiated Tensors appears to not have been used in the ' + 'graph. Set allow_unused=True if this is the desired behavior.', + ) from ex + grads.append(None) # type: ignore[arg-type] + + return tuple(grads) + + __all__ += ['DistAutogradContext', 'backward', 'get_gradients', 'grad'] diff --git a/torchopt/distributed/world.py b/torchopt/distributed/world.py new file mode 100644 index 00000000..610e52a0 --- /dev/null +++ b/torchopt/distributed/world.py @@ -0,0 +1,231 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for gathering information about the world.""" + +from __future__ import annotations + +import atexit +import functools +import os +from typing import Any, Callable, Iterable, NamedTuple, TypeVar + +import torch.distributed.rpc as rpc +from torch.distributed.elastic.multiprocessing.errors import record + + +__all__ = [ + 'auto_init_rpc', + 'barrier', + 'get_local_rank', + 'get_local_world_size', + 'get_rank', + 'get_worker_id', + 'get_world_info', + 'get_world_rank', + 'get_world_size', + 'not_on_rank', + 'on_rank', + 'rank_non_zero_only', + 'rank_zero_only', +] + + +def default_worker_name_format( + world_rank: int, + world_size: int, + local_rank: int, # pylint: disable=unused-argument + local_world_size: int, # pylint: disable=unused-argument +) -> str: + """Get the default worker name format.""" + return f'worker{world_rank:0{len(str(world_size))}d}' + + +F = TypeVar('F', bound=Callable[..., Any]) +_WORKER_NAME_FORMAT: Callable[..., str] = default_worker_name_format + + +class WorldInfo(NamedTuple): + """Information about the world.""" + + world_rank: int + world_size: int + local_rank: int + local_world_size: int + + @property + def rank(self) -> int: + """Get the global world rank of the current worker.""" + return self.world_rank + + @property + def worker_name(self) -> str: + """Get the name of the current worker.""" + return _WORKER_NAME_FORMAT( + world_rank=self.world_rank, + world_size=self.world_size, + local_rank=self.local_rank, + local_world_size=self.local_world_size, + ) + + +def get_world_info() -> WorldInfo: + """Get the world information.""" + world_info = getattr(get_world_info, 'world_info', None) + + if world_info is None: + world_rank = int(os.getenv('RANK', '0')) + world_size = int(os.getenv('WORLD_SIZE', '1')) + local_rank = int(os.getenv('LOCAL_RANK', '0')) + local_world_size = int(os.getenv('LOCAL_WORLD_SIZE', '1')) + world_info = WorldInfo(world_rank, world_size, local_rank, local_world_size) + # pylint: disable=line-too-long + get_world_info.world_info = get_world_info.WORLD_INFO = world_info # type: ignore[attr-defined] + get_world_info.world_rank = get_world_info.WORLD_RANK = world_rank # type: ignore[attr-defined] + get_world_info.rank = get_world_info.RANK = world_rank # type: ignore[attr-defined] + get_world_info.world_size = get_world_info.WORLD_SIZE = world_size # type: ignore[attr-defined] + get_world_info.local_rank = get_world_info.LOCAL_RANK = local_rank # type: ignore[attr-defined] + get_world_info.local_world_size = get_world_info.LOCAL_WORLD_SIZE = local_world_size # type: ignore[attr-defined] + # pylint: enable=line-too-long + + return world_info + + +def get_world_rank() -> int: + """Get the global world rank of the current worker.""" + return get_world_info().world_rank + + +get_rank = get_world_rank + + +def get_world_size() -> int: + """Get the world size.""" + return get_world_info().world_size + + +def get_local_rank() -> int: + """Get the local rank of the current worker on the current node.""" + return get_world_info().local_rank + + +def get_local_world_size() -> int: + """Get the local world size on the current node.""" + return get_world_info().local_world_size + + +get_world_info() + + +# pylint: disable-next=redefined-builtin,invalid-name +def get_worker_id(id: str | int | None = None) -> int: + """Get the worker id from the given id.""" + if isinstance(id, int): + return id + return rpc.get_worker_info(worker_name=id).id + + +def barrier(worker_names: Iterable[str] | None = None) -> None: + r"""Synchronize local and remote RPC processes. + + This will block until all local and remote RPC processes specified under worker_names + reach this method to wait for all outstanding work to complete. + + Args: + worker_names (iterable of str or None, optional): The set of workers to synchronize. + If :data:`None`, all workers. (default: :data:`None`) + """ + worker_names = {} if worker_names is None else set(worker_names) + rpc.api._barrier(worker_names) # pylint: disable=protected-access + + +def auto_init_rpc( + worker_init_fn: Callable[[], None] | None = None, + worker_name_format: Callable[..., str] = default_worker_name_format, + *, + backend: rpc.BackendType | None = None, + rpc_backend_options: rpc.RpcBackendOptions | None = None, +) -> Callable[[F], F]: + """Return a decorator to automatically initialize RPC on the decorated function.""" + global _WORKER_NAME_FORMAT # pylint: disable=global-statement + _WORKER_NAME_FORMAT = worker_name_format + + def wrapper(func: F) -> F: + world_info = get_world_info() + + @record + @functools.wraps(func) + def wrapped(*args: Any, **kwargs: Any) -> Any: + rpc.init_rpc( + name=world_info.worker_name, + rank=world_info.rank, + world_size=world_info.world_size, + backend=backend, + rpc_backend_options=rpc_backend_options, + ) + atexit.register(rpc.shutdown, graceful=True) + if worker_init_fn is not None: + barrier() + worker_init_fn() + barrier() + return func(*args, **kwargs) + + return wrapped # type: ignore[return-value] + + return wrapper + + +def __on_ranks(ranks: Iterable[int], inverse: bool = False) -> Callable[[F], F]: + ranks = frozenset(ranks) + + def wrapper(func: F) -> F: + world_rank = get_world_info().world_rank + + @functools.wraps(func) + def wrapped(*args: Any, **kwargs: Any) -> Any: + if inverse: + if world_rank not in ranks: + return func(*args, **kwargs) + elif world_rank in ranks: + return func(*args, **kwargs) + return None + + return wrapped # type: ignore[return-value] + + return wrapper + + +def on_rank(*ranks: int) -> Callable[[F], F]: + """Return a decorator to mark a function to be executed only on given ranks.""" + return __on_ranks(ranks=ranks, inverse=False) + + +def not_on_rank(*ranks: int) -> Callable[[F], F]: + """Return a decorator to mark a function to be executed only on non given ranks.""" + return __on_ranks(ranks=ranks, inverse=True) + + +def rank_all(func: F) -> F: + """Return a decorator to mark a function to be executed on all ranks.""" + return func + + +def rank_zero_only(func: F) -> F: + """Return a decorator to mark a function to be executed only on rank zero.""" + return on_rank(0)(func) + + +def rank_non_zero_only(func: F) -> F: + """Return a decorator to mark a function to be executed only on non rank zero.""" + return not_on_rank(0)(func) diff --git a/torchopt/hook.py b/torchopt/hook.py new file mode 100644 index 00000000..c11b92f6 --- /dev/null +++ b/torchopt/hook.py @@ -0,0 +1,78 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hook utilities.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable + +from torchopt import pytree +from torchopt.base import EmptyState, GradientTransformation + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import OptState, Params, Updates + + +__all__ = ['nan_to_num_hook', 'register_hook', 'zero_nan_hook'] + + +def zero_nan_hook(g: torch.Tensor) -> torch.Tensor: + """Replace ``nan`` with zero.""" + return g.nan_to_num(nan=0.0) + + +def nan_to_num_hook( + nan: float = 0.0, + posinf: float | None = None, + neginf: float | None = None, +) -> Callable[[torch.Tensor], torch.Tensor]: + """Return a ``nan`` to num hook to replace ``nan`` / ``+inf`` / ``-inf`` with the given numbers.""" + + def hook(g: torch.Tensor) -> torch.Tensor: + """Replace ``nan`` / ``+inf`` / ``-inf`` with the given numbers.""" + return g.nan_to_num(nan=nan, posinf=posinf, neginf=neginf) + + return hook + + +def register_hook(hook: Callable[[torch.Tensor], torch.Tensor | None]) -> GradientTransformation: + """Stateless identity transformation that leaves input gradients untouched. + + This function passes through the *gradient updates* unchanged. + + Returns: + An ``(init_fn, update_fn)`` tuple. + """ + + def init_fn(params: Params) -> OptState: # pylint: disable=unused-argument + return EmptyState() + + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, # pylint: disable=unused-argument + ) -> tuple[Updates, OptState]: + def f(g: torch.Tensor) -> torch.utils.hooks.RemovableHandle: + return g.register_hook(hook) + + pytree.tree_map_(f, updates) + return updates, state + + return GradientTransformation(init_fn, update_fn) diff --git a/torchopt/linalg/__init__.py b/torchopt/linalg/__init__.py new file mode 100644 index 00000000..fc499d67 --- /dev/null +++ b/torchopt/linalg/__init__.py @@ -0,0 +1,38 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/google/jax/blob/main/jax/_src/scipy/sparse/linalg.py +# ============================================================================== +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Linear algebra functions.""" + +from torchopt.linalg.cg import cg +from torchopt.linalg.ns import ns, ns_inv + + +__all__ = ['cg', 'ns', 'ns_inv'] diff --git a/torchopt/linalg/cg.py b/torchopt/linalg/cg.py new file mode 100644 index 00000000..1096a5af --- /dev/null +++ b/torchopt/linalg/cg.py @@ -0,0 +1,191 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/google/jax/blob/main/jax/_src/scipy/sparse/linalg.py +# ============================================================================== +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Conjugate Gradient iteration to solve ``Ax = b``.""" + +# pylint: disable=invalid-name + +from __future__ import annotations + +from functools import partial +from typing import TYPE_CHECKING, Callable + +import torch + +from torchopt import pytree +from torchopt.linalg.utils import cat_shapes, normalize_matvec +from torchopt.pytree import tree_vdot_real + + +if TYPE_CHECKING: + from torchopt.typing import TensorTree + + +__all__ = ['cg'] + + +def _identity(x: TensorTree) -> TensorTree: + return x + + +# pylint: disable-next=too-many-arguments,too-many-locals +def _cg_solve( + A: Callable[[TensorTree], TensorTree], + b: TensorTree, + x0: TensorTree, + *, + maxiter: int, + rtol: float = 1e-5, + atol: float = 0.0, + M: Callable[[TensorTree], TensorTree] = _identity, +) -> TensorTree: + # https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method + + # tolerance handling uses the "non-legacy" behavior of `scipy.sparse.linalg.cg` + b2 = tree_vdot_real(b, b) + atol2 = max(rtol**2 * b2, atol**2) + + def cond_fn(value: tuple[TensorTree, TensorTree, float, TensorTree, int]) -> bool: + _, r, gamma, _, k = value + rs = gamma if M is _identity else tree_vdot_real(r, r) + return rs > atol2 and k < maxiter + + def body_fn( + value: tuple[TensorTree, TensorTree, float, TensorTree, int], + ) -> tuple[TensorTree, TensorTree, float, TensorTree, int]: + x, r, gamma, p, k = value + Ap = A(p) + alpha = gamma / tree_vdot_real(p, Ap) + x_ = pytree.tree_map(lambda a, b: a.add(b, alpha=alpha), x, p) + r_ = pytree.tree_map(lambda a, b: a.sub(b, alpha=alpha), r, Ap) + z_ = M(r_) + gamma_ = tree_vdot_real(r_, z_) + beta_ = gamma_ / gamma + p_ = pytree.tree_map(lambda a, b: a.add(b, alpha=beta_), z_, p) + return x_, r_, gamma_, p_, k + 1 + + r0 = pytree.tree_map(torch.sub, b, A(x0)) + p0 = z0 = M(r0) + gamma0 = tree_vdot_real(r0, z0) + + value = (x0, r0, gamma0, p0, 0) + while cond_fn(value): + value = body_fn(value) + + x_final, *_ = value + + return x_final + + +# pylint: disable-next=too-many-arguments +def _isolve( + _isolve_solve: Callable, + A: TensorTree | Callable[[TensorTree], TensorTree], + b: TensorTree, + x0: TensorTree | None = None, + *, + rtol: float = 1e-5, + atol: float = 0.0, + maxiter: int | None = None, + M: TensorTree | Callable[[TensorTree], TensorTree] | None = None, +) -> TensorTree: + if x0 is None: + x0 = pytree.tree_map(torch.zeros_like, b) + + if maxiter is None: + size = sum(cat_shapes(b)) + maxiter = 10 * size # copied from SciPy + + if M is None: + M = _identity + A = normalize_matvec(A) + M = normalize_matvec(M) + + if cat_shapes(x0) != cat_shapes(b): + raise ValueError( + f'Tensors in x0 and b must have matching shapes: {cat_shapes(x0)} vs. {cat_shapes(b)}.', + ) + + isolve_solve = partial(_isolve_solve, x0=x0, rtol=rtol, atol=atol, maxiter=maxiter, M=M) + return isolve_solve(A, b) + + +# pylint: disable-next=too-many-arguments +def cg( + A: TensorTree | Callable[[TensorTree], TensorTree], + b: TensorTree, + x0: TensorTree | None = None, + *, + rtol: float = 1e-5, + atol: float = 0.0, + maxiter: int | None = None, + M: TensorTree | Callable[[TensorTree], TensorTree] | None = None, +) -> TensorTree: + """Use Conjugate Gradient iteration to solve ``Ax = b``. + + The numerics of TorchOpt's ``cg`` should exact match SciPy's ``cg`` (up to numerical precision), + but note that the interface is slightly different: you need to supply the linear operator ``A`` + as a function instead of a sparse matrix or ``LinearOperator``. + + Derivatives of :func:`cg` are implemented via implicit differentiation with another :func:`cg` + solve, rather than by differentiating *through* the solver. They will be accurate only if both + solves converge. + + Args: + A (Tensor or tree of Tensor): 2D array or function that calculates the linear map + (matrix-vector product) ``Ax`` when called like ``A(x)``. ``A`` must represent a + hermitian, positive definite matrix, and must return tensor(s) with the same structure + and shape as its argument. + b (Tensor or tree of Tensor): Right hand side of the linear system representing a single + vector. Can be stored as a tensor or Python container of tensor(s) with any shape. + x0 (Tensor, tree of Tensor, or None, optional): Starting guess for the solution. Must have + the same structure as ``b``. If :data:`None`, use zero initialization. + (default: :data:`None`) + rtol (float, optional): Tolerances for convergence, ``norm(residual) <= max(rtol*norm(b), atol)``. + We do not implement SciPy's "legacy" behavior, so TorchOpt's tolerance will differ from + SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``. (default: :const:`1e-5`) + atol (float, optional): Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``. + We do not implement SciPy's "legacy" behavior, so TorchOpt's tolerance will differ from + SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``. (default: :const:`0.0`) + maxiter (int or None, optional): Maximum number of iterations. Iteration will stop after + maxiter steps even if the specified tolerance has not been achieved. If :data:`None`, + ``10 * size`` will be used, where ``size`` is the size of the flattened input tensor(s). + (default: :data:`None`) + M (Tensor, tree of Tensor, function, or None, optional): Pre-conditioner for ``A``. The + pre-conditioner should approximate the inverse of ``A``. Effective preconditioning + dramatically improves the rate of convergence, which implies that fewer iterations are + needed to reach a given error tolerance. If :data:`None`, no pre-conditioner will be + used. (default: :data:`None`) + + Returns: + the Conjugate Gradient (CG) linear solver + """ + return _isolve(_cg_solve, A=A, b=b, x0=x0, rtol=rtol, atol=atol, maxiter=maxiter, M=M) diff --git a/torchopt/linalg/ns.py b/torchopt/linalg/ns.py new file mode 100644 index 00000000..5fc8d478 --- /dev/null +++ b/torchopt/linalg/ns.py @@ -0,0 +1,166 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Neumann Series Matrix Inversion Approximation to solve ``Ax = b``.""" + +# pylint: disable=invalid-name + +from __future__ import annotations + +import functools +from typing import TYPE_CHECKING, Callable + +import torch + +from torchopt import pytree +from torchopt.linalg.utils import normalize_matvec + + +if TYPE_CHECKING: + from torchopt.typing import TensorTree + + +__all__ = ['ns', 'ns_inv'] + + +def _ns_solve( + A: torch.Tensor, + b: torch.Tensor, + maxiter: int, + alpha: float | None = None, +) -> torch.Tensor: + """Use Neumann Series Matrix Inversion Approximation to solve ``Ax = b``.""" + if A.ndim != 2 or A.shape[0] != A.shape[1]: + raise ValueError(f'`A` must be a square matrix, but has shape: {A.shape}') + + inv_A_hat_b = b + v = b + if alpha is not None: + # A^{-1} = a [I - (I - a A)]^{-1} = a [I + (I - a A) + (I - a A)^2 + (I - a A)^3 + ...] + for _ in range(maxiter): + v = v - alpha * (A @ v) + inv_A_hat_b = inv_A_hat_b + v + inv_A_hat_b = alpha * inv_A_hat_b + else: + # A^{-1} = [I - (I - A)]^{-1} = I + (I - A) + (I - A)^2 + (I - A)^3 + ... + for _ in range(maxiter): + v = v - A @ v + inv_A_hat_b = inv_A_hat_b + v + + return inv_A_hat_b + + +def ns( + A: TensorTree | Callable[[TensorTree], TensorTree], + b: TensorTree, + maxiter: int | None = None, + *, + alpha: float | None = None, +) -> TensorTree: + """Use Neumann Series Matrix Inversion Approximation to solve ``Ax = b``. + + Args: + A (Tensor or tree of Tensor): 2D array or function that calculates the linear map + (matrix-vector product) ``Ax`` when called like ``A(x)``. ``A`` must represent a + hermitian, positive definite matrix, and must return tensor(s) with the same structure + and shape as its argument. + b (Tensor or tree of Tensor): Right hand side of the linear system representing a single + vector. Can be stored as a tensor or Python container of tensor(s) with any shape. + maxiter (int or None, optional): Maximum number of iterations. Iteration will stop after + maxiter steps even if the specified tolerance has not been achieved. If :data:`None`, + :const:`10` will be used. (default: :const:`10`) + alpha: (float or None, optional): Decay coefficient. If :data:`None`, :const:`1.0` will be + used. (default: :const:`1.0`) + + Returns: + The Neumann Series (NS) matrix inversion approximation. + """ + if maxiter is None: + maxiter = 10 + + if not callable(A): + return pytree.tree_map(functools.partial(_ns_solve, maxiter=maxiter, alpha=alpha), A, b) + + matvec = normalize_matvec(A) + inv_A_hat_b = b + v = b + if alpha is not None: + # A^{-1} = a [I - (I - a A)]^{-1} = a [I + (I - a A) + (I - a A)^2 + (I - a A)^3 + ...] + for _ in range(maxiter): + # v = v - alpha * (A @ v) + v = pytree.tree_sub_scalar_mul(v, matvec(v), alpha=alpha) + # inv_A_hat_b = inv_A_hat_b + v + inv_A_hat_b = pytree.tree_add(inv_A_hat_b, v) + # inv_A_hat_b = alpha * inv_A_hat_b + inv_A_hat_b = pytree.tree_scalar_mul(alpha, inv_A_hat_b) + else: + # A^{-1} = [I - (I - A)]^{-1} = I + (I - A) + (I - A)^2 + (I - A)^3 + ... + for _ in range(maxiter): + # v = v - A @ v + v = pytree.tree_sub(v, matvec(v)) + # inv_A_hat_b = inv_A_hat_b + v + inv_A_hat_b = pytree.tree_add(inv_A_hat_b, v) + + return inv_A_hat_b + + +def _ns_inv(A: torch.Tensor, maxiter: int, alpha: float | None = None) -> torch.Tensor: + """Use Neumann Series iteration to solve ``A^{-1}``.""" + if A.ndim != 2 or A.shape[0] != A.shape[1]: + raise ValueError(f'`A` must be a square matrix, but has shape: {A.shape}') + + I = torch.eye(*A.shape, out=torch.empty_like(A)) # noqa: E741 + inv_A_hat = torch.zeros_like(A) + if alpha is not None: + # A^{-1} = a [I - (I - a A)]^{-1} = a [I + (I - a A) + (I - a A)^2 + (I - a A)^3 + ...] + M = I - alpha * A + for rank in range(maxiter): + # pylint: disable-next=not-callable + inv_A_hat = inv_A_hat + torch.linalg.matrix_power(M, rank) + inv_A_hat = alpha * inv_A_hat + else: + # A^{-1} = [I - (I - A)]^{-1} = I + (I - A) + (I - A)^2 + (I - A)^3 + ... + M = I - A + for rank in range(maxiter): + # pylint: disable-next=not-callable + inv_A_hat = inv_A_hat + torch.linalg.matrix_power(M, rank) + return inv_A_hat + + +def ns_inv( + A: TensorTree, + maxiter: int | None = None, + *, + alpha: float | None = None, +) -> TensorTree: + """Use Neumann Series iteration to solve ``A^{-1}``. + + Args: + A (Tensor or tree of Tensor): 2D array or function that calculates the linear map + (matrix-vector product) ``Ax`` when called like ``A(x)``. ``A`` must represent a + hermitian, positive definite matrix, and must return tensor(s) with the same structure + and shape as its argument. + maxiter (int or None, optional): Maximum number of iterations. Iteration will stop after + maxiter steps even if the specified tolerance has not been achieved. If :data:`None`, + :const:`10` will be used. (default: :const:`10`) + alpha: (float or None, optional): Decay coefficient. If :data:`None`, :const:`1.0` will be + used. (default: :const:`1.0`) + + Returns: + The Neumann Series (NS) matrix inversion approximation. + """ + if maxiter is None: + maxiter = 10 + + return pytree.tree_map(functools.partial(_ns_inv, maxiter=maxiter, alpha=alpha), A) diff --git a/torchopt/linalg/utils.py b/torchopt/linalg/utils.py new file mode 100644 index 00000000..bbcc80aa --- /dev/null +++ b/torchopt/linalg/utils.py @@ -0,0 +1,60 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for linear algebra.""" + +from __future__ import annotations + +import itertools +from typing import TYPE_CHECKING, Callable + +import torch + +from torchopt import pytree + + +if TYPE_CHECKING: + from torchopt.typing import TensorTree + + +def cat_shapes(tree: TensorTree) -> tuple[int, ...]: + """Concatenate the shapes of the leaves of a tree of tensors.""" + leaves = pytree.tree_leaves(tree) + return tuple(itertools.chain.from_iterable(tuple(leaf.shape) for leaf in leaves)) + + +def normalize_matvec( + matvec: TensorTree | Callable[[TensorTree], TensorTree], +) -> Callable[[TensorTree], TensorTree]: + """Normalize an argument for computing matrix-vector product.""" + if callable(matvec): + return matvec + + mat_flat, treespec = pytree.tree_flatten(matvec) + for mat in mat_flat: + if not isinstance(mat, torch.Tensor) or mat.ndim != 2 or mat.shape[0] != mat.shape[1]: + raise TypeError(f'Linear operator must be a square matrix, but has shape: {mat.shape}') + + def _matvec(x: TensorTree) -> TensorTree: + x_flat = pytree.tree_leaves(x) + if len(x_flat) != len(mat_flat): + raise ValueError( + f'`x` must have the same number of leaves as `matvec`, ' + f'but has {len(x_flat)} leaves and `matvec` has {len(mat_flat)} leaves', + ) + + y_flat = map(torch.matmul, mat_flat, x_flat) + return pytree.tree_unflatten(treespec, y_flat) + + return _matvec diff --git a/torchopt/linear_solve/__init__.py b/torchopt/linear_solve/__init__.py new file mode 100644 index 00000000..43ca1da0 --- /dev/null +++ b/torchopt/linear_solve/__init__.py @@ -0,0 +1,39 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/google/jaxopt/blob/main/jaxopt/_src/linear_solve.py +# ============================================================================== +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Linear algebra solvers.""" + +from torchopt.linear_solve.cg import solve_cg +from torchopt.linear_solve.inv import solve_inv +from torchopt.linear_solve.normal_cg import solve_normal_cg + + +__all__ = ['solve_cg', 'solve_inv', 'solve_normal_cg'] diff --git a/torchopt/linear_solve/cg.py b/torchopt/linear_solve/cg.py new file mode 100644 index 00000000..23814cc2 --- /dev/null +++ b/torchopt/linear_solve/cg.py @@ -0,0 +1,114 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/google/jaxopt/blob/main/jaxopt/_src/linear_solve.py +# ============================================================================== +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Linear algebra solver for ``A x = b`` using conjugate gradient.""" + +# pylint: disable=invalid-name + +from __future__ import annotations + +import functools +from typing import TYPE_CHECKING, Any, Callable + +from torchopt import linalg +from torchopt.linear_solve.utils import make_ridge_matvec + + +if TYPE_CHECKING: + from torchopt.typing import LinearSolver, TensorTree + + +__all__ = ['solve_cg'] + + +def _solve_cg( + matvec: Callable[[TensorTree], TensorTree], # (x) -> A @ x + b: TensorTree, + ridge: float | None = None, + init: TensorTree | None = None, + **kwargs: Any, +) -> TensorTree: + """Solve ``A x = b`` using conjugate gradient. + + This assumes that ``A`` is a hermitian, positive definite matrix. + + Args: + matvec (callable): A function that returns the product between ``A`` and a vector. + b (Tensor or tree of Tensor): A tree of tensors for the right hand side of the equation. + ridge (float or None, optional): Optional ridge regularization. If provided, solves the + equation for ``A x + ridge x = b``. (default: :data:`None`) + init (Tensor, tree of Tensor, or None, optional): Optional initialization to be used by + conjugate gradient. If :data:`None`, uses zero initialization. (default: :data:`None`) + **kwargs: Additional keyword arguments for the conjugate gradient solver. + + Returns: + The solution with the same structure as ``b``. + """ + if ridge is not None: + # (x) -> A @ x + ridge * x + # i.e. (x) -> (A + ridge * I) @ x + matvec = make_ridge_matvec(matvec, ridge=ridge) + + # Returns solution for `(A + ridge * I) @ x = b`. + return linalg.cg(matvec, b, x0=init, **kwargs) + + +def solve_cg(**kwargs: Any) -> LinearSolver: + """Return a solver function to solve ``A x = b`` using conjugate gradient. + + This assumes that ``A`` is a hermitian, positive definite matrix. + + Args: + ridge (float or None, optional): Optional ridge regularization. If provided, solves the + equation for ``A x + ridge x = b``. (default: :data:`None`) + init (Tensor, tree of Tensor, or None, optional): Optional initialization to be used by + conjugate gradient. If :data:`None`, uses zero initialization. (default: :data:`None`) + **kwargs: Additional keyword arguments for the conjugate gradient solver + :func:`torchopt.linalg.cg`. + + Returns: + A solver function with signature ``(matvec, b) -> x`` that solves ``A x = b`` using + conjugate gradient where ``matvec(v) = A v``. + + See Also: + Conjugate gradient iteration :func:`torchopt.linalg.cg`. + + Examples: + >>> A = {'a': torch.eye(5, 5), 'b': torch.eye(3, 3)} + >>> x = {'a': torch.randn(5), 'b': torch.randn(3)} + >>> def matvec(x: TensorTree) -> TensorTree: + ... return {'a': A['a'] @ x['a'], 'b': A['b'] @ x['b']} + >>> b = matvec(x) + >>> solver = solve_cg(init={'a': torch.zeros(5), 'b': torch.zeros(3)}) + >>> x_hat = solver(matvec, b) + >>> assert torch.allclose(x_hat['a'], x['a']) and torch.allclose(x_hat['b'], x['b']) + """ + return functools.partial(_solve_cg, **kwargs) diff --git a/torchopt/linear_solve/inv.py b/torchopt/linear_solve/inv.py new file mode 100644 index 00000000..4dbe1542 --- /dev/null +++ b/torchopt/linear_solve/inv.py @@ -0,0 +1,129 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/google/jaxopt/blob/main/jaxopt/_src/linear_solve.py +# ============================================================================== +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Linear algebra solver for ``A x = b`` using matrix inversion.""" + +# pylint: disable=invalid-name + +from __future__ import annotations + +import functools +from typing import TYPE_CHECKING, Any, Callable + +import torch + +from torchopt import linalg, pytree +from torchopt.linear_solve.utils import make_ridge_matvec, materialize_matvec + + +if TYPE_CHECKING: + from torchopt.typing import LinearSolver, TensorTree + + +__all__ = ['solve_inv'] + + +def _solve_inv( + matvec: Callable[[TensorTree], TensorTree], # (x) -> A @ x + b: TensorTree, + ridge: float | None = None, + ns: bool = False, + **kwargs: Any, +) -> TensorTree: + """Solve ``A x = b`` using matrix inversion. + + If ``ns = False``, this assumes the matrix ``A`` is a constant matrix and will materialize it + in memory. + + Args: + matvec (callable): A function that returns the product between ``A`` and a vector. + b (Tensor or tree of Tensor): A tree of tensors for the right hand side of the equation. + ridge (float or None, optional): Optional ridge regularization. If provided, solves the + equation for ``A x + ridge x = b``. (default: :data:`None`) + ns (bool, optional): Whether to use Neumann Series matrix inversion approximation. + If :data:`False`, materialize the matrix ``A`` in memory and use :func:`torch.linalg.solve` + instead. (default: :data:`False`) + **kwargs: Additional keyword arguments for the Neumann Series matrix inversion approximation + solver :func:`torchopt.linalg.ns`. + + Returns: + The solution with the same shape as ``b``. + """ + if ridge is not None: + # (x) -> A @ x + ridge * x + # i.e. (x) -> (A + ridge * I) @ x + matvec = make_ridge_matvec(matvec, ridge=ridge) + + b_flat = pytree.tree_leaves(b) + if len(b_flat) == 1 and b_flat[0].ndim == 0: + A, *_ = materialize_matvec(matvec, b) + return pytree.tree_truediv(b, A) + + if ns: + return linalg.ns(matvec, b, **kwargs) + + A, _, tree_ravel, tree_unravel = materialize_matvec(matvec, b) + return tree_unravel(pytree.tree_map(torch.linalg.solve, A, tree_ravel(b))) + + +def solve_inv(**kwargs: Any) -> LinearSolver: + """Return a solver function to solve ``A x = b`` using matrix inversion. + + If ``ns = False``, this assumes the matrix ``A`` is a constant matrix and will materialize it + in memory. + + Args: + ridge (float or None, optional): Optional ridge regularization. If provided, solves the + equation for ``A x + ridge x = b``. (default: :data:`None`) + ns (bool, optional): Whether to use Neumann Series matrix inversion approximation. + If :data:`False`, materialize the matrix ``A`` in memory and use :func:`torch.linalg.solve` + instead. (default: :data:`False`) + **kwargs: Additional keyword arguments for the Neumann Series matrix inversion approximation + solver :func:`torchopt.linalg.ns`. + + Returns: + A solver function with signature ``(matvec, b) -> x`` that solves ``A x = b`` using matrix + inversion where ``matvec(v) = A v``. + + See Also: + Neumann Series matrix inversion approximation :func:`torchopt.linalg.ns`. + + Examples: + >>> A = {'a': torch.eye(5, 5), 'b': torch.eye(3, 3)} + >>> x = {'a': torch.randn(5), 'b': torch.randn(3)} + >>> def matvec(x: TensorTree) -> TensorTree: + ... return {'a': A['a'] @ x['a'], 'b': A['b'] @ x['b']} + >>> b = matvec(x) + >>> solver = solve_inv(ns=True, maxiter=10) + >>> x_hat = solver(matvec, b) + >>> assert torch.allclose(x_hat['a'], x['a']) and torch.allclose(x_hat['b'], x['b']) + """ + return functools.partial(_solve_inv, **kwargs) diff --git a/torchopt/linear_solve/normal_cg.py b/torchopt/linear_solve/normal_cg.py new file mode 100644 index 00000000..a5af49b2 --- /dev/null +++ b/torchopt/linear_solve/normal_cg.py @@ -0,0 +1,124 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/google/jaxopt/blob/main/jaxopt/_src/linear_solve.py +# ============================================================================== +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Linear algebra solver for ``A^T A x = A^T b`` using conjugate gradient.""" + +# pylint: disable=invalid-name + +from __future__ import annotations + +import functools +from typing import TYPE_CHECKING, Any, Callable + +from torchopt import linalg +from torchopt.linear_solve.utils import make_normal_matvec, make_ridge_matvec, make_rmatvec + + +if TYPE_CHECKING: + from torchopt.typing import LinearSolver, TensorTree + + +__all__ = ['solve_normal_cg'] + + +def _solve_normal_cg( + matvec: Callable[[TensorTree], TensorTree], # (x) -> A @ x + b: TensorTree, + ridge: float | None = None, + init: TensorTree | None = None, + **kwargs: Any, +) -> TensorTree: + """Solve the normal equation ``A^T A x = A^T b`` using conjugate gradient. + + This can be used to solve ``A x = b`` using conjugate gradient when ``A`` is not hermitian, + positive definite. + + Args: + matvec (callable): A function that returns the product between ``A`` and a vector. + b (Tensor or tree of Tensor): A tree of tensors for the right hand side of the equation. + ridge (float or None, optional): Optional ridge regularization. If provided, solves the + equation for ``A^T A x + ridge x = A^T b``. (default: :data:`None`) + init (Tensor, tree of Tensor, or None, optional): Optional initialization to be used by + conjugate gradient. If :data:`None`, uses zero initialization. (default: :data:`None`) + **kwargs: Additional keyword arguments for the conjugate gradient solver + :func:`torchopt.linalg.cg`. + + Returns: + The solution with the same structure as ``b``. + """ + example_x = b if init is None else init + + rmatvec = make_rmatvec(matvec, example_x) # (x) -> A.T @ x + normal_matvec = make_normal_matvec(matvec) # (x) -> A.T @ A @ x + + if ridge is not None: + # (x) -> A.T @ A @ x + ridge * x + # i.e. (x) -> (A.T @ A + ridge * I) @ x + normal_matvec = make_ridge_matvec(normal_matvec, ridge=ridge) + + rhs = rmatvec(b) # A.T @ b + + # Returns solution for `(A.T @ A + ridge * I) @ x = A.T @ b`. + return linalg.cg(normal_matvec, rhs, x0=init, **kwargs) + + +def solve_normal_cg(**kwargs: Any) -> LinearSolver: + """Return a solver function to solve ``A^T A x = A^T b`` using conjugate gradient. + + This can be used to solve ``A x = b`` using conjugate gradient when ``A`` is not hermitian, + positive definite. + + Args: + ridge (float or None, optional): Optional ridge regularization. If provided, solves the + equation for ``A^T A x + ridge x = A^T b``. (default: :data:`None`) + init (Tensor, tree of Tensor, or None, optional): Optional initialization to be used by + conjugate gradient. If :data:`None`, uses zero initialization. (default: :data:`None`) + **kwargs: Additional keyword arguments for the conjugate gradient solver + :func:`torchopt.linalg.cg`. + + Returns: + A solver function with signature ``(matvec, b) -> x`` that solves ``A^T A x = A^T b`` using + conjugate gradient where ``matvec(v) = A v``. + + See Also: + Conjugate gradient iteration :func:`torchopt.linalg.cg`. + + Examples: + >>> A = {'a': torch.randn(5, 5), 'b': torch.randn(3, 3)} + >>> x = {'a': torch.randn(5), 'b': torch.randn(3)} + >>> def matvec(x: TensorTree) -> TensorTree: + ... return {'a': A['a'] @ x['a'], 'b': A['b'] @ x['b']} + >>> b = matvec(x) + >>> solver = solve_normal_cg(init={'a': torch.zeros(5), 'b': torch.zeros(3)}) + >>> x_hat = solver(matvec, b) + >>> assert torch.allclose(x_hat['a'], x['a']) and torch.allclose(x_hat['b'], x['b']) + """ + return functools.partial(_solve_normal_cg, **kwargs) diff --git a/torchopt/linear_solve/utils.py b/torchopt/linear_solve/utils.py new file mode 100644 index 00000000..9d1b8779 --- /dev/null +++ b/torchopt/linear_solve/utils.py @@ -0,0 +1,122 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/google/jaxopt/blob/main/jaxopt/_src/linear_solve.py +# ============================================================================== +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for linear algebra solvers.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable + +import functorch + +from torchopt import pytree + + +if TYPE_CHECKING: + from torchopt.typing import TensorTree + + +def make_rmatvec( + matvec: Callable[[TensorTree], TensorTree], + example_x: TensorTree, +) -> Callable[[TensorTree], TensorTree]: + """Return a function that computes ``rmatvec(y) = A.T @ y`` from ``matvec(x) = A @ x``.""" + _, vjp, *_ = functorch.vjp(matvec, example_x) + + return lambda y: vjp(y)[0] + + +def make_normal_matvec( + matvec: Callable[[TensorTree], TensorTree], +) -> Callable[[TensorTree], TensorTree]: + """Return a function that computes ``normal_matvec(y) = A.T @ A @ y`` from ``matvec(x) = A @ x``.""" + + def normal_matvec(y: TensorTree) -> TensorTree: + """Compute ``A.T @ A @ y`` from ``matvec(x) = A @ x``.""" + matvec_y, vjp, *_ = functorch.vjp(matvec, y) + return vjp(matvec_y)[0] + + return normal_matvec + + +def make_ridge_matvec( + matvec: Callable[[TensorTree], TensorTree], + ridge: float = 0.0, +) -> Callable[[TensorTree], TensorTree]: + """Return a function that computes ``ridge_matvec(y) = A.T @ A @ y + ridge * y`` from ``matvec(x) = A @ x``.""" + + def ridge_matvec(y: TensorTree) -> TensorTree: + """Compute ``A.T @ A @ v + ridge * v`` from ``matvec(x) = A @ x``.""" + return pytree.tree_add_scalar_mul(matvec(y), y, alpha=ridge) + + return ridge_matvec + + +def materialize_matvec( + matvec: Callable[[TensorTree], TensorTree], + x: TensorTree, +) -> tuple[ + TensorTree, + Callable[[TensorTree], TensorTree], + Callable[[TensorTree], TensorTree], + Callable[[TensorTree], TensorTree], +]: + """Materialize the matrix ``A`` used in ``matvec(x) = A @ x``.""" + x_flat, treespec = pytree.tree_flatten(x) + shapes = tuple(t.shape for t in x_flat) + + if all(t.ndim == 1 for t in x_flat): + + def tree_ravel(x: TensorTree) -> TensorTree: + return x + + def tree_unravel(y: TensorTree) -> TensorTree: + return y + + matvec_ravel = matvec + + else: + + def tree_ravel(x: TensorTree) -> TensorTree: + return pytree.tree_map(lambda t: t.contiguous().view(-1), x) + + def tree_unravel(y: TensorTree) -> TensorTree: + shapes_iter = iter(shapes) + return pytree.tree_map(lambda t: t.contiguous().view(next(shapes_iter)), y) + + def matvec_ravel(y: TensorTree) -> TensorTree: + return tree_ravel(matvec(tree_unravel(y))) + + nargs = len(x_flat) + jacobian_tree = functorch.jacfwd(matvec_ravel)(tree_ravel(x)) + jacobian_flat = pytree.tree_leaves(jacobian_tree) + jacobian_diag = [jacobian_flat[i + i * nargs] for i in range(nargs)] + return pytree.tree_unflatten(treespec, jacobian_diag), matvec_ravel, tree_ravel, tree_unravel diff --git a/torchopt/nn/__init__.py b/torchopt/nn/__init__.py new file mode 100644 index 00000000..b55e49d7 --- /dev/null +++ b/torchopt/nn/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for neural network modules that hold meta-parameters and meta-modules.""" + +from torchopt.diff.implicit.nn.module import ImplicitMetaGradientModule # circular reference +from torchopt.diff.zero_order.nn.module import ZeroOrderGradientModule # circular reference +from torchopt.nn.module import MetaGradientModule +from torchopt.nn.stateless import reparameterize, reparametrize, swap_state + + +__all__ = [ + 'ImplicitMetaGradientModule', + 'MetaGradientModule', + 'ZeroOrderGradientModule', + 'reparameterize', + 'reparametrize', + 'swap_state', +] diff --git a/torchopt/nn/module.py b/torchopt/nn/module.py new file mode 100644 index 00000000..8c40f58a --- /dev/null +++ b/torchopt/nn/module.py @@ -0,0 +1,468 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for neural network modules that hold meta-parameters and meta-modules.""" + +from __future__ import annotations + +from collections import OrderedDict +from typing import TYPE_CHECKING, Any, Iterator, NamedTuple +from typing_extensions import Self # Python 3.11+ + +import torch +import torch.nn as nn + +from torchopt import pytree + + +if TYPE_CHECKING: + from torchopt.typing import TensorContainer + + +class MetaInputsContainer(NamedTuple): + """Container for parameters and modules in the constructor input arguments.""" + + meta_parameters: set[torch.Tensor] + meta_modules: set[nn.Module] + + +class MetaGradientModule(nn.Module): # pylint: disable=abstract-method + """Base class for neural network modules that hold meta-parameters and meta-modules.""" + + _meta_inputs: MetaInputsContainer + _meta_parameters: TensorContainer + _meta_modules: dict[str, nn.Module | None] + + def __new__(cls, *args: Any, **kwargs: Any) -> Self: + """Create a new module instance.""" + instance = super().__new__(cls) + flat_args: list[Any] + flat_args = pytree.tree_leaves((args, kwargs)) # type: ignore[arg-type] + meta_parameters = {x for x in flat_args if isinstance(x, torch.Tensor) and x.requires_grad} + meta_modules = {x for x in flat_args if isinstance(x, nn.Module) and x.training} + for meta_module in tuple(meta_modules): + meta_parameters.update(meta_module.parameters()) + meta_modules.update(meta_module.modules()) + + instance._meta_inputs = MetaInputsContainer(meta_parameters, meta_modules) + instance._meta_parameters: TensorContainer = OrderedDict() # type: ignore[misc] + instance._meta_modules: dict[str, nn.Module | None] = OrderedDict() # type: ignore[misc] + return instance + + def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=unused-argument + """Initialize a new module instance.""" + super().__init__() + + def __getattr__(self, name: str) -> torch.Tensor | nn.Module: # noqa: C901 + """Get an attribute of the module.""" + if '_parameters' in self.__dict__: + _parameters = self.__dict__['_parameters'] + if name in _parameters: + return _parameters[name] + if '_buffers' in self.__dict__: + _buffers = self.__dict__['_buffers'] + if name in _buffers: + return _buffers[name] + if '_modules' in self.__dict__: + modules = self.__dict__['_modules'] + if name in modules: + return modules[name] + if '_meta_parameters' in self.__dict__: + _meta_parameters = self.__dict__['_meta_parameters'] + if name in _meta_parameters: + return _meta_parameters[name] + if '_meta_modules' in self.__dict__: + _meta_modules = self.__dict__['_meta_modules'] + if name in _meta_modules: + return _meta_modules[name] + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + # pylint: disable-next=too-many-branches,too-many-statements + def __setattr__(self, name: str, value: torch.Tensor | nn.Module) -> None: # noqa: C901 + """Set an attribute of the module.""" + + def remove_from(*dicts_or_sets: dict[str, Any] | set[str]) -> None: + for dict_or_set in dicts_or_sets: + if name in dict_or_set: + if isinstance(dict_or_set, dict): + del dict_or_set[name] + else: + dict_or_set.discard(name) + + params = self.__dict__.get('_parameters') + meta_params = self.__dict__.get('_meta_parameters') + if isinstance(value, torch.Tensor) and value.requires_grad: + if params is None: + raise AttributeError('cannot assign parameters before Module.__init__() call') + if meta_params is None: + raise AttributeError( + 'cannot assign meta-parameters before MetaGradientModule.__init__() call', + ) + remove_from( + self.__dict__, + self._buffers, + self._modules, + self._non_persistent_buffers_set, + self._meta_parameters, + self._meta_modules, + ) + if value in self._meta_inputs.meta_parameters: + self.register_meta_parameter(name, value) + else: + self.register_parameter(name, value) + elif params is not None and name in params: + if value is not None: + raise TypeError( + f"cannot assign '{torch.typename(value)}' as parameter '{name}' " + f'(torch.Tensor or None expected)', + ) + self.register_parameter(name, value) # type: ignore[unreachable] + elif meta_params is not None and name in meta_params: + if value is not None: + raise TypeError( + f"cannot assign '{torch.typename(value)}' as meta-parameter '{name}' " + f'(torch.Tensor or None expected)', + ) + self.register_meta_parameter(name, value) # type: ignore[unreachable] + else: + modules = self.__dict__.get('_modules') + meta_modules = self.__dict__.get('_meta_modules') + if isinstance(value, nn.Module): + if modules is None: + raise AttributeError('cannot assign module before Module.__init__() call') + if meta_modules is None: + raise AttributeError( + 'cannot assign module before MetaGradientModule.__init__() call', + ) + remove_from( + self.__dict__, + self._parameters, + self._buffers, + self._non_persistent_buffers_set, + self._meta_parameters, + self._meta_modules, + ) + if value in self._meta_inputs.meta_modules: + meta_modules[name] = value + else: + modules[name] = value + elif modules is not None and name in modules: + if value is not None: + raise TypeError( + f"cannot assign '{torch.typename(value)}' as child module '{name}' " + f'(torch.nn.Module or None expected)', + ) + modules[name] = value # type: ignore[unreachable] + else: + buffers = self.__dict__.get('_buffers') + if buffers is not None and name in buffers: + if value is not None and not isinstance(value, torch.Tensor): + raise TypeError( + f"cannot assign '{torch.typename(value)}' as buffer '{name}' " + f'(torch.Tensor or None expected)', + ) + buffers[name] = value + else: + object.__setattr__(self, name, value) + + def __delattr__(self, name: str) -> None: + """Delete an attribute of the module.""" + if name in self._parameters: + del self._parameters[name] + elif name in self._buffers: + del self._buffers[name] + self._non_persistent_buffers_set.discard(name) + elif name in self._modules: + del self._modules[name] + elif name in self._meta_parameters: + del self._meta_parameters[name] + elif name in self._meta_modules: + del self._meta_modules[name] + else: + object.__delattr__(self, name) + + def register_parameter(self, name: str, param: torch.Tensor | None) -> None: + r"""Add a parameter to the module. + + The parameter can be accessed as an attribute using given name. + + Args: + name (str): The name of the parameter. The parameter can be accessed from this module + using the given name. + param (Tensor or None): The parameter to be added to the module. If :data:`None`, then + operations that run on parameters, such as ``cuda``, are ignored. If :data:`None`, + the parameter is **not** included in the module's ``state_dict``. + """ + if '_parameters' not in self.__dict__: + raise AttributeError('cannot assign parameter before Module.__init__() call') + if not isinstance(name, str): + raise TypeError(f'parameter name should be a string. Got {torch.typename(name)}') + if '.' in name: + raise KeyError("parameter name can't contain '.'") + if name == '': + raise KeyError("parameter name can't be empty string ''") + if hasattr(self, name) and name not in self._parameters: + raise KeyError(f"attribute '{name}' already exists") + + if param is None: + self._parameters[name] = None + return + + if not isinstance(param, torch.Tensor): + raise TypeError( + f"cannot assign '{torch.typename(param)}' object to parameter '{name}' " + f'(torch.Tensor or None required)', + ) + if not param.requires_grad: + raise ValueError( + f"cannot assign Tensor that `requires_grad=False` to parameter '{name}'", + ) + if param in self._meta_inputs.meta_parameters: + raise ValueError( + f"cannot assign Tensor that is a meta-parameter to parameter '{name}'. " + f'Use self.register_meta_parameter() instead.', + ) + + self._parameters[name] = param # type: ignore + + def register_meta_parameter(self, name: str, param: torch.Tensor | None) -> None: + r"""Add a meta-parameter to the module. + + The meta-parameter can be accessed as an attribute using given name. + + Args: + name (str): The name of the meta-parameter. The meta-parameter can be accessed from this + module using the given name. + param (Tensor or None): The meta-parameter to be added to the module. If :data:`None`, + then operations that run on meta-parameters, such as ``cuda``, are ignored. If + :data:`None`, the meta-parameter is **not** included in the module's ``state_dict``. + """ + if '_meta_parameters' not in self.__dict__: + raise AttributeError( + 'cannot assign meta-parameter before MetaGradientModule.__init__() call', + ) + if not isinstance(name, str): + raise TypeError(f'meta-parameter name should be a string. Got {torch.typename(name)}') + if '.' in name: + raise KeyError("meta-parameter name can't contain '.'") + if name == '': + raise KeyError("meta-parameter name can't be empty string ''") + if hasattr(self, name) and name not in self._meta_parameters: + raise KeyError(f"attribute '{name}' already exists") + + if param is None: + self._meta_parameters[name] = None + return + + if not isinstance(param, torch.Tensor): + raise TypeError( + f"cannot assign '{torch.typename(param)}' object to meta-parameter '{name}' " + f'(torch.Tensor or None required)', + ) + if not param.requires_grad: + raise ValueError( + f"cannot assign Tensor that `requires_grad=False` to meta-parameter '{name}'", + ) + + self._meta_parameters[name] = param + + def add_module(self, name: str, module: nn.Module | None) -> None: + r"""Add a child module to the current module. + + The module can be accessed as an attribute using the given name. + + Args: + name (str): The name of the child module. The child module can be accessed from this + module using the given name + module (nn.Module or None): The child module to be added to the module. + """ + if not isinstance(module, nn.Module) and module is not None: + raise TypeError(f'{torch.typename(module)} is not a Module subclass') + if not isinstance(name, str): + raise TypeError(f'module name should be a string. Got {torch.typename(name)}') + if hasattr(self, name) and name not in self._modules: + raise KeyError(f"attribute '{name}' already exists") + if '.' in name: + raise KeyError(f"module name can't contain '.', got: '{name}'") + if name == '': + raise KeyError("module name can't be empty string ''") + if module in self._meta_inputs.meta_modules: + raise ValueError( + f"cannot add module that is a meta-module to module '{name}'. " + f'Use self.add_meta_module() instead.', + ) + + self._modules[name] = module + + def register_module(self, name: str, module: nn.Module | None) -> None: + r"""Alias for :func:`add_module`.""" + self.add_module(name, module) + + def add_meta_module(self, name: str, meta_module: nn.Module | None) -> None: + r"""Add a child meta-module to the current module. + + The meta-module can be accessed as an attribute using the given name. + + Args: + name (str): The name of the child meta-module. The child meta-module can be accessed + from this module using the given name + meta_module (nn.Module or None): The child meta-module to be added to the module. + """ + if not isinstance(meta_module, nn.Module) and meta_module is not None: + raise TypeError(f'{torch.typename(meta_module)} is not a Module subclass') + if not isinstance(name, str): + raise TypeError(f'meta-module name should be a string. Got {torch.typename(name)}') + if hasattr(self, name) and name not in self._meta_modules: + raise KeyError(f"attribute '{name}' already exists") + if '.' in name: + raise KeyError(f"meta-module name can't contain '.', got: '{name}'") + if name == '': + raise KeyError("meta-module name can't be empty string ''") + + self._meta_modules[name] = meta_module + + def register_meta_module(self, name: str, meta_module: nn.Module | None) -> None: + r"""Alias for :func:`add_meta_module`.""" + self.add_meta_module(name, meta_module) + + def meta_parameters(self, recurse: bool = True) -> Iterator[torch.Tensor]: + r"""Return an iterator over module meta-parameters. + + This is typically passed to an optimizer. + + Args: + recurse (bool, optional): If :data:`True`, then yields parameters of this module and + all submodules. Otherwise, yields only meta-parameters that are direct members of + this module. (default: :data:`True`) + + Yields: + Parameter: module meta-parameter + + Examples: + >>> for param in model.meta_parameters(): + >>> print(type(param), param.size()) + (20L,) + (20L, 1L, 5L, 5L) + """ + for _, meta_param in self.named_meta_parameters(recurse=recurse): + yield meta_param + + def named_meta_parameters( + self, + prefix: str = '', + recurse: bool = True, + ) -> Iterator[tuple[str, torch.Tensor]]: + r"""Return an iterator over module meta-parameters, yielding both the name of the meta-parameter as well as the meta-parameter itself. + + Args: + prefix (str, optional): The prefix to prepend to all meta-parameter names. + (default: :const:`''`) + recurse (bool, optional): if :data:`True`, then yields meta-parameters of this module + and all submodules. Otherwise, yields only meta-parameters that are direct members + of this module. (default: :data:`True`) + + Yields: + (string, Parameter): Tuple containing the name and parameter + + Examples: + >>> for name, meta_param in self.named_meta_parameters(): + >>> if name in ['bias']: + >>> print(meta_param.size()) + """ # pylint: disable=line-too-long + memo = set() + for name, param in getattr(self, '_meta_parameters', {}).items(): + if param is None or param in memo: + continue + memo.add(param) + yield prefix + name, param + for name, meta_module in getattr(self, '_meta_modules', {}).items(): + if meta_module is None: + continue + submodule_prefix = prefix + name + yield from meta_module.named_parameters(submodule_prefix, recurse) + + def meta_children(self) -> Iterator[nn.Module]: + r"""Return an iterator over immediate children meta-modules. + + Yields: + Module: a child meta-module + """ + for _, module in self.named_meta_children(): + yield module + + def named_meta_children(self) -> Iterator[tuple[str, nn.Module]]: + r"""Return an iterator over immediate children meta-modules, yielding both the name of the meta-module as well as the meta-module itself. + + Yields: + (string, Module): Tuple containing a name and child meta-module + + Examples: + >>> for name, meta_module in model.named_meta_children(): + >>> if name in ['conv4', 'conv5']: + >>> print(meta_module) + """ # pylint: disable=line-too-long + memo = set() + for name, meta_module in self._meta_modules.items(): + if meta_module is not None and meta_module not in memo: + memo.add(meta_module) + yield name, meta_module + + def meta_modules(self) -> Iterator[nn.Module]: + r"""Return an iterator over all meta-modules in the network. + + Yields: + Module: a meta-module in the network + + Note: + Duplicate meta-modules are returned only once. + """ + for _, meta_module in self.named_meta_modules(): + yield meta_module + + def named_meta_modules( + self, + memo: set[nn.Module] | None = None, + prefix: str = '', + remove_duplicate: bool = True, + ) -> Iterator[tuple[str, nn.Module]]: + r"""Return an iterator over all meta-modules in the network, yielding both the name of the meta-module as well as the meta-module itself. + + Args: + memo (set of nn.Module or None, optional): A memory to store the set of meta-modules + already added to the result. If not provided, a new set will be created. + (default: :const:`None`) + prefix (str, optional): A prefix that will be added to the name of the meta-module. + (default: :const:`''`) + remove_duplicate (bool, optional): whether to remove the duplicated meta-module + instances in the result or not. (default: :const:`True`) + + Yields: + (string, Module): Tuple of name and meta-module + + Note: + Duplicate modules are returned only once. + """ # pylint: disable=line-too-long + if memo is None: + memo = set() + if self in memo: + return + + if remove_duplicate: + memo.add(self) + + for name, meta_module in self._meta_modules.items(): + if meta_module is None: + continue + submodule_prefix = prefix + ('.' if prefix else '') + name + yield from meta_module.named_modules(memo, submodule_prefix, remove_duplicate) diff --git a/torchopt/nn/stateless.py b/torchopt/nn/stateless.py new file mode 100644 index 00000000..c7f92b86 --- /dev/null +++ b/torchopt/nn/stateless.py @@ -0,0 +1,100 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions for stateless module calls.""" + +from __future__ import annotations + +import contextlib +from typing import TYPE_CHECKING, Generator, Iterable + + +if TYPE_CHECKING: + import torch + import torch.nn as nn + + +__all__ = ['reparameterize', 'reparametrize', 'swap_state'] + + +MISSING: torch.Tensor = object() # type: ignore[assignment] + + +def swap_state( + module: nn.Module, + named_tensors: dict[str, torch.Tensor] | Iterable[tuple[str, torch.Tensor]], + allow_missing: bool = False, +) -> dict[str, torch.Tensor]: + """Swap the module parameters and/or buffers.""" + if not isinstance(named_tensors, dict): + named_tensors = dict(named_tensors) + + submodules = {'': module} + + def get_submodule(path: str) -> nn.Module: + """Get submodules recursively.""" + try: + return submodules[path] + except KeyError: + prefix, dot, attr = path.rpartition('.') + if dot: + submodule = submodules[path] = getattr(get_submodule(prefix), attr) + else: + submodule = submodules[path] = getattr(module, attr) + return submodule + + def recursive_setattr(path: str, value: torch.Tensor) -> torch.Tensor: + """Set attribute recursively.""" + prefix, _, attr = path.rpartition('.') + mod = get_submodule(prefix) + + orig = getattr(mod, attr, MISSING) if allow_missing else getattr(mod, attr) + + # pylint: disable=protected-access + if value is MISSING: + delattr(mod, attr) + elif hasattr(mod, '_parameters') and attr in mod._parameters: + mod._parameters[attr] = value # type: ignore[assignment] + elif hasattr(mod, '_buffers') and attr in mod._buffers: + mod._buffers[attr] = value + elif hasattr(mod, '_meta_parameters') and attr in mod._meta_parameters: + mod._meta_parameters[attr] = value + else: + setattr(mod, attr, value) + # pylint: enable=protected-access + + return orig + + return {name: recursive_setattr(name, tensor) for name, tensor in named_tensors.items()} + + +@contextlib.contextmanager +def reparametrize( + module: nn.Module, + named_tensors: dict[str, torch.Tensor] | Iterable[tuple[str, torch.Tensor]], + allow_missing: bool = False, +) -> Generator[nn.Module, None, None]: + """Reparameterize the module parameters and/or buffers.""" + if not isinstance(named_tensors, dict): + named_tensors = dict(named_tensors) + + orig_named_tensors = {} + try: + orig_named_tensors = swap_state(module, named_tensors, allow_missing=allow_missing) + yield module + finally: + swap_state(module, orig_named_tensors, allow_missing=allow_missing) + + +reparameterize = reparametrize diff --git a/torchopt/optim/__init__.py b/torchopt/optim/__init__.py new file mode 100644 index 00000000..f620608c --- /dev/null +++ b/torchopt/optim/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""object oriented optimizer implementations.""" + +from torchopt.optim import meta +from torchopt.optim.adadelta import AdaDelta, Adadelta +from torchopt.optim.adagrad import AdaGrad, Adagrad +from torchopt.optim.adam import Adam +from torchopt.optim.adamax import AdaMax, Adamax +from torchopt.optim.adamw import AdamW +from torchopt.optim.base import Optimizer +from torchopt.optim.func import FuncOptimizer +from torchopt.optim.radam import RAdam +from torchopt.optim.rmsprop import RMSProp, RMSprop +from torchopt.optim.sgd import SGD diff --git a/torchopt/optim/adadelta.py b/torchopt/optim/adadelta.py new file mode 100644 index 00000000..600b69c5 --- /dev/null +++ b/torchopt/optim/adadelta.py @@ -0,0 +1,78 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Adadelta optimizer.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Iterable + +from torchopt import alias +from torchopt.optim.base import Optimizer + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import ScalarOrSchedule + + +__all__ = ['AdaDelta', 'Adadelta'] + + +class AdaDelta(Optimizer): + """The classic AdaDelta optimizer. + + See Also: + - The functional AdaDelta optimizer: :func:`torchopt.adadelta`. + - The differentiable meta-AdaDelta optimizer: :class:`torchopt.MetaAdaDetla`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + params: Iterable[torch.Tensor], + lr: ScalarOrSchedule = 1.0, + rho: float = 0.9, + eps: float = 1e-6, + weight_decay: float = 0.0, + ) -> None: + r"""Initialize the AdaDelta optimizer. + + Args: + params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what + tensors should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + rho (float, optional): Coefficients used for computing running averages of gradient and its square. + (default: :const:`0.9`) + eps (float, optional): A small constant applied to the square root (as in the AdaDelta paper) + to avoid dividing by zero when rescaling. + (default: :const:`1e-6`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + """ + super().__init__( + params, + alias.adadelta( + lr=lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + moment_requires_grad=False, + ), + ) + + +Adadelta = AdaDelta # alias for PyTorch compatibility diff --git a/torchopt/optim/adagrad.py b/torchopt/optim/adagrad.py new file mode 100644 index 00000000..06091281 --- /dev/null +++ b/torchopt/optim/adagrad.py @@ -0,0 +1,85 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""AdaGrad optimizer.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Iterable + +from torchopt import alias +from torchopt.optim.base import Optimizer + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import ScalarOrSchedule + + +__all__ = ['AdaGrad', 'Adagrad'] + + +class AdaGrad(Optimizer): + """The classic AdaGrad optimizer. + + See Also: + - The functional AdaGrad optimizer: :func:`torchopt.adagrad`. + - The differentiable meta-AdaGrad optimizer: :class:`torchopt.MetaAdaGrad`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + params: Iterable[torch.Tensor], + lr: ScalarOrSchedule = 1e-2, + lr_decay: float = 0.0, + weight_decay: float = 0.0, + initial_accumulator_value: float = 0.0, + eps: float = 1e-10, + *, + maximize: bool = False, + ) -> None: + r"""Initialize the AdaGrad optimizer. + + Args: + params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what + tensors should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-2`) + lr_decay (float, optional): Learning rate decay. (default: :const:`0.0`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + initial_accumulator_value (float, optional): Initial value for the accumulator. + (default: :const:`0.0`) + eps (float, optional): A small constant applied to denominator outside of the square + root (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-10`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) + """ + super().__init__( + params, + alias.adagrad( + lr=lr, + lr_decay=lr_decay, + weight_decay=weight_decay, + initial_accumulator_value=initial_accumulator_value, + eps=eps, + maximize=maximize, + ), + ) + + +Adagrad = AdaGrad # alias for PyTorch compatibility diff --git a/torchopt/optim/adam.py b/torchopt/optim/adam.py new file mode 100644 index 00000000..555af22e --- /dev/null +++ b/torchopt/optim/adam.py @@ -0,0 +1,92 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Adam optimizer.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Iterable + +from torchopt import alias +from torchopt.optim.base import Optimizer + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import ScalarOrSchedule + + +__all__ = ['Adam'] + + +class Adam(Optimizer): + """The classic Adam optimizer. + + See Also: + - The functional Adam optimizer: :func:`torchopt.adam`. + - The differentiable meta-Adam optimizer: :class:`torchopt.MetaAdam`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + params: Iterable[torch.Tensor], + lr: ScalarOrSchedule = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + *, + eps_root: float = 0.0, + maximize: bool = False, + use_accelerated_op: bool = False, + ) -> None: + r"""Initialize the Adam optimizer. + + Args: + params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what + tensors should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to denominator outside of the square + root (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + eps_root (float, optional): A small constant applied to denominator inside the square + root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for + example when computing (meta-)gradients through Adam. (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) + use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator. + (default: :data:`False`) + """ + super().__init__( + params, + alias.adam( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + eps_root=eps_root, + moment_requires_grad=False, + maximize=maximize, + use_accelerated_op=use_accelerated_op, + ), + ) diff --git a/torchopt/optim/adamax.py b/torchopt/optim/adamax.py new file mode 100644 index 00000000..e4996e85 --- /dev/null +++ b/torchopt/optim/adamax.py @@ -0,0 +1,78 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Adamax optimizer.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Iterable + +from torchopt import alias +from torchopt.optim.base import Optimizer + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import ScalarOrSchedule + + +__all__ = ['AdaMax', 'Adamax'] + + +class AdaMax(Optimizer): + """The classic AdaMax optimizer. + + See Also: + - The functional AdaMax optimizer: :func:`torchopt.adamax`. + - The differentiable meta-AdaMax optimizer: :class:`torchopt.MetaAdaMax`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + params: Iterable[torch.Tensor], + lr: ScalarOrSchedule = 2e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + ) -> None: + r"""Initialize the AdaMax optimizer. + + Args: + params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what + tensors should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to the square root (as in the AdaMax paper) + to avoid dividing by zero when rescaling. + (default: :const:`1e-6`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + """ + super().__init__( + params, + alias.adamax( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + moment_requires_grad=False, + ), + ) + + +Adamax = AdaMax # alias for PyTorch compatibility diff --git a/torchopt/optim/adamw.py b/torchopt/optim/adamw.py new file mode 100644 index 00000000..a60061ea --- /dev/null +++ b/torchopt/optim/adamw.py @@ -0,0 +1,103 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""AdamW optimizer.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Iterable + +from torchopt import alias +from torchopt.optim.base import Optimizer + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import OptState, Params, ScalarOrSchedule + + +__all__ = ['AdamW'] + + +class AdamW(Optimizer): + """The classic AdamW optimizer. + + See Also: + - The functional AdamW optimizer: :func:`torchopt.adamw`. + - The differentiable meta-AdamW optimizer: :class:`torchopt.MetaAdamW`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + params: Iterable[torch.Tensor], + lr: ScalarOrSchedule = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 1e-2, + *, + eps_root: float = 0.0, + mask: OptState | Callable[[Params], OptState] | None = None, + maximize: bool = False, + use_accelerated_op: bool = False, + ) -> None: + r"""Initialize the AdamW optimizer. + + Args: + params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what + tensors should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to denominator outside of the square + root (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Strength of the weight decay regularization. Note that + this weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al., 2019) where + the weight decay is only multiplied with the "schedule multiplier", but not the base + learning rate. (default: :const:`1e-2`) + eps_root (float, optional): A small constant applied to denominator inside the square + root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for + example when computing (meta-)gradients through Adam. (default: :const:`0.0`) + mask (tree of Tensor, callable, or None, optional): + A tree with same structure as (or a prefix of) the params pytree, or a function that + returns such a pytree given the params/updates. The leaves should be booleans, + :data:`True` for leaves/subtrees you want to apply the weight decay to, and + :data:`False` for those you want to skip. Note that the Adam gradient + transformations are applied to all parameters. (default: :data:`None`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) + use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator. + (default: :data:`False`) + """ + super().__init__( + params, + alias.adamw( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + eps_root=eps_root, + mask=mask, + moment_requires_grad=False, + maximize=maximize, + use_accelerated_op=use_accelerated_op, + ), + ) diff --git a/torchopt/optim/base.py b/torchopt/optim/base.py new file mode 100644 index 00000000..bdaa0d67 --- /dev/null +++ b/torchopt/optim/base.py @@ -0,0 +1,132 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The base class for optimizers.""" + +from __future__ import annotations + +from typing import Callable, Iterable, Sequence + +import torch + +from torchopt import pytree +from torchopt.base import UninitializedState +from torchopt.typing import GradientTransformation, OptState, Params, TupleOfTensors +from torchopt.update import apply_updates + + +__all__ = ['Optimizer'] + + +class Optimizer: + """A base class for classic optimizers that similar to :class:`torch.optim.Optimizer`.""" + + def __init__(self, params: Iterable[torch.Tensor], impl: GradientTransformation) -> None: + r"""Initialize the optimizer. + + Args: + params (iterable of torch.Tensor): An iterable of :class:`torch.Tensor`\s. Specifies + what tensors should be optimized. + impl (GradientTransformation): A low level optimizer function, it could be a optimizer + function provided in :mod:`torchopt.alias` or a customized :func:`torchopt.chain`\ed + transformation. + Note that using ``Optimizer(sgd())`` or ``Optimizer(chain(sgd()))`` is equivalent to + :class:`torchopt.SGD`. + """ + if not isinstance(impl, GradientTransformation): + raise TypeError(f'{impl} (type: {type(impl).__name__}) is not a GradientTransformation') + + self.impl: GradientTransformation = impl + self.param_groups: list[TupleOfTensors] = [] + self.param_treespecs: list[pytree.PyTreeSpec] = [] + self.state_groups: list[OptState] = [] + + if not isinstance(params, (list, tuple)): + params = tuple(params) + self.add_param_group(params) + + def zero_grad(self, set_to_none: bool = False) -> None: + r"""Set the gradients of all optimized :class:`torch.Tensor`\s to zero. + + The behavior is similar to :meth:`torch.optim.Optimizer.zero_grad`. + + Args: + set_to_none (bool, optional): Instead of setting to zero, set the ``grads`` to + :data:`None`. (default: :data:`False`) + """ + if set_to_none: + + def f(p: torch.Tensor) -> None: + p.grad = None + + else: + + def f(p: torch.Tensor) -> None: + if p.grad is None: + return + if p.grad.grad_fn is not None: + p.grad.detach_() + else: + p.grad.requires_grad_(False) + p.grad.zero_() + + pytree.tree_map_(f, self.param_groups) # type: ignore[arg-type] + + def state_dict(self) -> tuple[OptState, ...]: + """Return the state of the optimizer.""" + return tuple(self.state_groups) + + def load_state_dict(self, state_dict: Sequence[OptState]) -> None: + """Load the optimizer state. + + Args: + state_dict (sequence of tree of Tensor): Optimizer state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.state_groups[:] = list(state_dict) + + def step(self, closure: Callable[[], torch.Tensor] | None = None) -> torch.Tensor | None: + """Perform a single optimization step. + + The behavior is similar to :meth:`torch.optim.Optimizer.step`. + + Args: + closure (callable or None, optional): A closure that reevaluates the model and returns + the loss. Optional for most optimizers. (default: :data:`None`) + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + def f(p: torch.Tensor) -> torch.Tensor | None: + return p.grad + + for i, (params, state) in enumerate(zip(self.param_groups, self.state_groups)): + if isinstance(state, UninitializedState): + state = self.impl.init(params) + grads = pytree.tree_map(f, params) # type: ignore[arg-type] + updates, new_state = self.impl.update(grads, state, params=params, inplace=True) + self.param_groups[i] = apply_updates(params, updates, inplace=True) + self.state_groups[i] = new_state + + return loss + + def add_param_group(self, params: Params) -> None: + """Add a param group to the optimizer's ``param_groups``.""" + flat_params: TupleOfTensors + flat_params, params_treespec = pytree.tree_flatten_as_tuple(params) + self.param_groups.append(flat_params) + self.param_treespecs.append(params_treespec) + self.state_groups.append(UninitializedState()) diff --git a/torchopt/optim/func/__init__.py b/torchopt/optim/func/__init__.py new file mode 100644 index 00000000..f136f808 --- /dev/null +++ b/torchopt/optim/func/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional optimizer wrappers.""" + +from torchopt.optim.func.base import FuncOptimizer diff --git a/torchopt/optim/func/base.py b/torchopt/optim/func/base.py new file mode 100644 index 00000000..fa287f04 --- /dev/null +++ b/torchopt/optim/func/base.py @@ -0,0 +1,115 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional optimizer wrappers.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from torchopt.base import GradientTransformation, UninitializedState +from torchopt.update import apply_updates + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params + + +__all__ = ['FuncOptimizer'] + + +class FuncOptimizer: # pylint: disable=too-few-public-methods + """A wrapper class to hold the functional optimizer. + + This wrapper makes it easier to maintain the optimizer states. The optimizer states are held by + the wrapper internally. The wrapper provides a :meth:`step` function to compute the gradients + and update the parameters. + + See Also: + - The functional AdaDelta optimizer: :func:`torchopt.adadelta`. + - The functional AdaGrad optimizer: :func:`torchopt.adagrad`. + - The functional Adam optimizer: :func:`torchopt.adam`. + - The functional AdamW optimizer: :func:`torchopt.adamw`. + - The functional AdaMax optimizer: :func:`torchopt.adamax`. + - The functional RAdam optimizer: :func:`torchopt.radam`. + - The functional RMSprop optimizer: :func:`torchopt.rmsprop`. + - The functional SGD optimizer: :func:`torchopt.sgd`. + """ + + def __init__(self, impl: GradientTransformation, *, inplace: bool = False) -> None: + r"""Initialize the functional optimizer wrapper. + + Args: + impl (GradientTransformation): A low level optimizer function, it could be a optimizer + function provided in :mod:`torchopt.alias` or a customized :func:`torchopt.chain`\ed + transformation. + inplace (bool, optional): The default value of ``inplace`` for each optimization update. + (default: :data:`False`) + """ + if not isinstance(impl, GradientTransformation): + raise TypeError(f'{impl} (type: {type(impl).__name__}) is not a GradientTransformation') + + self.impl: GradientTransformation = impl + self.optim_state: OptState | None = UninitializedState() + self.inplace: bool = bool(inplace) + + def step( + self, + loss: torch.Tensor, + params: Params, + inplace: bool | None = None, + ) -> Params: + r"""Compute the gradients of loss to the network parameters and update network parameters. + + Graph of the derivative will be constructed, allowing to compute higher order derivative + products. We use the differentiable optimizer (pass argument inplace=False) to scale the + gradients and update the network parameters without modifying tensors in-place. + + Args: + loss (Tensor): The loss that is used to compute the gradients to network parameters. + params (tree of Tensor): An tree of :class:`torch.Tensor`\s. Specifies what tensors + should be optimized. + inplace (bool or None, optional): Whether to update the parameters in-place. If + :data:`None`, use the default value specified in the constructor. + (default: :data:`None`) + """ + if isinstance(self.optim_state, UninitializedState): + self.optim_state = self.impl.init(params) + + if inplace is None: + inplace = self.inplace + + # Step parameter only + grads = torch.autograd.grad(loss, params, create_graph=True, allow_unused=True) + updates, self.optim_state = self.impl.update( + grads, + self.optim_state, + params=params, + inplace=inplace, + ) + return apply_updates(params, updates, inplace=inplace) + + def state_dict(self) -> OptState: + """Extract the references of the optimizer states. + + Note that the states are references, so any in-place operations will change the states + inside :class:`FuncOptimizer` at the same time. + """ + return self.optim_state + + def load_state_dict(self, state_dict: OptState) -> None: + """Load the references of the optimizer states.""" + self.optim_state = state_dict diff --git a/torchopt/optim/meta/__init__.py b/torchopt/optim/meta/__init__.py new file mode 100644 index 00000000..9e30dfef --- /dev/null +++ b/torchopt/optim/meta/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Differentiable Meta-Optimizers.""" + +from torchopt.optim.meta.adadelta import MetaAdaDelta, MetaAdadelta +from torchopt.optim.meta.adagrad import MetaAdaGrad, MetaAdagrad +from torchopt.optim.meta.adam import MetaAdam +from torchopt.optim.meta.adamax import MetaAdaMax, MetaAdamax +from torchopt.optim.meta.adamw import MetaAdamW +from torchopt.optim.meta.base import MetaOptimizer +from torchopt.optim.meta.radam import MetaRAdam +from torchopt.optim.meta.rmsprop import MetaRMSProp, MetaRMSprop +from torchopt.optim.meta.sgd import MetaSGD diff --git a/torchopt/optim/meta/adadelta.py b/torchopt/optim/meta/adadelta.py new file mode 100644 index 00000000..eb386ae3 --- /dev/null +++ b/torchopt/optim/meta/adadelta.py @@ -0,0 +1,82 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Differentiable Adadelta optimizer.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from torchopt import alias +from torchopt.optim.meta.base import MetaOptimizer + + +if TYPE_CHECKING: + import torch.nn as nn + + from torchopt.typing import ScalarOrSchedule + + +__all__ = ['MetaAdaDelta', 'MetaAdadelta'] + + +class MetaAdaDelta(MetaOptimizer): + """The differentiable AdaDelta optimizer. + + See Also: + - The functional AdaDelta optimizer: :func:`torchopt.adadetla`. + - The classic AdaDelta optimizer: :class:`torchopt.Adadelta`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + module: nn.Module, + lr: ScalarOrSchedule = 1.0, + rho: float = 0.9, + eps: float = 1e-6, + weight_decay: float = 0.0, + *, + moment_requires_grad: bool = True, + ) -> None: + """Initialize the meta AdaDelta optimizer. + + Args: + module (nn.Module): A network whose parameters should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + rho (float, optional): Coefficients used for computing running averages of gradient and its square. + (default: :const:`0.9`) + eps (float, optional): A small constant applied to the square root (as in the AdaDelta paper) + to avoid dividing by zero when rescaling. + (default: :const:`1e-6`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + """ + super().__init__( + module, + alias.adadelta( + lr=lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + moment_requires_grad=moment_requires_grad, + ), + ) + + +MetaAdadelta = MetaAdaDelta # alias for PyTorch compatibility diff --git a/torchopt/optim/meta/adagrad.py b/torchopt/optim/meta/adagrad.py new file mode 100644 index 00000000..129c1338 --- /dev/null +++ b/torchopt/optim/meta/adagrad.py @@ -0,0 +1,84 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Differentiable AdaGrad optimizer.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from torchopt import alias +from torchopt.optim.meta.base import MetaOptimizer + + +if TYPE_CHECKING: + import torch.nn as nn + + from torchopt.typing import ScalarOrSchedule + + +__all__ = ['MetaAdaGrad', 'MetaAdagrad'] + + +class MetaAdaGrad(MetaOptimizer): + """The differentiable AdaGrad optimizer. + + See Also: + - The functional AdaGrad optimizer: :func:`torchopt.adagrad`. + - The classic AdaGrad optimizer: :class:`torchopt.Adagrad`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + module: nn.Module, + lr: ScalarOrSchedule = 1e-2, + lr_decay: float = 0.0, + weight_decay: float = 0.0, + initial_accumulator_value: float = 0.0, + eps: float = 1e-10, + *, + maximize: bool = False, + ) -> None: + """Initialize the meta AdaGrad optimizer. + + Args: + module (nn.Module): A network whose parameters should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-2`) + lr_decay (float, optional): Learning rate decay. (default: :const:`0.0`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + initial_accumulator_value (float, optional): Initial value for the accumulator. + (default: :const:`0.0`) + eps (float, optional): A small constant applied to denominator outside of the square + root (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-10`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) + """ + super().__init__( + module, + alias.adagrad( + lr=lr, + lr_decay=lr_decay, + weight_decay=weight_decay, + initial_accumulator_value=initial_accumulator_value, + eps=eps, + maximize=maximize, + ), + ) + + +MetaAdagrad = MetaAdaGrad # alias for PyTorch compatibility diff --git a/torchopt/optim/meta/adam.py b/torchopt/optim/meta/adam.py new file mode 100644 index 00000000..7a78ea7f --- /dev/null +++ b/torchopt/optim/meta/adam.py @@ -0,0 +1,92 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Differentiable Adam optimizer.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from torchopt import alias +from torchopt.optim.meta.base import MetaOptimizer + + +if TYPE_CHECKING: + import torch.nn as nn + + from torchopt.typing import ScalarOrSchedule + + +__all__ = ['MetaAdam'] + + +class MetaAdam(MetaOptimizer): + """The differentiable Adam optimizer. + + See Also: + - The functional Adam optimizer: :func:`torchopt.adam`. + - The classic Adam optimizer: :class:`torchopt.Adam`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + module: nn.Module, + lr: ScalarOrSchedule = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + *, + eps_root: float = 0.0, + moment_requires_grad: bool = True, + maximize: bool = False, + use_accelerated_op: bool = False, + ) -> None: + """Initialize the meta-Adam optimizer. + + Args: + module (nn.Module): A network whose parameters should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to denominator outside of the square + root (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + eps_root (float, optional): A small constant applied to denominator inside the square + root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for + example when computing (meta-)gradients through Adam. (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) + use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator. + (default: :data:`False`) + """ + super().__init__( + module, + alias.adam( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + eps_root=eps_root, + moment_requires_grad=moment_requires_grad, + maximize=maximize, + use_accelerated_op=use_accelerated_op, + ), + ) diff --git a/torchopt/optim/meta/adamax.py b/torchopt/optim/meta/adamax.py new file mode 100644 index 00000000..d6b40427 --- /dev/null +++ b/torchopt/optim/meta/adamax.py @@ -0,0 +1,82 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Differentiable Adamax optimizer.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from torchopt import alias +from torchopt.optim.meta.base import MetaOptimizer + + +if TYPE_CHECKING: + import torch.nn as nn + + from torchopt.typing import ScalarOrSchedule + + +__all__ = ['MetaAdaMax', 'MetaAdamax'] + + +class MetaAdaMax(MetaOptimizer): + """The differentiable AdaMax optimizer. + + See Also: + - The functional AdaMax optimizer: :func:`torchopt.adamax`. + - The classic AdaMax optimizer: :class:`torchopt.Adamax`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + module: nn.Module, + lr: ScalarOrSchedule = 2e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + *, + moment_requires_grad: bool = True, + ) -> None: + """Initialize the meta AdaMax optimizer. + + Args: + module (nn.Module): A network whose parameters should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to the square root (as in the AdaMax paper) + to avoid dividing by zero when rescaling. + (default: :const:`1e-6`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + """ + super().__init__( + module, + alias.adamax( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + moment_requires_grad=moment_requires_grad, + ), + ) + + +MetaAdamax = MetaAdaMax # alias for PyTorch compatibility diff --git a/torchopt/optim/meta/adamw.py b/torchopt/optim/meta/adamw.py new file mode 100644 index 00000000..62864582 --- /dev/null +++ b/torchopt/optim/meta/adamw.py @@ -0,0 +1,103 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Differentiable AdamW optimizer.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable + +from torchopt import alias +from torchopt.optim.meta.base import MetaOptimizer + + +if TYPE_CHECKING: + import torch.nn as nn + + from torchopt.typing import OptState, Params, ScalarOrSchedule + + +__all__ = ['MetaAdamW'] + + +class MetaAdamW(MetaOptimizer): + """The differentiable AdamW optimizer. + + See Also: + - The functional AdamW optimizer: :func:`torchopt.adamw`. + - The classic AdamW optimizer: :class:`torchopt.AdamW`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + module: nn.Module, + lr: ScalarOrSchedule = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 1e-2, + *, + eps_root: float = 0.0, + mask: OptState | Callable[[Params], OptState] | None = None, + moment_requires_grad: bool = False, + maximize: bool = False, + use_accelerated_op: bool = False, + ) -> None: + """Initialize the meta-AdamW optimizer. + + Args: + module (nn.Module): A network whose parameters should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to denominator outside of the square + root (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Strength of the weight decay regularization. Note that + this weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al., 2019) where + the weight decay is only multiplied with the "schedule multiplier", but not the base + learning rate. (default: :const:`1e-2`) + eps_root (float, optional): A small constant applied to denominator inside the square + root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for + example when computing (meta-)gradients through Adam. (default: :const:`0.0`) + mask (tree of Tensor, callable, or None, optional): + A tree with same structure as (or a prefix of) the params pytree, or a function that + returns such a pytree given the params/updates. The leaves should be booleans, + :data:`True` for leaves/subtrees you want to apply the weight decay to, and + :data:`False` for those you want to skip. Note that the Adam gradient + transformations are applied to all parameters. (default: :data:`None`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) + use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator. + (default: :data:`False`) + """ + super().__init__( + module, + alias.adamw( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + eps_root=eps_root, + mask=mask, + moment_requires_grad=moment_requires_grad, + maximize=maximize, + use_accelerated_op=use_accelerated_op, + ), + ) diff --git a/torchopt/optim/meta/base.py b/torchopt/optim/meta/base.py new file mode 100644 index 00000000..73ecdde7 --- /dev/null +++ b/torchopt/optim/meta/base.py @@ -0,0 +1,114 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The base class for differentiable meta-optimizers.""" + +from __future__ import annotations + +from typing import Sequence + +import torch +import torch.nn as nn + +from torchopt import pytree +from torchopt.base import UninitializedState +from torchopt.typing import GradientTransformation, ModuleTensorContainers, OptState, TupleOfTensors +from torchopt.update import apply_updates +from torchopt.utils import extract_module_containers + + +__all__ = ['MetaOptimizer'] + + +class MetaOptimizer: + """The base class for high-level differentiable optimizers.""" + + def __init__(self, module: nn.Module, impl: GradientTransformation) -> None: + r"""Initialize the meta-optimizer. + + Args: + module (nn.Module): A network whose parameters should be optimized. + impl (GradientTransformation): A low level optimizer function, it could be a optimizer + function provided in :mod:`torchopt.alias` or a customized :func:`torchopt.chain`\ed + transformation. + Note that using ``MetaOptimizer(sgd(moment_requires_grad=True))`` or + ``MetaOptimizer(chain(sgd(moment_requires_grad=True)))`` is equivalent to + :class:`torchopt.MetaSGD`. + """ + if not isinstance(impl, GradientTransformation): + raise TypeError(f'{impl} (type: {type(impl).__name__}) is not a GradientTransformation') + + self.impl: GradientTransformation = impl + self.param_containers_groups: list[ModuleTensorContainers] = [] + self.state_groups: list[OptState] = [] + + self.add_param_group(module) + + def step(self, loss: torch.Tensor) -> None: # pylint: disable=too-many-locals + """Compute the gradients of the loss to the network parameters and update network parameters. + + Graph of the derivative will be constructed, allowing to compute higher order derivative + products. We use the differentiable optimizer (pass argument ``inplace=False``) to scale the + gradients and update the network parameters without modifying tensors in-place. + + Args: + loss (torch.Tensor): The loss that is used to compute the gradients to the network + parameters. + """ + # Step parameter only + for i, (param_container, state) in enumerate( + zip(self.param_containers_groups, self.state_groups), + ): + flat_params: TupleOfTensors + flat_params, container_treespec = pytree.tree_flatten_as_tuple(param_container) # type: ignore[arg-type] + if isinstance(state, UninitializedState): + state = self.impl.init(flat_params) + grads = torch.autograd.grad( + loss, + flat_params, + create_graph=True, + allow_unused=True, + ) + updates, new_state = self.impl.update( + grads, + state, + params=flat_params, + inplace=False, + ) + self.state_groups[i] = new_state + flat_new_params = apply_updates(flat_params, updates, inplace=False) + new_params: ModuleTensorContainers = pytree.tree_unflatten( # type: ignore[assignment] + container_treespec, + flat_new_params, + ) + for container, new_param in zip(param_container, new_params): + container.update(new_param) + + def add_param_group(self, module: nn.Module) -> None: + """Add a param group to the optimizer's ``state_groups``.""" + params_container = extract_module_containers(module, with_buffers=False)[0] + self.param_containers_groups.append(params_container) + self.state_groups.append(UninitializedState()) + + def state_dict(self) -> tuple[OptState, ...]: + """Extract the references of the optimizer states. + + Note that the states are references, so any in-place operations will change the states + inside :class:`MetaOptimizer` at the same time. + """ + return tuple(self.state_groups) + + def load_state_dict(self, state_dict: Sequence[OptState]) -> None: + """Load the references of the optimizer states.""" + self.state_groups[:] = list(state_dict) diff --git a/torchopt/optim/meta/radam.py b/torchopt/optim/meta/radam.py new file mode 100644 index 00000000..bb07b5ba --- /dev/null +++ b/torchopt/optim/meta/radam.py @@ -0,0 +1,79 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Differentiable RAdam optimizer.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from torchopt import alias +from torchopt.optim.meta.base import MetaOptimizer + + +if TYPE_CHECKING: + import torch.nn as nn + + from torchopt.typing import ScalarOrSchedule + + +__all__ = ['MetaRAdam'] + + +class MetaRAdam(MetaOptimizer): + """The differentiable RAdam optimizer. + + See Also: + - The functional RAdam optimizer: :func:`torchopt.radan`. + - The classic RAdam optimizer: :class:`torchopt.RAdam`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + module: nn.Module, + lr: ScalarOrSchedule = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + *, + moment_requires_grad: bool = True, + ) -> None: + """Initialize the meta-RAdam optimizer. + + Args: + module (nn.Module): A network whose parameters should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to the square root (as in the RAdam paper) + to avoid dividing by zero when rescaling. + (default: :const:`1e-6`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + """ + super().__init__( + module, + alias.radam( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + moment_requires_grad=moment_requires_grad, + ), + ) diff --git a/torchopt/optim/meta/rmsprop.py b/torchopt/optim/meta/rmsprop.py new file mode 100644 index 00000000..a8b4abfa --- /dev/null +++ b/torchopt/optim/meta/rmsprop.py @@ -0,0 +1,90 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Differentiable RMSProp optimizer.""" + +import torch.nn as nn + +from torchopt import alias +from torchopt.optim.meta.base import MetaOptimizer +from torchopt.typing import ScalarOrSchedule + + +__all__ = ['MetaRMSProp', 'MetaRMSprop'] + + +class MetaRMSProp(MetaOptimizer): + """The differentiable RMSProp optimizer. + + See Also: + - The functional RMSProp optimizer: :func:`torchopt.rmsprop`. + - The classic RMSProp optimizer: :class:`torchopt.RMSProp`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + module: nn.Module, + lr: ScalarOrSchedule = 1e-2, + alpha: float = 0.99, + eps: float = 1e-8, + weight_decay: float = 0.0, + momentum: float = 0.0, + centered: bool = False, + *, + initial_scale: float = 0.0, + nesterov: bool = False, + maximize: bool = False, + ) -> None: + """Initialize the meta-RMSProp optimizer. + + Args: + module (nn.Module): A network whose parameters should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-2`) + alpha (float, optional): Smoothing constant, the decay used to track the magnitude of + previous gradients. (default: :const:`0.99`) + eps (float, optional): A small numerical constant to avoid dividing by zero when + rescaling. (default: :const:`1e-8`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + momentum (float, optional): The decay rate used by the momentum term. The momentum is + not used when it is set to :const:`0.0`. (default: :const:`0.0`) + centered (bool, optional): If :data:`True`, use the variance of the past gradients to + rescale the latest gradients. (default: :data:`False`) + initial_scale (float, optional): Initialization of accumulators tracking the magnitude + of previous updates. PyTorch uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When + reproducing results from a paper, verify the value used by the authors. + (default: :data:`0.0`) + nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) + """ + super().__init__( + module, + alias.rmsprop( + lr=lr, + alpha=alpha, + eps=eps, + weight_decay=weight_decay, + momentum=momentum, + centered=centered, + initial_scale=initial_scale, + nesterov=nesterov, + maximize=maximize, + ), + ) + + +MetaRMSprop = MetaRMSProp # alias for PyTorch compatibility diff --git a/torchopt/_src/optimizer/meta/sgd.py b/torchopt/optim/meta/sgd.py similarity index 51% rename from torchopt/_src/optimizer/meta/sgd.py rename to torchopt/optim/meta/sgd.py index b8ae5d24..81e04413 100644 --- a/torchopt/_src/optimizer/meta/sgd.py +++ b/torchopt/optim/meta/sgd.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Differentiable SGD optimizer.""" import torch.nn as nn -from torchopt._src.alias import sgd -from torchopt._src.optimizer.meta.base import MetaOptimizer -from torchopt._src.typing import ScalarOrSchedule +from torchopt import alias +from torchopt.optim.meta.base import MetaOptimizer +from torchopt.typing import ScalarOrSchedule + + +__all__ = ['MetaSGD'] class MetaSGD(MetaOptimizer): @@ -31,7 +35,7 @@ class MetaSGD(MetaOptimizer): # pylint: disable-next=too-many-arguments def __init__( self, - net: nn.Module, + module: nn.Module, lr: ScalarOrSchedule, momentum: float = 0.0, weight_decay: float = 0.0, @@ -39,31 +43,28 @@ def __init__( nesterov: bool = False, moment_requires_grad: bool = True, maximize: bool = False, - ): - """The :meth:`init` function. + ) -> None: + """Initialize the meta-SGD optimizer. Args: - net: (nn.Module) - A network whose parameters should be optimized. - lr: This is a fixed global scaling factor. - momentum: (default: :const:`0.0`) - The decay rate used by the momentum term. The momentum is not used when it is set to - :const:`0.0`. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - dampening: (default: :const:`0.0`) - Dampening for momentum. - nesterov: (default: :const:`False`) - Whether to use Nesterov momentum. - moment_requires_grad: (default: :data:`True`) - If :data:`True` the momentums will be created with flag ``requires_grad=True``, this - flag is often used in Meta-Learning algorithms. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. + module (nn.Module): A network whose parameters should be optimized. + lr (float or callable): This is a fixed global scaling factor or a learning rate + scheduler. + momentum (float, optional): The decay rate used by the momentum term. The momentum is + not used when it is set to :const:`0.0`. (default: :const:`0.0`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + dampening (float, optional): Dampening for momentum. (default: :const:`0.0`) + nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) """ super().__init__( - net, - sgd( + module, + alias.sgd( lr=lr, momentum=momentum, weight_decay=weight_decay, diff --git a/torchopt/optim/radam.py b/torchopt/optim/radam.py new file mode 100644 index 00000000..20e9dd22 --- /dev/null +++ b/torchopt/optim/radam.py @@ -0,0 +1,75 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RAdam optimizer.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Iterable + +from torchopt import alias +from torchopt.optim.base import Optimizer + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import ScalarOrSchedule + + +__all__ = ['RAdam'] + + +class RAdam(Optimizer): + """The classic RAdam optimizer. + + See Also: + - The functional Adam optimizer: :func:`torchopt.radam`. + - The differentiable meta-RAdam optimizer: :class:`torchopt.MetaRAdam`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + params: Iterable[torch.Tensor], + lr: ScalarOrSchedule = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + ) -> None: + r"""Initialize the RAdam optimizer. + + Args: + params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what + tensors should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to the square root (as in the RAdam paper) + to avoid dividing by zero when rescaling. + (default: :const:`1e-6`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + """ + super().__init__( + params, + alias.radam( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + moment_requires_grad=False, + ), + ) diff --git a/torchopt/optim/rmsprop.py b/torchopt/optim/rmsprop.py new file mode 100644 index 00000000..032e5864 --- /dev/null +++ b/torchopt/optim/rmsprop.py @@ -0,0 +1,93 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RMSProp optimizer.""" + +from typing import Iterable + +import torch + +from torchopt import alias +from torchopt.optim.base import Optimizer +from torchopt.typing import ScalarOrSchedule + + +__all__ = ['RMSProp', 'RMSprop'] + + +class RMSProp(Optimizer): + """The classic RMSProp optimizer. + + See Also: + - The functional RMSProp optimizer: :func:`torchopt.rmsprop`. + - The differentiable meta-RMSProp optimizer: :class:`torchopt.MetaRMSProp`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + params: Iterable[torch.Tensor], + lr: ScalarOrSchedule = 1e-2, + alpha: float = 0.99, + eps: float = 1e-8, + weight_decay: float = 0.0, + momentum: float = 0.0, + centered: bool = False, + *, + initial_scale: float = 0.0, + nesterov: bool = False, + maximize: bool = False, + ) -> None: + r"""Initialize the RMSProp optimizer. + + Args: + params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what + tensors should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-2`) + alpha (float, optional): Smoothing constant, the decay used to track the magnitude of + previous gradients. (default: :const:`0.99`) + eps (float, optional): A small numerical constant to avoid dividing by zero when + rescaling. (default: :const:`1e-8`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + momentum (float, optional): The decay rate used by the momentum term. The momentum is + not used when it is set to :const:`0.0`. (default: :const:`0.0`) + centered (bool, optional): If :data:`True`, use the variance of the past gradients to + rescale the latest gradients. (default: :data:`False`) + initial_scale (float, optional): Initialization of accumulators tracking the magnitude + of previous updates. PyTorch uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When + reproducing results from a paper, verify the value used by the authors. + (default: :data:`0.0`) + nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) + """ + super().__init__( + params, + alias.rmsprop( + lr=lr, + alpha=alpha, + eps=eps, + weight_decay=weight_decay, + momentum=momentum, + centered=centered, + initial_scale=initial_scale, + nesterov=nesterov, + maximize=maximize, + ), + ) + + +RMSprop = RMSProp # alias for PyTorch compatibility diff --git a/torchopt/_src/optimizer/sgd.py b/torchopt/optim/sgd.py similarity index 51% rename from torchopt/_src/optimizer/sgd.py rename to torchopt/optim/sgd.py index a7f415f6..27cd53c1 100644 --- a/torchopt/_src/optimizer/sgd.py +++ b/torchopt/optim/sgd.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,14 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""SGD optimizer.""" from typing import Iterable import torch -from torchopt._src.alias import sgd -from torchopt._src.optimizer.base import Optimizer -from torchopt._src.typing import ScalarOrSchedule +from torchopt import alias +from torchopt.optim.base import Optimizer +from torchopt.typing import ScalarOrSchedule + + +__all__ = ['SGD'] class SGD(Optimizer): @@ -40,28 +44,29 @@ def __init__( dampening: float = 0.0, nesterov: bool = False, maximize: bool = False, - ): - r"""The :meth:`init` function. + ) -> None: + r"""Initialize the SGD optimizer. Args: - params: (iterable of torch.Tensor) - An iterable of :class:`torch.Tensor`\s. Specifies what tensors should be optimized. - lr: This is a fixed global scaling factor. - momentum: (default: :const:`0.0`) - The decay rate used by the momentum term. The momentum is not used when it is set to - :const:`0.0`. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - dampening: (default: :const:`0.0`) - Dampening for momentum. - nesterov: (default: :data:`False`) - Whether to use Nesterov momentum. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. + params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what + tensors should be optimized. + lr (float or callable): This is a fixed global scaling factor or a learning rate + scheduler. + momentum (float, optional): The decay rate used by the momentum term. The momentum is + not used when it is set to :const:`0.0`. (default: :const:`0.0`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + dampening (float, optional): Dampening for momentum. (default: :const:`0.0`) + nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) """ super().__init__( params, - sgd( + alias.sgd( lr=lr, momentum=momentum, weight_decay=weight_decay, diff --git a/torchopt/py.typed b/torchopt/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/torchopt/pytree.py b/torchopt/pytree.py new file mode 100644 index 00000000..53abc2d2 --- /dev/null +++ b/torchopt/pytree.py @@ -0,0 +1,202 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The PyTree utilities.""" + +from __future__ import annotations + +import functools +import operator +from typing import TYPE_CHECKING, Callable + +import optree +import optree.typing as typing # pylint: disable=unused-import +import torch +import torch.distributed.rpc as rpc +from optree import * # pylint: disable=wildcard-import,unused-wildcard-import + + +if TYPE_CHECKING: + from torchopt.typing import Future, RRef, Scalar, T, TensorTree + + +__all__ = [ + *optree.__all__, + 'tree_flatten_as_tuple', + 'tree_pos', + 'tree_neg', + 'tree_add', + 'tree_add_scalar_mul', + 'tree_sub', + 'tree_sub_scalar_mul', + 'tree_mul', + 'tree_matmul', + 'tree_scalar_mul', + 'tree_truediv', + 'tree_vdot_real', + 'tree_wait', +] + + +def tree_flatten_as_tuple( + tree: PyTree[T], + is_leaf: Callable[[T], bool] | None = None, + *, + none_is_leaf: bool = False, + namespace: str = '', +) -> tuple[tuple[T, ...], PyTreeSpec]: + """Flatten a pytree to a tuple of leaves and a PyTreeSpec. + + Args: + tree (pytree): The pytree to flatten. + is_leaf (callable or None, optional): An optionally specified function that returns + :data:`True` if a given node is a leaf. (default: :data:`None`) + none_is_leaf (bool, optional): If :data:`True`, :data:`None` is considered a leaf rather + than a internal node with no children. (default: :data:`False`) + namespace (str, optional): The namespace of custom tree node types. (default: :const:`''`) + + Returns: + A tuple of (leaves, treespec). + """ + leaves, treespec = tree_flatten(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) + return tuple(leaves), treespec + + +def acc_add(*args: T) -> T: + """Accumulate addition.""" + return functools.reduce(operator.add, args) + + +def acc_mul(*args: T) -> T: + """Accumulate multiplication.""" + return functools.reduce(operator.mul, args) + + +def acc_matmul(*args: T) -> T: + """Accumulate matrix multiplication.""" + return functools.reduce(operator.matmul, args) + + +def tree_pos(tree: PyTree[T]) -> PyTree[T]: + """Apply :func:`operator.pos` over leaves.""" + return tree_map(operator.pos, tree) + + +def tree_neg(tree: PyTree[T]) -> PyTree[T]: + """Apply :func:`operator.neg` over leaves.""" + return tree_map(operator.neg, tree) + + +def tree_add(*trees: PyTree[T]) -> PyTree[T]: + """Tree addition over leaves.""" + return tree_map(acc_add, *trees) + + +def tree_add_scalar_mul( + tree_x: TensorTree, + tree_y: TensorTree, + alpha: Scalar | None = None, +) -> TensorTree: + """Compute ``tree_x + alpha * tree_y``.""" + if alpha is None: + return tree_map(lambda x, y: x.add(y), tree_x, tree_y) + return tree_map(lambda x, y: x.add(y, alpha=alpha), tree_x, tree_y) + + +def tree_sub(minuend_tree: PyTree[T], subtrahend_tree: PyTree[T]) -> PyTree[T]: + """Tree subtraction over leaves.""" + return tree_map(operator.sub, minuend_tree, subtrahend_tree) + + +def tree_sub_scalar_mul( + tree_x: TensorTree, + tree_y: TensorTree, + alpha: Scalar | None = None, +) -> TensorTree: + """Compute ``tree_x - alpha * tree_y``.""" + if alpha is None: + return tree_map(lambda x, y: x.sub(y), tree_x, tree_y) + return tree_map(lambda x, y: x.sub(y, alpha=alpha), tree_x, tree_y) + + +def tree_mul(*trees: PyTree[T]) -> PyTree[T]: + """Tree multiplication over leaves.""" + return tree_map(acc_mul, *trees) + + +def tree_matmul(*trees: PyTree[T]) -> PyTree[T]: + """Tree matrix multiplication over leaves.""" + return tree_map(acc_matmul, *trees) + + +def tree_scalar_mul(scalar: Scalar, multiplicand_tree: PyTree[T]) -> PyTree[T]: + """Tree scalar multiplication over leaves.""" + return tree_map(lambda x: scalar * x, multiplicand_tree) + + +def tree_truediv(dividend_tree: PyTree[T], divisor_tree: PyTree[T]) -> PyTree[T]: + """Tree division over leaves.""" + return tree_map(operator.truediv, dividend_tree, divisor_tree) + + +def _vdot_real_kernel(x: torch.Tensor, y: torch.Tensor) -> float: + """Compute ``dot(x.conj(), y).real``.""" + x = x.contiguous().view(-1) + y = y.contiguous().view(-1) + vdot = torch.dot(x.real, y.real).item() + if x.is_complex() and y.is_complex(): + vdot += torch.dot(x.imag, y.imag).item() + return vdot + + +def tree_vdot_real(tree_x: TensorTree, tree_y: TensorTree) -> float: + """Compute ``dot(tree_x.conj(), tree_y).real.sum()``.""" + leaves_x, treespec = tree_flatten(tree_x) + leaves_y = treespec.flatten_up_to(tree_y) + return sum(map(_vdot_real_kernel, leaves_x, leaves_y)) # type: ignore[arg-type] + + +def tree_wait(future_tree: PyTree[Future[T]]) -> PyTree[T]: + r"""Convert a tree of :class:`Future`\s to a tree of results.""" + futures, treespec = tree_flatten(future_tree) + + results = torch.futures.wait_all(futures) + + return tree_unflatten(treespec, results) + + +if rpc.is_available(): # pragma: no cover + + def tree_as_rref(tree: PyTree[T]) -> PyTree[RRef[T]]: + r"""Convert a tree of local objects to a tree of :class:`RRef`\s.""" + # pylint: disable-next=import-outside-toplevel,redefined-outer-name,reimported + from torch.distributed.rpc import RRef + + return tree_map(RRef, tree) + + def tree_to_here( + rref_tree: PyTree[RRef[T]], + timeout: float = rpc.api.UNSET_RPC_TIMEOUT, + ) -> PyTree[T]: + r"""Convert a tree of :class:`RRef`\s to a tree of local objects.""" + return tree_map(lambda x: x.to_here(timeout=timeout), rref_tree) + + def tree_local_value(rref_tree: PyTree[RRef[T]]) -> PyTree[T]: + r"""Return the local value of a tree of :class:`RRef`\s.""" + return tree_map(lambda x: x.local_value(), rref_tree) + + __all__ += ['tree_as_rref', 'tree_to_here'] + + +del optree, rpc diff --git a/torchopt/schedule/__init__.py b/torchopt/schedule/__init__.py new file mode 100644 index 00000000..d3d3eff5 --- /dev/null +++ b/torchopt/schedule/__init__.py @@ -0,0 +1,38 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/schedule.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Learning rate schedules.""" + +from torchopt.schedule.exponential_decay import exponential_decay +from torchopt.schedule.polynomial import linear_schedule, polynomial_schedule + + +__all__ = ['exponential_decay', 'linear_schedule', 'polynomial_schedule'] diff --git a/torchopt/schedule/exponential_decay.py b/torchopt/schedule/exponential_decay.py new file mode 100644 index 00000000..c19c54b9 --- /dev/null +++ b/torchopt/schedule/exponential_decay.py @@ -0,0 +1,123 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/schedule.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Exponential learning rate decay.""" + +from __future__ import annotations + +import logging +import math +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from torchopt.typing import Numeric, Scalar, Schedule + + +__all__ = ['exponential_decay'] + + +# pylint: disable-next=too-many-arguments +def exponential_decay( + init_value: Scalar, + decay_rate: Scalar, + transition_begin: int = 0, + transition_steps: int = 1, + staircase: bool = False, + end_value: float | None = None, +) -> Schedule: + """Construct a schedule with either continuous or discrete exponential decay. + + This function applies an exponential decay function to a provided initial value. The function + returns the decayed value as follows: + + .. code-block:: python + + decayed_value = init_value * decay_rate**(count / transition_steps) + + If the argument ``staircase`` is :data:`True`, then ``count / transition_steps`` is an integer + division and the decayed value follows a staircase function. + + Args: + init_value (float or Tensor): Initial value for the scalar to be annealed. + decay_rate (float or Tensor): The decay rate. + transition_begin (int, optional): Must be *positive*. After how many steps to start + annealing (before this many steps the scalar value is held fixed at ``init_value``). + (default: :const:`0`) + transition_steps (int, optional): Number of steps over which annealing takes place, the + scalar starts changing at ``transition_begin`` steps and completes the transition by + ``transition_begin + transition_steps`` steps. If ``transition_steps <= 0``, then the + entire annealing process is disabled and the value is held fixed at ``init_value``. + (default: :const:`1`) + staircase (bool, optional): If :data:`True`, decay the scalar at discrete intervals. + (default: :data:`False`) + end_value (float or Tensor, optional): End value of the scalar to be annealed. + (default: :data:`None`) + + Returns: + schedule: A function that maps step counts to values. + """ + if transition_steps is not None and transition_steps <= 0: # pragma: no cover + logging.info( + 'An exponential schedule was set with a non-positive `transition_steps`' + ' value; this will result in a constant schedule with value ' + '`init_value`.', + ) + return lambda count: init_value + + if decay_rate == 0: # pragma: no cover + logging.info( + 'An exponential schedule was set with a zero `decay_rate` value; ' + 'this will result in a constant schedule with value `init_value`.', + ) + return lambda count: init_value + + if transition_begin < 0: # pragma: no cover + logging.info( + 'An exponential schedule was set with a negative `transition_begin` ' + 'value; this will result in `transition_begin` falling back to `0`.', + ) + transition_begin = 0 + + if end_value is not None: # pragma: no cover + clip_fn = max if decay_rate < 1.0 else min + + def schedule(count: Numeric) -> Numeric: + decreased_count = count - transition_begin + p = decreased_count / transition_steps + if staircase: + p = math.floor(p) + decayed_value = init_value if decreased_count <= 0.0 else init_value * (decay_rate**p) + if end_value is not None: + return clip_fn(decayed_value, end_value) + return decayed_value + + return schedule diff --git a/torchopt/_src/schedule.py b/torchopt/schedule/polynomial.py similarity index 59% rename from torchopt/_src/schedule.py rename to torchopt/schedule/polynomial.py index d7367c2b..d2a5160c 100644 --- a/torchopt/_src/schedule.py +++ b/torchopt/schedule/polynomial.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,14 +29,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Polynomial learning rate schedules.""" + +from __future__ import annotations import logging +from typing import TYPE_CHECKING import numpy as np +import torch + + +if TYPE_CHECKING: + from torchopt.typing import Numeric, Scalar, Schedule -from torchopt._src import base -from torchopt._src.typing import Scalar -from torchopt._src.utils import pytree + +__all__ = ['linear_schedule', 'polynomial_schedule'] def polynomial_schedule( @@ -45,48 +53,45 @@ def polynomial_schedule( power: Scalar, transition_steps: int, transition_begin: int = 0, -) -> base.Schedule: - """Constructs a schedule with polynomial transition from init to end value. +) -> Schedule: + """Construct a schedule with polynomial transition from init to end value. Args: - init_value: Initial value for the scalar to be annealed. - end_value: End value of the scalar to be annealed. - power: The power of the polynomial used to transition from ``init`` to ``end``. - transition_steps: - Number of steps over which annealing takes place, the scalar starts changing at - ``transition_begin`` steps and completes the transition by - ``transition_begin + transition_steps`` steps. - If ``transition_steps <= 0``, then the entire annealing process is disabled and the - value is held fixed at ``init_value``. - transition_begin: - Must be *positive*. After how many steps to start annealing (before this many steps the - scalar value is held fixed at ``init_value``). + init_value (float or Tensor): Initial value for the scalar to be annealed. + end_value (float or Tensor): End value of the scalar to be annealed. + power (float or Tensor): The power of the polynomial used to transition from ``init`` to + ``end``. + transition_steps (int): Number of steps over which annealing takes place, the scalar starts + changing at ``transition_begin`` steps and completes the transition by + ``transition_begin + transition_steps`` steps. If ``transition_steps <= 0``, then the + entire annealing process is disabled and the value is held fixed at ``init_value``. + transition_begin (int, optional): Must be *positive*. After how many steps to start + annealing (before this many steps the scalar value is held fixed at ``init_value``). + (default: :const:`0`) Returns: schedule: A function that maps step counts to values. """ - if transition_steps <= 0: + if transition_steps <= 0: # pragma: no cover logging.info( 'A polynomial schedule was set with a non-positive `transition_steps` value; this ' - 'results in a constant schedule with value `init_value`.' + 'results in a constant schedule with value `init_value`.', ) return lambda count: init_value - if transition_begin < 0: + if transition_begin < 0: # pragma: no cover logging.info( 'An exponential schedule was set with a negative `transition_begin` value; this will ' - 'result in `transition_begin` falling back to `0`.' + 'result in `transition_begin` falling back to `0`.', ) transition_begin = 0 - def schedule(count): - def impl(count): - count = np.clip(count - transition_begin, 0, transition_steps) - frac = 1 - count / transition_steps - return (init_value - end_value) * (frac**power) + end_value - - return pytree.tree_map(impl, count) + def schedule(count: Numeric) -> Numeric: + clip = torch.clamp if isinstance(count, torch.Tensor) else np.clip + count = clip(count - transition_begin, 0, transition_steps) # type: ignore[operator] + frac = 1.0 - count / transition_steps + return (init_value - end_value) * (frac**power) + end_value return schedule @@ -97,7 +102,7 @@ def linear_schedule( end_value: Scalar, transition_steps: int, transition_begin: int = 0, -) -> base.Schedule: +) -> Schedule: """Alias polynomial schedule to linear schedule for convenience.""" return polynomial_schedule( init_value=init_value, diff --git a/torchopt/transform/__init__.py b/torchopt/transform/__init__.py new file mode 100644 index 00000000..fa59a43b --- /dev/null +++ b/torchopt/transform/__init__.py @@ -0,0 +1,63 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/transform.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset transformations.""" + +from torchopt.transform.add_decayed_weights import add_decayed_weights, masked +from torchopt.transform.nan_to_num import nan_to_num +from torchopt.transform.scale import scale +from torchopt.transform.scale_by_adadelta import scale_by_adadelta +from torchopt.transform.scale_by_adam import scale_by_accelerated_adam, scale_by_adam +from torchopt.transform.scale_by_adamax import scale_by_adamax +from torchopt.transform.scale_by_radam import scale_by_radam +from torchopt.transform.scale_by_rms import scale_by_rms +from torchopt.transform.scale_by_rss import scale_by_rss +from torchopt.transform.scale_by_schedule import scale_by_schedule +from torchopt.transform.scale_by_stddev import scale_by_stddev +from torchopt.transform.trace import trace + + +__all__ = [ + 'add_decayed_weights', + 'masked', + 'nan_to_num', + 'scale', + 'scale_by_accelerated_adam', + 'scale_by_adadelta', + 'scale_by_adam', + 'scale_by_adamax', + 'scale_by_radam', + 'scale_by_rms', + 'scale_by_rss', + 'scale_by_schedule', + 'scale_by_stddev', + 'trace', +] diff --git a/torchopt/transform/add_decayed_weights.py b/torchopt/transform/add_decayed_weights.py new file mode 100644 index 00000000..0cb67837 --- /dev/null +++ b/torchopt/transform/add_decayed_weights.py @@ -0,0 +1,262 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/transform.py +# https://github.com/deepmind/optax/blob/master/optax/_src/wrappers.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset transformations for adding weight decay to updates.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, NamedTuple + +from torchopt import pytree +from torchopt.base import EmptyState, GradientTransformation, identity +from torchopt.transform.utils import tree_map_flat, tree_map_flat_ + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import OptState, Params, Updates + + +__all__ = ['add_decayed_weights', 'masked'] + + +class MaskedState(NamedTuple): + """Maintain inner transform state for masked transformations.""" + + inner_state: Any + + +class MaskedNode(NamedTuple): + """A node used to mask out unspecified parts of a tree. + + This node is ignored when mapping functions across the tree e.g. using :func:`pytree.tree_map` + since it is a container without children. It can therefore be used to mask out parts of a tree. + """ + + +def masked( + inner: GradientTransformation, + mask: OptState | Callable[[Params], OptState] | None = None, +) -> GradientTransformation: + """Mask updates so only some are transformed, the rest are passed through. + + For example, it is common to skip weight decay for BatchNorm scale and all bias parameters. In + many networks, these are the only parameters with only one dimension. So, you may create a mask + function to mask these out as follows:: + mask_fn = lambda p: pytree.tree_map(lambda x: x.ndim != 1, p) + weight_decay = torchopt.masked(torchopt.add_decayed_weights(0.001), mask_fn) + You may alternatively create the mask pytree upfront:: + mask = pytree.tree_map(lambda x: x.ndim != 1, params) + weight_decay = torchopt.masked(torchopt.add_decayed_weights(0.001), mask) + For the ``inner`` transform, state will only be stored for the parameters that have a mask value + of :data:`True`. + + Args: + inner (GradientTransformation): Inner transformation to mask. + mask (tree of Tensor, callable, or None, optional): A tree with same structure as (or a + prefix of) the params tree, or a function that returns such a tree given the + params/updates. The leaves should be booleans, :data:`True` for leaves/subtrees you want + to apply the transformation to, and :data:`False` for those you want to skip. + (default: :data:`None`) + + Returns: + A :class:`GradientTransformation` wrapping ``inner``. + """ + return _masked(inner=inner, mask=mask, already_flattened=False) + + +def _masked_flat( + inner: GradientTransformation, + mask: OptState | Callable[[Params], OptState] | None = None, +) -> GradientTransformation: + return _masked(inner, mask, already_flattened=True) + + +def _masked( + inner: GradientTransformation, + mask: OptState | Callable[[Params], OptState] | None = None, + *, + already_flattened: bool = False, +) -> GradientTransformation: + if already_flattened: # noqa: SIM108 + tree_map = tree_map_flat + else: + tree_map = pytree.tree_map # type: ignore[assignment] + + def tree_mask(params: Params, mask_tree: OptState) -> Params: + return tree_map(lambda p, m: p if m else MaskedNode(), params, mask_tree) + + def init_fn(params: Params) -> OptState: + mask_tree = mask(params) if callable(mask) else mask + masked_params = tree_mask(params, mask_tree) + return MaskedState(inner_state=inner.init(masked_params)) + + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, + inplace: bool = True, + ) -> tuple[Updates, OptState]: + mask_tree = mask(updates) if callable(mask) else mask + masked_updates = tree_mask(updates, mask_tree) + masked_params = None if params is None else tree_mask(params, mask_tree) + + new_masked_updates, new_inner_state = inner.update( + masked_updates, + state.inner_state, + params=masked_params, + inplace=inplace, + ) + + new_updates = tree_map( + lambda old_u, new_u, m: new_u if m else old_u, + updates, + new_masked_updates, + mask_tree, + ) + return new_updates, MaskedState(inner_state=new_inner_state) + + return GradientTransformation(init_fn, update_fn) + + +masked.flat = _masked_flat # type: ignore[attr-defined] +masked.impl = _masked # type: ignore[attr-defined] + + +AddDecayedWeightsState = EmptyState + + +def add_decayed_weights( + weight_decay: float = 0.0, + mask: OptState | Callable[[Params], OptState] | None = None, +) -> GradientTransformation: + """Add parameter scaled by `weight_decay`. + + Args: + weight_decay (float, optional): A scalar weight decay rate. (default: :const:`0.0`) + mask (tree of Tensor, callable, or None, optional): A tree with same structure as (or a + prefix of) the params tree, or a function that returns such a tree given the + params/updates. The leaves should be booleans, :data:`True` for leaves/subtrees you want + to apply the transformation to, and :data:`False` for those you want to skip. + (default: :data:`None`) + + Returns: + An (init_fn, update_fn) tuple. + """ + return _add_decayed_weights( + weight_decay=weight_decay, + mask=mask, + already_flattened=False, + ) + + +def _add_decayed_weights_flat( + weight_decay: float = 0.0, + mask: OptState | Callable[[Params], OptState] | None = None, +) -> GradientTransformation: + return _add_decayed_weights( + weight_decay=weight_decay, + mask=mask, + already_flattened=True, + ) + + +def _add_decayed_weights( # noqa: C901 + weight_decay: float = 0.0, + mask: OptState | Callable[[Params], OptState] | None = None, + *, + already_flattened: bool = False, +) -> GradientTransformation: + # pylint: disable-next=unneeded-not + if not weight_decay >= 0.0: # pragma: no cover + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + + if weight_decay == 0.0 and mask is None: + return identity() + + if already_flattened: + tree_map = tree_map_flat + tree_map_ = tree_map_flat_ + else: + tree_map = pytree.tree_map # type: ignore[assignment] + tree_map_ = pytree.tree_map_ # type: ignore[assignment] + + def init_fn(params: Params) -> OptState: # pylint: disable=unused-argument + return AddDecayedWeightsState() + + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, + inplace: bool = True, + ) -> tuple[Updates, OptState]: + assert params is not None, ( + 'Parameters are required for weight decay. ' + 'Call `update(updates, state, params=params)` instead.' + ) + + if inplace: + + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g + if g.requires_grad: + return g.add_(p, alpha=weight_decay) + return g.add_(p.data, alpha=weight_decay) + + tree_map_(f, params, updates) + + else: + + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.add(p, alpha=weight_decay) if g is not None else g + + updates = tree_map(f, params, updates) + + return updates, state + + # If mask is not `None`, apply mask to the gradient transformation. + # E.g. it is common to skip weight decay on bias units and batch stats. + if mask is not None: + return masked.impl( # type: ignore[attr-defined] + inner=GradientTransformation(init_fn, update_fn), + mask=mask, + already_flattened=already_flattened, + ) + return GradientTransformation(init_fn, update_fn) + + +add_decayed_weights.flat = _add_decayed_weights_flat # type: ignore[attr-defined] +add_decayed_weights.impl = _add_decayed_weights # type: ignore[attr-defined] diff --git a/torchopt/transform/nan_to_num.py b/torchopt/transform/nan_to_num.py new file mode 100644 index 00000000..740df1b0 --- /dev/null +++ b/torchopt/transform/nan_to_num.py @@ -0,0 +1,65 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset transformations that replaces updates with non-finite values to the given numbers.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from torchopt import pytree +from torchopt.base import EmptyState, GradientTransformation + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import OptState, Params, Updates + + +def nan_to_num( + nan: float = 0.0, + posinf: float | None = None, + neginf: float | None = None, +) -> GradientTransformation: + """Replace updates with values ``nan`` / ``+inf`` / ``-inf`` to the given numbers. + + Returns: + An ``(init_fn, update_fn)`` tuple. + """ + + def init_fn(params: Params) -> OptState: # pylint: disable=unused-argument + return EmptyState() + + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: + if inplace: + + def f(g: torch.Tensor) -> torch.Tensor: + return g.nan_to_num_(nan=nan, posinf=posinf, neginf=neginf) + + else: + + def f(g: torch.Tensor) -> torch.Tensor: + return g.nan_to_num(nan=nan, posinf=posinf, neginf=neginf) + + new_updates = pytree.tree_map(f, updates) + return new_updates, state + + return GradientTransformation(init_fn, update_fn) diff --git a/torchopt/transform/scale.py b/torchopt/transform/scale.py new file mode 100644 index 00000000..2b492bdf --- /dev/null +++ b/torchopt/transform/scale.py @@ -0,0 +1,113 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/transform.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset transformation for scaling updates by learning rate.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from torchopt import pytree +from torchopt.base import EmptyState, GradientTransformation +from torchopt.transform.utils import tree_map_flat, tree_map_flat_ + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import OptState, Params, Updates + + +__all__ = ['scale'] + + +ScaleState = EmptyState + + +def scale(step_size: float) -> GradientTransformation: + """Scale updates by some fixed scalar ``step_size``. + + Args: + step_size (float): A scalar corresponding to a fixed scaling factor for updates. + + Returns: + An ``(init_fn, update_fn)`` tuple. + """ + return _scale(step_size=step_size, already_flattened=False) + + +def _scale_flat(step_size: float) -> GradientTransformation: + return _scale(step_size=step_size, already_flattened=True) + + +def _scale( + step_size: float, + *, + already_flattened: bool = False, +) -> GradientTransformation: + if already_flattened: + tree_map = tree_map_flat + tree_map_ = tree_map_flat_ + else: + tree_map = pytree.tree_map # type: ignore[assignment] + tree_map_ = pytree.tree_map_ # type: ignore[assignment] + + def init_fn(params: Params) -> OptState: # pylint: disable=unused-argument + return ScaleState() + + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: + if inplace: + + def f(g: torch.Tensor) -> torch.Tensor: + return g.mul_(step_size) + + updates = tree_map_(f, updates) + + else: + + def f(g: torch.Tensor) -> torch.Tensor: + return g.mul(step_size) + + updates = tree_map(f, updates) + + return updates, state + + return GradientTransformation(init_fn, update_fn) + + +scale.flat = _scale_flat # type: ignore[attr-defined] +scale.impl = _scale # type: ignore[attr-defined] diff --git a/torchopt/transform/scale_by_adadelta.py b/torchopt/transform/scale_by_adadelta.py new file mode 100644 index 00000000..6d05e5dd --- /dev/null +++ b/torchopt/transform/scale_by_adadelta.py @@ -0,0 +1,160 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Preset transformations for scaling updates by Adam.""" + +# pylint: disable=invalid-name + +from __future__ import annotations + +from typing import TYPE_CHECKING, NamedTuple + +import torch + +from torchopt import pytree +from torchopt.base import GradientTransformation +from torchopt.transform.utils import tree_map_flat, update_moment + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates + + +__all__ = ['scale_by_adadelta'] + + +class ScaleByAdadeltaState(NamedTuple): + """State for the Adadelta algorithm.""" + + mu: Updates + nu: Updates + + +def scale_by_adadelta( + rho: float = 0.9, + eps: float = 1e-6, + moment_requires_grad: bool = False, +) -> GradientTransformation: + """Rescale updates according to the Adadelta algorithm. + + References: + - Zeiler, 2012: https://arxiv.org/abs/1212.5701 + + Args: + rho (float, optional): Decay rate for the squared grads. + (default: :const:`0.9`) + eps (float, optional): Term added to the denominator to improve numerical stability. + (default: :const:`1e-6`) + moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag + ``requires_grad = True``. (default: :data:`False`) + + Returns: + An (init_fn, update_fn) tuple. + """ + return _scale_by_adadelta( + rho=rho, + eps=eps, + moment_requires_grad=moment_requires_grad, + already_flattened=False, + ) + + +def _scale_by_adadelta_flat( + rho: float = 0.9, + eps: float = 1e-6, + moment_requires_grad: bool = False, +) -> GradientTransformation: + return _scale_by_adadelta( + rho=rho, + eps=eps, + moment_requires_grad=moment_requires_grad, + already_flattened=True, + ) + + +def _scale_by_adadelta( + rho: float = 0.9, + eps: float = 1e-6, + moment_requires_grad: bool = False, + *, + already_flattened: bool = False, +) -> GradientTransformation: + # pylint: disable=unneeded-not + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= rho < 1.0: # pragma: no cover + raise ValueError(f'Invalid rho parameter at index 0: {rho}') + # pylint: enable=unneeded-not + + if already_flattened: # noqa: SIM108 + tree_map = tree_map_flat + else: + tree_map = pytree.tree_map # type: ignore[assignment] + + def init_fn(params: Params) -> OptState: + mu = tree_map( # first moment + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), + params, + ) + nu = tree_map( # second moment + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), + params, + ) + return ScaleByAdadeltaState(mu=mu, nu=nu) + + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: + mu = update_moment.impl( # type: ignore[attr-defined] + updates, + state.mu, + rho, + order=2, + inplace=inplace, + already_flattened=already_flattened, + ) + + if inplace: + + def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.mul_(v.add(eps).div_(m.add(eps)).sqrt_()) if g is not None else g + + else: + + def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.mul(v.add(eps).div_(m.add(eps)).sqrt_()) if g is not None else g + + updates = tree_map(f, mu, state.nu, updates) + + nu = update_moment.impl( # type: ignore[attr-defined] + updates, + state.nu, + rho, + order=2, + inplace=inplace, + already_flattened=already_flattened, + ) + + return updates, ScaleByAdadeltaState(mu=mu, nu=nu) + + return GradientTransformation(init_fn, update_fn) + + +scale_by_adadelta.flat = _scale_by_adadelta_flat # type: ignore[attr-defined] +scale_by_adadelta.impl = _scale_by_adadelta # type: ignore[attr-defined] diff --git a/torchopt/transform/scale_by_adam.py b/torchopt/transform/scale_by_adam.py new file mode 100644 index 00000000..d45d1eb2 --- /dev/null +++ b/torchopt/transform/scale_by_adam.py @@ -0,0 +1,420 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/transform.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset transformations for scaling updates by Adam.""" + +# pylint: disable=invalid-name + +from __future__ import annotations + +from typing import TYPE_CHECKING, NamedTuple + +import torch + +from torchopt import pytree +from torchopt.accelerated_op import AdamOp +from torchopt.base import GradientTransformation +from torchopt.transform.utils import inc_count, tree_map_flat, update_moment + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates + + +__all__ = ['scale_by_accelerated_adam', 'scale_by_adam'] + + +TRIPLE_PYTREE_SPEC = pytree.tree_structure((0, 1, 2), none_is_leaf=True) # type: ignore[arg-type] + + +class ScaleByAdamState(NamedTuple): + """State for the Adam algorithm.""" + + mu: Updates + nu: Updates + count: OptState + + +def _bias_correction( + moment: Updates, + decay: float, + count: OptState, + *, + already_flattened: bool = False, +) -> Updates: + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + + def f(t: torch.Tensor, c: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name + return t.div(1 - pow(decay, c)) + + if already_flattened: + return tree_map_flat(f, moment, count) + return pytree.tree_map(f, moment, count) + + +def scale_by_adam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + moment_requires_grad: bool = False, +) -> GradientTransformation: + """Rescale updates according to the Adam algorithm. + + References: + - Kingma et al., 2014: https://arxiv.org/abs/1412.6980 + + Args: + b1 (float, optional): Decay rate for the exponentially weighted average of grads. + (default: :const:`0.9`) + b2 (float, optional): Decay rate for the exponentially weighted average of squared grads. + (default: :const:`0.999`) + eps (float, optional): Term added to the denominator to improve numerical stability. + (default: :const:`1e-8`) + eps_root (float, optional): Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag + ``requires_grad = True``. (default: :data:`False`) + + Returns: + An (init_fn, update_fn) tuple. + """ + return _scale_by_adam( + b1=b1, + b2=b2, + eps=eps, + eps_root=eps_root, + moment_requires_grad=moment_requires_grad, + already_flattened=False, + ) + + +def _scale_by_adam_flat( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + moment_requires_grad: bool = False, +) -> GradientTransformation: + return _scale_by_adam( + b1=b1, + b2=b2, + eps=eps, + eps_root=eps_root, + moment_requires_grad=moment_requires_grad, + already_flattened=True, + ) + + +# pylint: disable-next=too-many-arguments +def _scale_by_adam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + moment_requires_grad: bool = False, + *, + already_flattened: bool = False, +) -> GradientTransformation: + # pylint: disable=unneeded-not + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= b1 < 1.0: # pragma: no cover + raise ValueError(f'Invalid beta parameter at index 0: {b1}') + if not 0.0 <= b2 < 1.0: # pragma: no cover + raise ValueError(f'Invalid beta parameter at index 1: {b2}') + # pylint: enable=unneeded-not + + if already_flattened: # noqa: SIM108 + tree_map = tree_map_flat + else: + tree_map = pytree.tree_map # type: ignore[assignment] + + def init_fn(params: Params) -> OptState: + zero = tree_map( # count init + lambda t: torch.zeros(1, dtype=torch.int64, device=t.device).squeeze_(), + params, + ) + mu = tree_map( # first moment + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), + params, + ) + nu = tree_map( # second moment + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), + params, + ) + return ScaleByAdamState(mu=mu, nu=nu, count=zero) + + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: + mu = update_moment.impl( # type: ignore[attr-defined] + updates, + state.mu, + b1, + order=1, + inplace=inplace, + already_flattened=already_flattened, + ) + nu = update_moment.impl( # type: ignore[attr-defined] + updates, + state.nu, + b2, + order=2, + inplace=inplace, + already_flattened=already_flattened, + ) + # pylint: disable=line-too-long + count_inc = inc_count.impl(updates, state.count, already_flattened=already_flattened) # type: ignore[attr-defined] + mu_hat = _bias_correction(mu, b1, count_inc, already_flattened=already_flattened) + nu_hat = _bias_correction(nu, b2, count_inc, already_flattened=already_flattened) + + if inplace: + + def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return m.div_(v.add_(eps_root).sqrt_().add(eps)) if g is not None else g + + else: + + def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return m.div(v.add(eps_root).sqrt_().add(eps)) if g is not None else g + + updates = tree_map(f, mu_hat, nu_hat, updates) + return updates, ScaleByAdamState(mu=mu, nu=nu, count=count_inc) + + return GradientTransformation(init_fn, update_fn) + + +scale_by_adam.flat = _scale_by_adam_flat # type: ignore[attr-defined] +scale_by_adam.impl = _scale_by_adam # type: ignore[attr-defined] + + +def scale_by_accelerated_adam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + moment_requires_grad: bool = False, +) -> GradientTransformation: + """Rescale updates according to the Adam algorithm. + + This function is accelerated by using some fused accelerated operators. + + References: + - Kingma et al., 2014: https://arxiv.org/abs/1412.6980 + + Args: + b1 (float, optional): Decay rate for the exponentially weighted average of grads. + (default: :const:`0.9`) + b2 (float, optional): Decay rate for the exponentially weighted average of squared grads. + (default: :const:`0.999`) + eps (float, optional): Term added to the denominator to improve numerical stability. + (default: :const:`1e-8`) + eps_root (float, optional): Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag + ``requires_grad = True``. (default: :data:`False`) + + Returns: + An (init_fn, update_fn) tuple. + """ + return _scale_by_accelerated_adam( + b1=b1, + b2=b2, + eps=eps, + eps_root=eps_root, + moment_requires_grad=moment_requires_grad, + already_flattened=False, + ) + + +def _scale_by_accelerated_adam_flat( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + moment_requires_grad: bool = False, +) -> GradientTransformation: + return _scale_by_accelerated_adam( + b1=b1, + b2=b2, + eps=eps, + eps_root=eps_root, + moment_requires_grad=moment_requires_grad, + already_flattened=True, + ) + + +# pylint: disable-next=too-many-arguments +def _scale_by_accelerated_adam( # noqa: C901 + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + moment_requires_grad: bool = False, + *, + already_flattened: bool = False, +) -> GradientTransformation: + # pylint: disable=unneeded-not + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= b1 < 1.0: # pragma: no cover + raise ValueError(f'Invalid beta parameter at index 0: {b1}') + if not 0.0 <= b2 < 1.0: # pragma: no cover + raise ValueError(f'Invalid beta parameter at index 1: {b2}') + # pylint: enable=unneeded-not + + if already_flattened: + tree_map = tree_map_flat + + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: + count_inc = inc_count.impl(updates, state.count, already_flattened=True) # type: ignore[attr-defined] + + op = AdamOp(b1=b1, b2=b2, eps=eps, eps_root=eps_root, inplace=inplace) + + def op_fn( + mu: torch.Tensor | None, + nu: torch.Tensor | None, + update: torch.Tensor | None, + count: torch.Tensor | None, + ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: + if mu is None: + return (None, None, None) + return op(mu, nu, update, count) # type: ignore[arg-type] + + out = tree_map_flat( + op_fn, + state.mu, + state.nu, + updates, + count_inc, + none_is_leaf=True, + ) + + if len(out) == 0: + new_mu, new_nu, new_updates = (), (), () + else: + new_mu, new_nu, new_updates = tuple(zip(*out)) # transpose + + new_mu, new_nu, new_updates = ( + new if type(new) is type(old) else type(old)(new) + for new, old in ( + (new_mu, state.mu), + (new_nu, state.nu), + (new_updates, updates), + ) + ) + + return new_updates, ScaleByAdamState(mu=new_mu, nu=new_nu, count=count_inc) + + else: + tree_map = pytree.tree_map # type: ignore[assignment] + + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: + count_inc = inc_count.impl(updates, state.count, already_flattened=False) # type: ignore[attr-defined] + + new_mu: Updates + new_nu: Updates + new_updates: Updates + + treespec = pytree.tree_structure(updates, none_is_leaf=True) + if treespec.num_leaves > 0: + op = AdamOp(b1=b1, b2=b2, eps=eps, eps_root=eps_root, inplace=inplace) + + def op_fn( + mu: torch.Tensor | None, + nu: torch.Tensor | None, + update: torch.Tensor | None, + count: torch.Tensor | None, + ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: + if mu is None: + return (None, None, None) + return op(mu, nu, update, count) # type: ignore[arg-type] + + out = pytree.tree_map( + op_fn, + state.mu, + state.nu, + updates, + count_inc, + none_is_leaf=True, + ) + + new_mu, new_nu, new_updates = pytree.tree_transpose( # type: ignore[misc] + treespec, + TRIPLE_PYTREE_SPEC, + out, + ) + else: + new_mu = pytree.tree_unflatten(treespec, ()) + new_nu = pytree.tree_unflatten(treespec, ()) + new_updates = pytree.tree_unflatten(treespec, ()) + + return new_updates, ScaleByAdamState(mu=new_mu, nu=new_nu, count=count_inc) + + def init_fn(params: Params) -> OptState: + zero = tree_map( # count init + lambda t: torch.zeros(1, dtype=torch.int64, device=t.device).squeeze_(), + params, + ) + mu = tree_map( # first moment + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), + params, + ) + nu = tree_map( # second moment + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), + params, + ) + return ScaleByAdamState(mu=mu, nu=nu, count=zero) + + return GradientTransformation(init_fn, update_fn) + + +scale_by_accelerated_adam.flat = _scale_by_accelerated_adam_flat # type: ignore[attr-defined] +scale_by_accelerated_adam.impl = _scale_by_accelerated_adam # type: ignore[attr-defined] diff --git a/torchopt/transform/scale_by_adamax.py b/torchopt/transform/scale_by_adamax.py new file mode 100644 index 00000000..cfacbf35 --- /dev/null +++ b/torchopt/transform/scale_by_adamax.py @@ -0,0 +1,161 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Preset transformations for scaling updates by Adamax.""" + +# pylint: disable=invalid-name + +from __future__ import annotations + +from typing import TYPE_CHECKING, NamedTuple + +import torch + +from torchopt import pytree +from torchopt.base import GradientTransformation +from torchopt.transform.utils import tree_map_flat, update_moment + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates + + +__all__ = ['scale_by_adamax'] + + +class ScaleByAdamaxState(NamedTuple): + """State for the Adamax algorithm.""" + + mu: Updates + nu: Updates + t: int + + +def scale_by_adamax( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-6, + moment_requires_grad: bool = False, +) -> GradientTransformation: + """A Adam algorithm variation. + + References: + - Kingma et al., 2014: https://arxiv.org/abs/1412.6980 + + Args: + b1 (float, optional): Decay rate for the exponentially weighted average of grads. + (default: :const:`0.9`) + b2 (float, optional): Decay rate for the exponentially weighted average of squared grads. + (default: :const:`0.999`) + eps (float, optional): Term added to the denominator to improve numerical stability. + (default: :const:`1e-6`) + moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag + ``requires_grad = True``. (default: :data:`False`) + + Returns: + An (init_fn, update_fn) tuple. + """ + return _scale_by_adamax( + b1=b1, + b2=b2, + eps=eps, + moment_requires_grad=moment_requires_grad, + already_flattened=False, + ) + + +def _scale_by_adamax_flat( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-6, + moment_requires_grad: bool = False, +) -> GradientTransformation: + return _scale_by_adamax( + b1=b1, + b2=b2, + eps=eps, + moment_requires_grad=moment_requires_grad, + already_flattened=True, + ) + + +def _scale_by_adamax( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-6, + moment_requires_grad: bool = False, + *, + already_flattened: bool = False, +) -> GradientTransformation: + # pylint: disable=unneeded-not + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= b1 < 1.0: # pragma: no cover + raise ValueError(f'Invalid b1 parameter at index 0: {b1}') + if not 0.0 <= b2 < 1.0: # pragma: no cover + raise ValueError(f'Invalid b1 parameter at index 0: {b2}') + # pylint: enable=unneeded-not + + if already_flattened: # noqa: SIM108 + tree_map = tree_map_flat + else: + tree_map = pytree.tree_map # type: ignore[assignment] + + def init_fn(params: Params) -> OptState: + mu = tree_map( # first moment + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), + params, + ) + nu = tree_map( # second moment + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), + params, + ) + return ScaleByAdamaxState(mu=mu, nu=nu, t=1) + + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: + mu = update_moment.impl( # type: ignore[attr-defined] + updates, + state.mu, + b1, + order=1, + inplace=inplace, + already_flattened=already_flattened, + ) + + def update_nu(n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return torch.max(n.mul(b2), g.abs().add_(eps)) if g is not None else g + + nu = tree_map(update_nu, state.nu, updates) + + one_minus_b1_pow_t = 1 - b1**state.t + + def f(m: torch.Tensor, n: torch.Tensor | None) -> torch.Tensor: + return m.div(n).div_(one_minus_b1_pow_t) if n is not None else m + + updates = tree_map(f, mu, nu) + + return updates, ScaleByAdamaxState(mu=mu, nu=nu, t=state.t + 1) + + return GradientTransformation(init_fn, update_fn) + + +scale_by_adamax.flat = _scale_by_adamax_flat # type: ignore[attr-defined] +scale_by_adamax.impl = _scale_by_adamax # type: ignore[attr-defined] diff --git a/torchopt/transform/scale_by_radam.py b/torchopt/transform/scale_by_radam.py new file mode 100644 index 00000000..95f26149 --- /dev/null +++ b/torchopt/transform/scale_by_radam.py @@ -0,0 +1,207 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Preset transformations for scaling updates by RAdam.""" + +# pylint: disable=invalid-name + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING, NamedTuple + +import torch + +from torchopt import pytree +from torchopt.base import GradientTransformation +from torchopt.transform.utils import tree_map_flat, update_moment + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates + + +__all__ = ['scale_by_radam'] + + +class ScaleByRAdamState(NamedTuple): + """State for the RAdam algorithm.""" + + mu: Updates + nu: Updates + t: int + + +def scale_by_radam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-6, + moment_requires_grad: bool = False, +) -> GradientTransformation: + """Rescale updates according to the RAdam algorithm. + + References: + - Liu, 2019: https://arxiv.org/abs/1908.03265 + + Args: + b1 (float, optional): Decay rate for the exponentially weighted average of grads. + (default: :const:`0.9`) + b2 (float, optional): Decay rate for the exponentially weighted average of squared grads. + (default: :const:`0.999`) + eps (float, optional): Term added to the denominator to improve numerical stability. + (default: :const:`1e-6`) + moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag + ``requires_grad = True``. (default: :data:`False`) + + Returns: + An (init_fn, update_fn) tuple. + """ + return _scale_by_radam( + b1=b1, + b2=b2, + eps=eps, + moment_requires_grad=moment_requires_grad, + already_flattened=False, + ) + + +def _scale_by_radam_flat( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-6, + moment_requires_grad: bool = False, +) -> GradientTransformation: + return _scale_by_radam( + b1=b1, + b2=b2, + eps=eps, + moment_requires_grad=moment_requires_grad, + already_flattened=True, + ) + + +def _scale_by_radam( # noqa: C901 + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-6, + moment_requires_grad: bool = False, + *, + already_flattened: bool = False, +) -> GradientTransformation: + # pylint: disable=unneeded-not + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= b1 < 1.0: # pragma: no cover + raise ValueError(f'Invalid b1 parameter at index 0: {b1}') + if not 0.0 <= b2 < 1.0: # pragma: no cover + raise ValueError(f'Invalid b1 parameter at index 0: {b2}') + # pylint: enable=unneeded-not + + if already_flattened: # noqa: SIM108 + tree_map = tree_map_flat + else: + tree_map = pytree.tree_map # type: ignore[assignment] + + def init_fn(params: Params) -> OptState: + mu = tree_map( # first moment + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), + params, + ) + nu = tree_map( # second moment + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), + params, + ) + return ScaleByRAdamState(mu=mu, nu=nu, t=1) + + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: + mu = update_moment.impl( # type: ignore[attr-defined] + updates, + state.mu, + b1, + order=1, + inplace=inplace, + already_flattened=already_flattened, + ) + + nu = update_moment.impl( # type: ignore[attr-defined] + updates, + state.nu, + b2, + order=2, + inplace=inplace, + already_flattened=already_flattened, + ) + + rho_inf = 2 / (1 - b2) - 1 + one_minus_b1_pow_t = 1 - b1**state.t + one_minus_b2_pow_t = 1 - b2**state.t + rho = rho_inf - 2 * state.t * b2**state.t / one_minus_b2_pow_t + + if rho > 5: + numerator = math.sqrt( + one_minus_b2_pow_t + * (rho - 4) + * (rho - 2) + * rho_inf + / ((rho_inf - 4) * (rho_inf - 2) * rho), + ) + if inplace: + + def f( + m: torch.Tensor, + v: torch.Tensor, + ) -> torch.Tensor: + return m.mul(numerator / one_minus_b1_pow_t).div_(v.sqrt().add_(eps)) + + else: + + def f( + m: torch.Tensor, + v: torch.Tensor, + ) -> torch.Tensor: + return m.mul(numerator / one_minus_b1_pow_t).div(v.sqrt().add(eps)) + + else: + if inplace: + + def f( + m: torch.Tensor, + v: torch.Tensor, # pylint: disable=unused-argument + ) -> torch.Tensor: + return m.div(one_minus_b1_pow_t) + + else: + + def f( + m: torch.Tensor, + v: torch.Tensor, # pylint: disable=unused-argument + ) -> torch.Tensor: + return m.div(one_minus_b1_pow_t) + + updates = tree_map(f, mu, nu) + + return updates, ScaleByRAdamState(mu=mu, nu=nu, t=state.t + 1) + + return GradientTransformation(init_fn, update_fn) + + +scale_by_radam.flat = _scale_by_radam_flat # type: ignore[attr-defined] +scale_by_radam.impl = _scale_by_radam # type: ignore[attr-defined] diff --git a/torchopt/transform/scale_by_rms.py b/torchopt/transform/scale_by_rms.py new file mode 100644 index 00000000..f2141388 --- /dev/null +++ b/torchopt/transform/scale_by_rms.py @@ -0,0 +1,160 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/transform.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset transformations for scaling updates by exponential root mean-squared (RMS).""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, NamedTuple + +import torch + +from torchopt import pytree +from torchopt.base import GradientTransformation +from torchopt.transform.utils import tree_map_flat, tree_map_flat_, update_moment + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates + + +__all__ = ['scale_by_rms'] + + +class ScaleByRmsState(NamedTuple): + """State for exponential root mean-squared (RMS)-normalized updates.""" + + nu: Updates + + +def scale_by_rms( + alpha: float = 0.9, + eps: float = 1e-8, + initial_scale: float = 0.0, +) -> GradientTransformation: + """Rescale updates by the root of the exp. moving avg of the square. + + References: + - Tieleman and Hinton, 2012: http://www.cs.toronto.edu/~hinton/coursera/lecture6/lec6.pdf + + Args: + alpha (float, optional): Decay rate for the exponentially weighted average of squared grads. + (default: :const:`0.9`) + eps (float, optional): Term added to the denominator to improve numerical stability. + (default: :const:`1e-8`) + initial_scale (float, optional): Initial value for second moment. (default: :const:`0.0`) + + Returns: + An (init_fn, update_fn) tuple. + """ + return _scale_by_rms( + alpha=alpha, + eps=eps, + initial_scale=initial_scale, + already_flattened=False, + ) + + +def _scale_by_rms_flat( + alpha: float = 0.9, + eps: float = 1e-8, + initial_scale: float = 0.0, +) -> GradientTransformation: + return _scale_by_rms( + alpha=alpha, + eps=eps, + initial_scale=initial_scale, + already_flattened=True, + ) + + +def _scale_by_rms( + alpha: float = 0.9, + eps: float = 1e-8, + initial_scale: float = 0.0, + *, + already_flattened: bool = False, +) -> GradientTransformation: + # pylint: disable=unneeded-not + if not alpha >= 0.0: # pragma: no cover + raise ValueError(f'Invalid alpha value: {alpha}') + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') + # pylint: enable=unneeded-not + + if already_flattened: + tree_map = tree_map_flat + tree_map_ = tree_map_flat_ + else: + tree_map = pytree.tree_map # type: ignore[assignment] + tree_map_ = pytree.tree_map_ # type: ignore[assignment] + + def init_fn(params: Params) -> OptState: + nu = tree_map(lambda n: torch.full_like(n, initial_scale), params) # second moment + return ScaleByRmsState(nu=nu) + + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: + nu = update_moment.impl( # type: ignore[attr-defined] + updates, + state.nu, + alpha, + order=2, + inplace=inplace, + already_flattened=already_flattened, + ) + + if inplace: + # pylint: disable-next=invalid-name + def f(n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.div_(n.sqrt().add_(eps)) if g is not None else g + + tree_map_(f, nu, updates) + + else: + # pylint: disable-next=invalid-name + def f(n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.div(n.sqrt().add(eps)) if g is not None else g + + updates = tree_map(f, nu, updates) + + return updates, ScaleByRmsState(nu=nu) + + return GradientTransformation(init_fn, update_fn) + + +scale_by_rms.flat = _scale_by_rms_flat # type: ignore[attr-defined] +scale_by_rms.impl = _scale_by_rms # type: ignore[attr-defined] diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py new file mode 100644 index 00000000..642b2e5c --- /dev/null +++ b/torchopt/transform/scale_by_rss.py @@ -0,0 +1,155 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/transform.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset transformations for scaling updates by the root of the sum of all squared gradients.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, NamedTuple + +import torch + +from torchopt import pytree +from torchopt.base import GradientTransformation +from torchopt.transform.utils import tree_map_flat, update_moment + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates + + +__all__ = ['scale_by_rss'] + + +class ScaleByRssState(NamedTuple): + """State holding the sum of gradient squares to date.""" + + sum_of_squares: Updates + + +def scale_by_rss( + initial_accumulator_value: float = 0.0, + eps: float = 1e-10, +) -> GradientTransformation: + """Rescale updates by the root of the sum of all squared gradients to date. + + References: + - Duchi et al., 2011: https://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf + - McMahan et al., 2010: https://arxiv.org/abs/1002.4908 + + Args: + initial_accumulator_value (float, optional): Starting value for accumulators, must be + ``>= 0``. (default: :const:`0.0`) + eps (float, optional): A small floating point value to avoid zero denominator. + (default: :const:`1e-10`) + + Returns: + An (init_fn, update_fn) tuple. + """ + return _scale_by_rss( + initial_accumulator_value=initial_accumulator_value, + eps=eps, + already_flattened=False, + ) + + +def _scale_by_rss_flat( + initial_accumulator_value: float = 0.0, + eps: float = 1e-10, +) -> GradientTransformation: + return _scale_by_rss( + initial_accumulator_value=initial_accumulator_value, + eps=eps, + already_flattened=True, + ) + + +def _scale_by_rss( + initial_accumulator_value: float = 0.0, + eps: float = 1e-10, + *, + already_flattened: bool = False, +) -> GradientTransformation: + if already_flattened: # noqa: SIM108 + tree_map = tree_map_flat + else: + tree_map = pytree.tree_map # type: ignore[assignment] + + def init_fn(params: Params) -> OptState: + sum_of_squares = tree_map( + lambda t: torch.full_like( + t, + initial_accumulator_value, + memory_format=torch.preserve_format, + ), + params, + ) + return ScaleByRssState(sum_of_squares=sum_of_squares) + + def update_fn( + updates: Updates, + state: OptState, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: + sum_of_squares = update_moment.impl( # type: ignore[attr-defined] + updates, + state.sum_of_squares, + decay=1.0, + order=2, + inplace=inplace, + already_flattened=already_flattened, + ) + + if inplace: + + def f(sos: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return ( + torch.where(sos > 0.0, g.div_(sos.sqrt().add_(eps)), 0.0) + if g is not None + else g + ) + + else: + + def f(sos: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return ( + torch.where(sos > 0.0, g.div(sos.sqrt().add(eps)), 0.0) if g is not None else g + ) + + updates = tree_map(f, sum_of_squares, updates) + return updates, ScaleByRssState(sum_of_squares=sum_of_squares) + + return GradientTransformation(init_fn, update_fn) + + +scale_by_rss.flat = _scale_by_rss_flat # type: ignore[attr-defined] +scale_by_rss.impl = _scale_by_rss # type: ignore[attr-defined] diff --git a/torchopt/transform/scale_by_schedule.py b/torchopt/transform/scale_by_schedule.py new file mode 100644 index 00000000..499e2adb --- /dev/null +++ b/torchopt/transform/scale_by_schedule.py @@ -0,0 +1,136 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/transform.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset transformation for scaling updates by learning rate schedules.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, NamedTuple + +import torch + +from torchopt import pytree +from torchopt.base import GradientTransformation +from torchopt.transform.utils import inc_count, tree_map_flat, tree_map_flat_ + + +if TYPE_CHECKING: + from torchopt.typing import Numeric, OptState, Params, Schedule, SequenceOfTensors, Updates + + +__all__ = ['scale_by_schedule'] + + +class ScaleByScheduleState(NamedTuple): + """Maintain count for scale scheduling.""" + + count: SequenceOfTensors # type: ignore + + +def scale_by_schedule(step_size_fn: Schedule) -> GradientTransformation: + """Scale updates using a custom schedule for the ``step_size``. + + Args: + step_size_fn (callable): A function that takes an update count as input and proposes the + ``step_size`` to multiply the updates by. + + Returns: + An ``(init_fn, update_fn)`` tuple. + """ + return _scale_by_schedule(step_size_fn=step_size_fn, already_flattened=False) + + +def _scale_by_schedule_flat(step_size_fn: Schedule) -> GradientTransformation: + return _scale_by_schedule(step_size_fn=step_size_fn, already_flattened=True) + + +def _scale_by_schedule( + step_size_fn: Schedule, + *, + already_flattened: bool = False, +) -> GradientTransformation: + if already_flattened: + tree_map = tree_map_flat + tree_map_ = tree_map_flat_ + else: + tree_map = pytree.tree_map # type: ignore[assignment] + tree_map_ = pytree.tree_map_ # type: ignore[assignment] + + def init_fn(params: Params) -> OptState: + zero = tree_map( # count init + lambda t: torch.zeros(1, dtype=torch.int64, device=t.device).squeeze_(), + params, + ) + return ScaleByScheduleState(count=zero) + + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: + if inplace: + # pylint: disable-next=invalid-name + def f(c: Numeric, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g + step_size = step_size_fn(c) + return g.mul_(step_size) + + tree_map_(f, state.count, updates) + + else: + # pylint: disable-next=invalid-name + def f(c: Numeric, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g + step_size = step_size_fn(c) + return g.mul(step_size) + + updates = tree_map(f, state.count, updates) + + return ( + updates, + ScaleByScheduleState( + count=inc_count.impl( # type: ignore[attr-defined] + updates, + state.count, + already_flattened=already_flattened, + ), + ), + ) + + return GradientTransformation(init_fn, update_fn) + + +scale_by_schedule.flat = _scale_by_schedule_flat # type: ignore[attr-defined] +scale_by_schedule.impl = _scale_by_schedule # type: ignore[attr-defined] diff --git a/torchopt/transform/scale_by_stddev.py b/torchopt/transform/scale_by_stddev.py new file mode 100644 index 00000000..5a3e6655 --- /dev/null +++ b/torchopt/transform/scale_by_stddev.py @@ -0,0 +1,172 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/transform.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset transformations for scaling updates by the root of the centered exponential moving average.""" + +# pylint: disable=invalid-name + +from __future__ import annotations + +from typing import TYPE_CHECKING, NamedTuple + +import torch + +from torchopt import pytree +from torchopt.base import GradientTransformation +from torchopt.transform.utils import tree_map_flat, tree_map_flat_, update_moment + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates + + +__all__ = ['scale_by_stddev'] + + +class ScaleByRStdDevState(NamedTuple): + """State for centered exponential moving average of squares of updates.""" + + mu: Updates + nu: Updates + + +def scale_by_stddev( + alpha: float = 0.9, + eps: float = 1e-8, + initial_scale: float = 0.0, +) -> GradientTransformation: + """Rescale updates by the root of the centered exponential moving average of squares. + + References: + - Tieleman and Hinton, 2012: http://www.cs.toronto.edu/~hinton/coursera/lecture6/lec6.pdf + + Args: + alpha (float, optional): Decay rate for the exponentially weighted average of squared grads. + (default: :const:`0.9`) + eps (float, optional): Term added to the denominator to improve numerical stability. + (default: :const:`1e-8`) + initial_scale (float, optional): Initial value for second moment. (default: :const:`0.0`) + + Returns: + An (init_fn, update_fn) tuple. + """ + return _scale_by_stddev( + alpha=alpha, + eps=eps, + initial_scale=initial_scale, + already_flattened=False, + ) + + +def _scale_by_stddev_flat( + alpha: float = 0.9, + eps: float = 1e-8, + initial_scale: float = 0.0, +) -> GradientTransformation: + return _scale_by_stddev( + alpha=alpha, + eps=eps, + initial_scale=initial_scale, + already_flattened=True, + ) + + +def _scale_by_stddev( + alpha: float = 0.9, + eps: float = 1e-8, + initial_scale: float = 0.0, + *, + already_flattened: bool = False, +) -> GradientTransformation: + # pylint: disable=unneeded-not + if not alpha >= 0.0: # pragma: no cover + raise ValueError(f'Invalid alpha value: {alpha}') + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') + # pylint: enable=unneeded-not + + if already_flattened: + tree_map = tree_map_flat + tree_map_ = tree_map_flat_ + else: + tree_map = pytree.tree_map # type: ignore[assignment] + tree_map_ = pytree.tree_map_ # type: ignore[assignment] + + def init_fn(params: Params) -> OptState: + mu = tree_map(torch.zeros_like, params) # first moment + nu = tree_map(lambda n: torch.full_like(n, initial_scale), params) # second moment + return ScaleByRStdDevState(mu=mu, nu=nu) + + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: + mu = update_moment.impl( # type: ignore[attr-defined] + updates, + state.mu, + alpha, + order=1, + inplace=inplace, + already_flattened=already_flattened, + ) + nu = update_moment.impl( # type: ignore[attr-defined] + updates, + state.nu, + alpha, + order=2, + inplace=inplace, + already_flattened=already_flattened, + ) + + if inplace: + + def f(m: torch.Tensor, n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.div_(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) if g is not None else g + + tree_map_(f, mu, nu, updates) + + else: + + def f(m: torch.Tensor, n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.div(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) if g is not None else g + + updates = tree_map(f, mu, nu, updates) + + return updates, ScaleByRStdDevState(mu=mu, nu=nu) + + return GradientTransformation(init_fn, update_fn) + + +scale_by_stddev.flat = _scale_by_stddev_flat # type: ignore[attr-defined] +scale_by_stddev.impl = _scale_by_stddev # type: ignore[attr-defined] diff --git a/torchopt/transform/trace.py b/torchopt/transform/trace.py new file mode 100644 index 00000000..219cbbec --- /dev/null +++ b/torchopt/transform/trace.py @@ -0,0 +1,217 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/transform.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset transformations for scaling updates by Adam.""" + +# pylint: disable=invalid-name + +from __future__ import annotations + +from typing import TYPE_CHECKING, NamedTuple + +import torch + +from torchopt import pytree +from torchopt.base import GradientTransformation, identity +from torchopt.transform.utils import tree_map_flat, tree_map_flat_ + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates + + +__all__ = ['trace'] + + +class TraceState(NamedTuple): + """Hold an aggregation of past updates.""" + + trace: Params + + +def trace( + momentum: float = 0.9, + dampening: float = 0.0, + nesterov: bool = False, + moment_requires_grad: bool = False, +) -> GradientTransformation: + """Compute a trace of past updates. + + Note: `trace` and `ema` have very similar but distinct updates; + `trace = decay * trace + t`, while `ema = decay * ema + (1 - decay) * t`. + Both are frequently found in the optimization literature. + + Args: + momentum (float, optional): The decay rate for the trace of past updates. + (default: :const:`0.9`) + dampening (float, optional): Dampening for momentum. (default: :const:`0.0`) + nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`) + moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag + ``requires_grad = True``. (default: :data:`False`) + + Returns: + An (init_fn, update_fn) tuple. + """ + return _trace( + momentum=momentum, + dampening=dampening, + nesterov=nesterov, + moment_requires_grad=moment_requires_grad, + already_flattened=False, + ) + + +def _trace_flat( + momentum: float = 0.9, + dampening: float = 0.0, + nesterov: bool = False, + moment_requires_grad: bool = False, +) -> GradientTransformation: + return _trace( + momentum=momentum, + dampening=dampening, + nesterov=nesterov, + moment_requires_grad=moment_requires_grad, + already_flattened=True, + ) + + +def _trace( # noqa: C901 + momentum: float = 0.9, + dampening: float = 0.0, + nesterov: bool = False, + moment_requires_grad: bool = False, + *, + already_flattened: bool = False, +) -> GradientTransformation: + # pylint: disable=unneeded-not + if not momentum >= 0.0: # pragma: no cover + raise ValueError(f'Invalid momentum value: {momentum}') + if nesterov and (momentum <= 0.0 or dampening != 0.0): # pragma: no cover + raise ValueError('Nesterov momentum requires a momentum and zero dampening') + # pylint: enable=unneeded-not + + if momentum == 0.0: + return identity() + + if already_flattened: + tree_map = tree_map_flat + tree_map_ = tree_map_flat_ + else: + tree_map = pytree.tree_map # type: ignore[assignment] + tree_map_ = pytree.tree_map_ # type: ignore[assignment] + + def init_fn(params: Params) -> OptState: + return TraceState( + trace=tree_map( + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), + params, + ), + ) + + first_call = True + + def update_fn( # noqa: C901 + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: + nonlocal first_call + + if nesterov: + if inplace: + + def f1(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g + if first_call: + return t.add_(g) + return t.mul_(momentum).add_(g) + + def f2(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.add_(t, alpha=momentum) if g is not None else g + + new_trace = tree_map(f1, state.trace, updates) + tree_map_(f2, new_trace, updates) + + else: + + def f1(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g + if first_call: + return t.add(g) + return t.mul(momentum).add_(g) + + def f2(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.add(t, alpha=momentum) if g is not None else g + + new_trace = tree_map(f1, state.trace, updates) + updates = tree_map(f2, new_trace, updates) + + else: + if inplace: + + def f(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g + if first_call: + return t.add_(g) + return t.mul_(momentum).add_(g, alpha=1.0 - dampening) + + def copy_to_(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.copy_(t) if g is not None else g + + new_trace = tree_map(f, state.trace, updates) + tree_map_(copy_to_, new_trace, updates) + + else: + + def f(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g + if first_call: + return t.add(g) + return t.mul(momentum).add_(g, alpha=1.0 - dampening) + + new_trace = tree_map(f, state.trace, updates) + updates = tree_map(torch.clone, new_trace) + + first_call = False + return updates, TraceState(trace=new_trace) + + return GradientTransformation(init_fn, update_fn) + + +trace.flat = _trace_flat # type: ignore[attr-defined] +trace.impl = _trace # type: ignore[attr-defined] diff --git a/torchopt/transform/utils.py b/torchopt/transform/utils.py new file mode 100644 index 00000000..9b38d561 --- /dev/null +++ b/torchopt/transform/utils.py @@ -0,0 +1,230 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/transform.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for the preset transformations.""" + +from __future__ import annotations + +from collections import deque +from typing import TYPE_CHECKING, Any, Callable, Sequence + +import torch + +from torchopt import pytree + + +if TYPE_CHECKING: + from torchopt.typing import TensorTree, Updates + + +__all__ = ['inc_count', 'tree_map_flat', 'tree_map_flat_', 'update_moment'] + + +INT64_MAX = torch.iinfo(torch.int64).max + + +def tree_map_flat( + func: Callable, + flat_arg: Sequence[Any], + *flat_args: Any, + none_is_leaf: bool = False, +) -> Sequence[Any]: + """Apply a function to each element of a flattened list.""" + if none_is_leaf: + fn = func + else: + + def fn(x: Any | None, *xs: Any) -> Any | None: + return func(x, *xs) if x is not None else None + + return flat_arg.__class__(map(fn, flat_arg, *flat_args)) # type: ignore[call-arg] + + +def tree_map_flat_( + func: Callable, + flat_arg: Sequence[Any], + *flat_args: Any, + none_is_leaf: bool = False, +) -> Sequence[Any]: + """Apply a function to each element of a flattened list and return the original list.""" + if none_is_leaf: + fn = func + else: + + def fn(x: Any | None, *xs: Any) -> Any | None: + return func(x, *xs) if x is not None else None + + flat_results = map(fn, flat_arg, *flat_args) + deque(flat_results, maxlen=0) # consume and exhaust the iterable + return flat_arg + + +def inc_count(updates: Updates, count: TensorTree) -> TensorTree: + """Increment int counter by one. + + Returns: + A counter incremented by one, or :data:`INT64_MAX` if the maximum precision is reached. + """ + return _inc_count( + updates=updates, + count=count, + already_flattened=False, + ) + + +def _inc_count_flat(updates: Updates, count: TensorTree) -> TensorTree: + return _inc_count( + updates=updates, + count=count, + already_flattened=True, + ) + + +def _inc_count( + updates: Updates, + count: TensorTree, + *, + already_flattened: bool = False, +) -> TensorTree: + def f(c: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor: # pylint: disable=invalid-name + return c + (c != INT64_MAX).to(torch.int64) if g is not None else c + + if already_flattened: + return tree_map_flat(f, count, updates, none_is_leaf=True) + return pytree.tree_map(f, count, updates, none_is_leaf=True) + + +inc_count.flat = _inc_count_flat # type: ignore[attr-defined] +inc_count.impl = _inc_count # type: ignore[attr-defined] + + +def update_moment( + updates: Updates, + moments: TensorTree, + decay: float, + *, + order: int, + inplace: bool = True, +) -> TensorTree: + """Compute the exponential moving average of the ``order``-th moment.""" + return _update_moment( + updates, + moments, + decay, + order=order, + inplace=inplace, + already_flattened=False, + ) + + +def _update_moment_flat( + updates: Updates, + moments: TensorTree, + decay: float, + *, + order: int, + inplace: bool = True, +) -> TensorTree: + return _update_moment( + updates, + moments, + decay, + order=order, + inplace=inplace, + already_flattened=True, + ) + + +# pylint: disable-next=too-many-arguments +def _update_moment( # noqa: C901 + updates: Updates, + moments: TensorTree, + decay: float, + *, + order: int, + inplace: bool = True, + already_flattened: bool = False, +) -> TensorTree: + assert order in (1, 2) + + if inplace: + if order == 2: + if decay != 1.0: + + def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor: + return t.mul_(decay).addcmul_(g, g, value=1 - decay) if g is not None else t + + else: + + def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor: + return t.addcmul_(g, g) if g is not None else t + + else: + if decay != 1.0: + + def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor: + return t.mul_(decay).add_(g, alpha=1 - decay) if g is not None else t + + else: + + def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor: + return t.add_(g) if g is not None else t + + else: + if order == 2: + if decay != 1.0: + + def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor: + return t.mul(decay).addcmul_(g, g, value=1 - decay) if g is not None else t + + else: + + def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor: + return t.addcmul(g, g) if g is not None else t + + else: + if decay != 1.0: + + def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor: + return t.mul(decay).add_(g, alpha=1 - decay) if g is not None else t + + else: + + def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor: + return t.add(g) if g is not None else t + + if already_flattened: + return tree_map_flat(f, updates, moments, none_is_leaf=True) + return pytree.tree_map(f, updates, moments, none_is_leaf=True) + + +update_moment.flat = _update_moment_flat # type: ignore[attr-defined] +update_moment.impl = _update_moment # type: ignore[attr-defined] diff --git a/torchopt/typing.py b/torchopt/typing.py new file mode 100644 index 00000000..fcd888fb --- /dev/null +++ b/torchopt/typing.py @@ -0,0 +1,149 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Typing utilities.""" + +from __future__ import annotations + +import abc +from typing import ( + Callable, + Dict, + List, + Optional, + Protocol, + Sequence, + Tuple, + TypeVar, + Union, + runtime_checkable, +) +from typing_extensions import TypeAlias # Python 3.10+ + +import torch +import torch.distributed.rpc as rpc +from optree.typing import PyTree, PyTreeTypeVar +from torch import Tensor +from torch.distributions import Distribution +from torch.futures import Future + +from torchopt.base import ( + ChainedGradientTransformation, + EmptyState, + GradientTransformation, + UninitializedState, +) + + +__all__ = [ + 'ChainedGradientTransformation', + 'Device', + 'Distribution', + 'EmptyState', + 'Future', + 'GradientTransformation', + 'LinearSolver', + 'ListOfOptionalTensors', + 'ListOfTensors', + 'ModuleTensorContainers', + 'Numeric', + 'OptState', + 'OptionalTensor', + 'OptionalTensorOrOptionalTensors', + 'OptionalTensorTree', + 'Params', + 'PyTree', + 'Samplable', + 'SampleFunc', + 'Scalar', + 'ScalarOrSchedule', + 'Schedule', + 'SequenceOfOptionalTensors', + 'SequenceOfTensors', + 'Size', + 'Tensor', + 'TensorContainer', + 'TensorOrTensors', + 'TensorTree', + 'TupleOfOptionalTensors', + 'TupleOfTensors', + 'UninitializedState', + 'Updates', +] + +T = TypeVar('T') + +Device: TypeAlias = Union[torch.device, str, int] + +Scalar: TypeAlias = Union[float, int, bool] +Numeric: TypeAlias = Union[Tensor, Scalar] + +Schedule: TypeAlias = Callable[[Numeric], Numeric] +ScalarOrSchedule: TypeAlias = Union[float, Schedule] + +OptionalTensor = Optional[Tensor] + +ListOfTensors = List[Tensor] +TupleOfTensors = Tuple[Tensor, ...] +SequenceOfTensors = Sequence[Tensor] +TensorOrTensors = Union[Tensor, SequenceOfTensors] +TensorTree: TypeAlias = PyTreeTypeVar('TensorTree', Tensor) # type: ignore[valid-type] + +ListOfOptionalTensors = List[OptionalTensor] +TupleOfOptionalTensors = Tuple[OptionalTensor, ...] +SequenceOfOptionalTensors = Sequence[OptionalTensor] +OptionalTensorOrOptionalTensors = Union[OptionalTensor, SequenceOfOptionalTensors] +OptionalTensorTree: TypeAlias = PyTreeTypeVar('OptionalTensorTree', OptionalTensor) # type: ignore[valid-type] + +TensorContainer = Dict[str, Optional[Tensor]] +ModuleTensorContainers = Tuple[TensorContainer, ...] + +# Parameters are arbitrary nests of `torch.Tensor`. +Params: TypeAlias = TensorTree +Updates: TypeAlias = Params # Gradient updates are of the same type as parameters. +OptState: TypeAlias = TensorTree # States are arbitrary nests of `torch.Tensor`. + +if rpc.is_available(): # pragma: no cover + from torch.distributed.rpc import RRef # pylint: disable=ungrouped-imports,unused-import + + __all__ += ['RRef'] +else: # pragma: no cover + # pylint: disable-next=invalid-name + RRef = None # type: ignore[misc,assignment] + +# solver(matvec, b) -> solution +LinearSolver: TypeAlias = Callable[[Callable[[TensorTree], TensorTree], TensorTree], TensorTree] + + +Size = torch.Size + +# sample(sample_shape) -> Tensor +SampleFunc: TypeAlias = Callable[[Size], Union[Tensor, Sequence[Numeric]]] + + +@runtime_checkable +class Samplable(Protocol): # pylint: disable=too-few-public-methods + """Abstract protocol class that supports sampling.""" + + @abc.abstractmethod + def sample( + self, + sample_shape: Size = Size(), # noqa: B008 # pylint: disable=unused-argument + ) -> Tensor | Sequence[Numeric]: + # pylint: disable-next=line-too-long + """Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.""" + raise NotImplementedError + + +Samplable.register(Distribution) diff --git a/torchopt/_src/update.py b/torchopt/update.py similarity index 71% rename from torchopt/_src/update.py rename to torchopt/update.py index 753292d7..3f2d71fe 100644 --- a/torchopt/_src/update.py +++ b/torchopt/update.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,15 +29,26 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Helper functions for applying updates.""" -from torchopt._src import base # pylint: disable=unused-import -from torchopt._src.utils import pytree +from __future__ import annotations +from typing import TYPE_CHECKING -def apply_updates( - params: 'base.Params', updates: 'base.Updates', *, inplace: bool = True -) -> 'base.Params': - """Applies an update to the corresponding parameters. +from torchopt import pytree + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import Params, Updates + + +__all__ = ['apply_updates'] + + +def apply_updates(params: Params, updates: Updates, *, inplace: bool = True) -> Params: + """Apply an update to the corresponding parameters. This is a utility functions that applies an update to a set of parameters, and then returns the updated parameters to the caller. As an example, the update may be a gradient transformed by a @@ -46,25 +57,25 @@ def apply_updates( :func:`tree_map` (e.g. if you want to manipulate updates in custom ways before applying them). Args: - params: A tree of parameters. - updates: - A tree of updates, the tree structure and the shape of the leaf nodes must match that - of ``params``. - inplace: If :data:`True`, will update params in a inplace manner. + params (tree of Tensor): A tree of parameters. + updates (tree of Tensor): A tree of updates, the tree structure and the shape of the leaf + nodes must match that of ``params``. + inplace (bool, optional): If :data:`True`, will update params in a inplace manner. + (default: :data:`True`) Returns: Updated parameters, with same structure, shape and type as ``params``. """ if inplace: - def f(p, u): + def f(p: torch.Tensor, u: torch.Tensor | None) -> torch.Tensor: if u is not None: p.data.add_(u) return p else: - def f(p, u): + def f(p: torch.Tensor, u: torch.Tensor | None) -> torch.Tensor: return p.add(u) if u is not None else p return pytree.tree_map(f, params, updates) diff --git a/torchopt/utils.py b/torchopt/utils.py new file mode 100644 index 00000000..5f9202a3 --- /dev/null +++ b/torchopt/utils.py @@ -0,0 +1,515 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for TorchOpt.""" + +from __future__ import annotations + +import copy +import itertools +from typing import TYPE_CHECKING, Any, Literal, NamedTuple, Sequence, cast, overload +from typing_extensions import TypeAlias # Python 3.10+ + +import torch +import torch.nn as nn + +from torchopt import pytree +from torchopt.typing import Device, ModuleTensorContainers, OptState, TensorContainer, TensorTree + + +if TYPE_CHECKING: + from torchopt.optim.meta.base import MetaOptimizer + + +__all__ = [ + 'ModuleState', + 'extract_state_dict', + 'module_clone', + 'module_detach_', + 'recover_state_dict', + 'stop_gradient', +] + + +class ModuleState(NamedTuple): + """Container for module state.""" + + params: tuple[TensorContainer, ...] + buffers: tuple[TensorContainer, ...] + visual_contents: dict | None = None + detach_buffers: bool = False + + +CopyMode: TypeAlias = Literal['reference', 'copy', 'deepcopy', 'ref', 'clone', 'deepclone'] + + +def stop_gradient(target: ModuleState | nn.Module | MetaOptimizer | TensorTree) -> None: + """Stop the gradient for the input object. + + Since a tensor use ``grad_fn`` to connect itself with the previous computation graph, the + backpropagated gradient will flow over the tensor and continue flow to the tensors that is + connected by ``grad_fn``. Some algorithms requires manually detaching tensors from the + computation graph. + + Note that the :func:`stop_gradient` operation is in-place. + + Args: + target (ModuleState, nn.Module, MetaOptimizer, or tree of Tensor): The target that to be + detached from the computation graph, it could be a :class:`nn.Module`, + :class:`torchopt.MetaOptimizer`, state of the :class:`torchopt.MetaOptimizer`, or just + a plain list of tensors. + """ + # pylint: disable-next=import-outside-toplevel + from torchopt.optim.meta.base import MetaOptimizer + + def fn_(obj: Any) -> None: + if isinstance(obj, torch.Tensor): + requires_grad = obj.requires_grad + obj.detach_().requires_grad_(requires_grad) + + if isinstance(target, ModuleState): + true_target = cast(TensorTree, (target.params, target.buffers)) + elif isinstance(target, nn.Module): + true_target = cast(TensorTree, tuple(target.parameters())) + elif isinstance(target, MetaOptimizer): + true_target = cast(TensorTree, target.state_dict()) + else: + true_target = cast(TensorTree, target) # tree of tensors + + pytree.tree_map_(fn_, true_target) + + +@overload +def extract_state_dict( # pylint: disable=too-many-arguments + target: nn.Module, + *, + by: CopyMode = 'reference', + device: Device | None = None, + with_buffers: bool = True, + detach_buffers: bool = False, + enable_visual: bool = False, + visual_prefix: str = '', +) -> ModuleState: # pragma: no cover + ... + + +@overload +def extract_state_dict( + target: MetaOptimizer, + *, + by: CopyMode = 'reference', + device: Device | None = None, +) -> tuple[OptState, ...]: # pragma: no cover + ... + + +# pylint: disable-next=too-many-arguments,too-many-branches,too-many-locals +def extract_state_dict( # noqa: C901 + target: nn.Module | MetaOptimizer, + *, + by: CopyMode = 'reference', + device: Device | None = None, + with_buffers: bool = True, + detach_buffers: bool = False, + enable_visual: bool = False, + visual_prefix: str = '', +) -> ModuleState | tuple[OptState, ...]: + """Extract target state. + + Since a tensor use ``grad_fn`` to connect itself with the previous computation graph, the + backpropagated gradient will flow over the tensor and continue flow to the tensors that is + connected by ``grad_fn``. Some algorithms requires manually detaching tensors from the + computation graph. + + Note that the extracted state is a reference, which means any in-place operator will affect the + target that the state is extracted from. + + Args: + target (nn.Module or MetaOptimizer): It could be a :class:`nn.Module` or + :class:`torchopt.MetaOptimizer`. + by (str, optional): The extract policy of tensors in the target. (default: :const:`'reference'`) + - :const:`'reference'`: The extracted tensors will be references to the original + tensors. + - :const:`'copy'`: The extracted tensors will be clones of the original tensors. This + makes the copied tensors have ``grad_fn`` to be a ```` function points + to the original tensors. + - :const:`'deepcopy'`: The extracted tensors will be deep-copied from the original + tensors. The deep-copied tensors will detach from the original computation graph. + device (Device or None, optional): If specified, move the extracted state to the specified + device. (default: :const:`None`) + with_buffers (bool, optional): Extract buffer together with parameters, this argument is + only used if the input target is :class:`nn.Module`. (default: :const:`True`) + detach_buffers (bool, optional): Whether to detach the reference to the buffers, this + argument is only used if the input target is :class:`nn.Module` and ``by='reference'``. + (default: :const:`False`) + enable_visual (bool, optional): Add additional annotations, which could be used in + computation graph visualization. Currently, this flag only has effect on + :class:`nn.Module` but we will support :class:`torchopt.MetaOptimizer` later. + (default: :const:`False`) + visual_prefix (str, optional): Prefix for the visualization annotations. + (default: :const:`''`) + + Returns: + State extracted of the input object. + """ + assert by in ('reference', 'copy', 'deepcopy', 'ref', 'clone', 'deepclone') + by = by.replace('clone', 'copy') + by = 'reference' if by == 'ref' else by + + # pylint: disable=import-outside-toplevel + from torchopt.optim.meta.base import MetaOptimizer + + if device is not None: + target_device = torch.device(device) + + def reference(t: torch.Tensor) -> torch.Tensor: + return t.to(device=target_device) + + def clone(t: torch.Tensor) -> torch.Tensor: + return t.clone().to(device=target_device) + + def clone_detach_(t: torch.Tensor) -> torch.Tensor: + if isinstance(t, nn.Parameter): + return nn.Parameter( + t.clone().to(device=target_device).detach_(), + requires_grad=t.requires_grad, + ) + return t.clone().to(device=target_device).detach_().requires_grad_(t.requires_grad) + + else: + + def reference(t: torch.Tensor) -> torch.Tensor: + return t + + def clone(t: torch.Tensor) -> torch.Tensor: + return t.clone() + + def clone_detach_(t: torch.Tensor) -> torch.Tensor: + if isinstance(t, nn.Parameter): + return nn.Parameter(t.clone().detach_(), requires_grad=t.requires_grad) + return t.clone().detach_().requires_grad_(t.requires_grad) + + if by == 'reference': + replicate = reference + elif by == 'copy': + replicate = clone + else: + replicate = clone_detach_ + + if isinstance(target, nn.Module): # pylint: disable=no-else-return + if enable_visual: + visual_contents = {} + + for k, v in target.named_parameters(): # pylint: disable=invalid-name + if v.grad_fn is not None: + visual_contents.update({v.grad_fn: (visual_prefix + k, v)}) + else: + visual_contents.update({v: visual_prefix + k}) # type: ignore[dict-item] + else: + visual_contents = None + + params: list[TensorContainer] = [] + buffers: list[TensorContainer] = [] + memo: set[nn.Module] = set() + + def update_params(container: TensorContainer) -> None: + if len(container) > 0: + params.append( + type(container)( + (k, replicate(v)) + for k, v in container.items() + if isinstance(v, torch.Tensor) + ), + ) + + def update_buffers(container: TensorContainer) -> None: + if len(container) > 0: + fn = clone_detach_ if detach_buffers else replicate + buffers.append( + type(container)( + (k, fn(v)) for k, v in container.items() if isinstance(v, torch.Tensor) + ), + ) + + # pylint: disable=protected-access + update_params(target._parameters) # type: ignore[arg-type] + if with_buffers: + update_buffers(target._buffers) + memo.add(target) + for submodule in target.modules(): + if submodule in memo: + continue + update_params(submodule._parameters) # type: ignore[arg-type] + if with_buffers: + update_buffers(submodule._buffers) + memo.add(submodule) + + return ModuleState( + params=tuple(params), + buffers=tuple(buffers), + visual_contents=visual_contents, + detach_buffers=detach_buffers, + ) + + if isinstance(target, MetaOptimizer): + state = target.state_dict() + + def get_variable(t: torch.Tensor | None) -> torch.Tensor | None: + if isinstance(t, torch.Tensor): + return replicate(t) + return t + + return pytree.tree_map(get_variable, state) # type: ignore[arg-type,return-value] + + raise TypeError(f'Unexpected class of {target}') + + +def extract_module_containers( + module: nn.Module, + with_buffers: bool = True, +) -> tuple[ModuleTensorContainers, ModuleTensorContainers]: + """Extract the references to the containers of parameters and buffers from a module.""" + if isinstance(module, nn.Module): + params: list[TensorContainer] = [] + buffers: list[TensorContainer] = [] + memo: set[nn.Module] = set() + + def update_container(container: list[TensorContainer], items: TensorContainer) -> None: + if len(items) > 0: + container.append(items) # we need references to original dictionaries + + # pylint: disable=protected-access + update_container(params, module._parameters) # type: ignore[arg-type] + if with_buffers: + update_container(buffers, module._buffers) + memo.add(module) + for submodule in module.modules(): + if submodule in memo: + continue + update_container(params, submodule._parameters) # type: ignore[arg-type] + if with_buffers: + update_container(buffers, submodule._buffers) + memo.add(submodule) + return tuple(params), tuple(buffers) + + raise RuntimeError(f'Unexpected class of {module}') + + +def recover_state_dict( + target: nn.Module | MetaOptimizer, + state: ModuleState | Sequence[OptState], +) -> None: + """Recover state. + + This function is compatible for the ``extract_state``. + + Note that the recovering process is not in-place, so the tensors of the object will not be + modified. + + Args: + target (nn.Module or MetaOptimizer): Target that need to recover. + state (ModuleState or sequence of tree of Tensor): The recovering state. + """ + # pylint: disable-next=import-outside-toplevel + from torchopt.optim.meta.base import MetaOptimizer + + if isinstance(target, nn.Module): + params, buffers, *_ = state = cast(ModuleState, state) + params_containers, buffers_containers = extract_module_containers(target, with_buffers=True) + + if state.detach_buffers: + + def clone_detach_(t: torch.Tensor) -> torch.Tensor: + if isinstance(t, nn.Parameter): + return nn.Parameter(t.clone().detach_(), requires_grad=t.requires_grad) + return t.clone().detach_().requires_grad_(t.requires_grad) + + buffers = pytree.tree_map(clone_detach_, buffers) # type: ignore[assignment,arg-type] + + for tgt, src in itertools.chain( + zip(params_containers, params), + zip(buffers_containers, buffers), + ): + tgt.update(src) + elif isinstance(target, MetaOptimizer): + state = cast(Sequence[OptState], state) + target.load_state_dict(state) + else: + raise TypeError(f'Unexpected class of {target}') + + +@overload +def module_clone( + target: nn.Module, + *, + by: CopyMode = 'reference', + detach_buffers: bool = False, + device: Device | None = None, +) -> nn.Module: # pragma: no cover + ... + + +@overload +def module_clone( + target: MetaOptimizer, + *, + by: CopyMode = 'reference', + detach_buffers: bool = False, + device: Device | None = None, +) -> MetaOptimizer: # pragma: no cover + ... + + +@overload +def module_clone( + target: TensorTree, + *, + by: CopyMode = 'reference', + detach_buffers: bool = False, + device: Device | None = None, +) -> TensorTree: # pragma: no cover + ... + + +# pylint: disable-next=too-many-locals +def module_clone( # noqa: C901 + target: nn.Module | MetaOptimizer | TensorTree, + *, + by: CopyMode = 'reference', + detach_buffers: bool = False, + device: Device | None = None, +) -> nn.Module | MetaOptimizer | TensorTree: + """Clone a module. + + Args: + target (nn.Module, MetaOptimizer, or tree of Tensor): The target to be cloned. + by (str, optional): The extract policy of tensors in the target. (default: :const:`'reference'`) + - :const:`'reference'`: The extracted tensors will be references to the original + tensors. + - :const:`'copy'`: The extracted tensors will be clones of the original tensors. This + makes the copied tensors have ``grad_fn`` to be a ```` function points + to the original tensors. + - :const:`'deepcopy'`: The extracted tensors will be deep-copied from the original + tensors. The deep-copied tensors will detach from the original computation graph. + detach_buffers (bool, optional): Whether to detach the reference to the buffers, this + argument is only used if the input target is :class:`nn.Module` and ``by='reference'``. + (default: :const:`False`) + device (Device or None, optional): If specified, move the cloned module to the specified + device. (default: :const:`None`) + + Returns: + The cloned module. + """ + assert by in ('reference', 'copy', 'deepcopy', 'ref', 'clone', 'deepclone') + by = by.replace('clone', 'copy') + by = 'reference' if by == 'ref' else by + if device is not None: + device = torch.device(device) + + # pylint: disable-next=import-outside-toplevel + from torchopt.optim.meta.base import MetaOptimizer + + if isinstance(target, (nn.Module, MetaOptimizer)): + if isinstance(target, nn.Module): + containers = cast(TensorTree, extract_module_containers(target, with_buffers=True)) + else: + containers = cast(TensorTree, target.state_dict()) + tensors = pytree.tree_leaves(containers) + memo = {id(t): t for t in tensors} + cloned = copy.deepcopy(target, memo=memo) + state = extract_state_dict( # type: ignore[call-overload] + target, + by=by, + with_buffers=True, + detach_buffers=detach_buffers, + device=device, + ) + recover_state_dict(cloned, state) + return cloned + + # Tree of tensors + if device is not None: + target_device = torch.device(device) + + def reference(t: torch.Tensor) -> torch.Tensor: + return t.to(device=target_device) + + def clone(t: torch.Tensor) -> torch.Tensor: + return t.clone().to(device=target_device) + + def clone_detach_(t: torch.Tensor) -> torch.Tensor: + if isinstance(t, nn.Parameter): + return nn.Parameter( + t.clone().to(device=target_device).detach_(), + requires_grad=t.requires_grad, + ) + return t.clone().to(device=target_device).detach_().requires_grad_(t.requires_grad) + + else: + + def reference(t: torch.Tensor) -> torch.Tensor: + return t + + def clone(t: torch.Tensor) -> torch.Tensor: + return t.clone() + + def clone_detach_(t: torch.Tensor) -> torch.Tensor: + if isinstance(t, nn.Parameter): + return nn.Parameter(t.clone().detach_(), requires_grad=t.requires_grad) + return t.clone().detach_().requires_grad_(t.requires_grad) + + if by == 'reference': + replicate = reference + elif by == 'copy': + replicate = clone + else: + replicate = clone_detach_ + + return pytree.tree_map(replicate, cast(TensorTree, target)) + + +@overload +def module_detach_(target: ModuleState) -> ModuleState: # pragma: no cover + ... + + +@overload +def module_detach_(target: nn.Module) -> nn.Module: # pragma: no cover + ... + + +@overload +def module_detach_(target: MetaOptimizer) -> MetaOptimizer: # pragma: no cover + ... + + +@overload +def module_detach_(target: TensorTree) -> TensorTree: # pragma: no cover + ... + + +def module_detach_( + target: ModuleState | nn.Module | MetaOptimizer | TensorTree, +) -> ModuleState | nn.Module | MetaOptimizer | TensorTree: + """Detach a module from the computation graph. + + Args: + target (ModuleState, nn.Module, MetaOptimizer, or tree of Tensor): The + target to be detached. + + Returns: + The detached module. + """ + stop_gradient(target) + return target diff --git a/torchopt/version.py b/torchopt/version.py index b79568e7..9fdcac9b 100644 --- a/torchopt/version.py +++ b/torchopt/version.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,4 +14,38 @@ # ============================================================================== """TorchOpt: a high-performance optimizer library built upon PyTorch.""" -__version__ = '0.5.0' +__version__ = '0.7.3' +__license__ = 'Apache License, Version 2.0' +__author__ = 'TorchOpt Contributors' +__release__ = False + +if not __release__: + import os + import subprocess + + try: + prefix, sep, suffix = ( + subprocess.check_output( # noqa: S603 + ['git', 'describe', '--abbrev=7'], # noqa: S607 + cwd=os.path.dirname(os.path.abspath(__file__)), + stderr=subprocess.DEVNULL, + text=True, + ) + .strip() + .lstrip('v') + .replace('-', '.dev', 1) + .replace('-', '+', 1) + .partition('.dev') + ) + if sep: + version_prefix, dot, version_tail = prefix.rpartition('.') + prefix = f'{version_prefix}{dot}{int(version_tail) + 1}' + __version__ = f'{prefix}{sep}{suffix}' + del version_prefix, dot, version_tail + else: + __version__ = prefix + del prefix, sep, suffix + except (OSError, subprocess.CalledProcessError): + pass + + del os, subprocess diff --git a/torchopt/_src/visual.py b/torchopt/visual.py similarity index 62% rename from torchopt/_src/visual.py rename to torchopt/visual.py index edf052bc..7638d7ec 100644 --- a/torchopt/_src/visual.py +++ b/torchopt/visual.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,24 +15,32 @@ # This file is modified from: # https://github.com/szagoruyko/pytorchviz/blob/master/torchviz/dot.py # ============================================================================== +"""Computation graph visualization.""" -import warnings -from collections import namedtuple -from typing import Dict, Generator +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Generator, Iterable, Mapping, cast import torch from graphviz import Digraph -from pkg_resources import parse_version + +from torchopt import pytree +from torchopt.utils import ModuleState + + +if TYPE_CHECKING: + from torchopt.typing import TensorTree -Node = namedtuple('Node', ('name', 'inputs', 'attr', 'op')) +__all__ = ['make_dot', 'resize_graph'] + # Saved attrs for grad_fn (incl. saved variables) begin with `._saved_*` SAVED_PREFIX = '_saved_' -def get_fn_name(fn, show_attrs, max_attr_chars): - """Returns function name.""" +def get_fn_name(fn: Any, show_attrs: bool, max_attr_chars: int) -> str: + """Return function name.""" name = str(type(fn).__name__) if not show_attrs: return name @@ -42,9 +50,9 @@ def get_fn_name(fn, show_attrs, max_attr_chars): continue val = getattr(fn, attr) attr = attr[len(SAVED_PREFIX) :] - if torch.is_tensor(val): + if isinstance(val, torch.Tensor): attrs[attr] = '[saved tensor]' - elif isinstance(val, tuple) and any(torch.is_tensor(t) for t in val): + elif isinstance(val, tuple) and any(isinstance(t, torch.Tensor) for t in val): attrs[attr] = '[saved tensors]' else: attrs[attr] = str(val) @@ -56,25 +64,34 @@ def get_fn_name(fn, show_attrs, max_attr_chars): sep = '-' * max(col1width + col2width + 2, len(name)) attrstr = '%-' + str(col1width) + 's: %' + str(col2width) + 's' - def truncate(s): # pylint: disable=invalid-name + def truncate(s: str) -> str: # pylint: disable=invalid-name return s[: col2width - 3] + '...' if len(s) > col2width else s params = '\n'.join(attrstr % (k, truncate(str(v))) for (k, v) in attrs.items()) return name + '\n' + sep + '\n' + params -# mypy: ignore-errors # pylint: disable-next=too-many-branches,too-many-statements,too-many-locals -def make_dot( - var: torch.Tensor, params=None, show_attrs=False, show_saved=False, max_attr_chars=50 +def make_dot( # noqa: C901 + var: TensorTree, + params: ( + Mapping[str, torch.Tensor] + | ModuleState + | Generator + | Iterable[Mapping[str, torch.Tensor] | ModuleState | Generator] + | None + ) = None, + show_attrs: bool = False, + show_saved: bool = False, + max_attr_chars: int = 50, ) -> Digraph: - """Produces Graphviz representation of PyTorch autograd graph. + """Produce Graphviz representation of PyTorch autograd graph. If a node represents a backward function, it is gray. Otherwise, the node represents a tensor and is either blue, orange, or green: - **Blue** - Reachable leaf tensors that requires grad (tensors whose :attr:`grad` fields will be + Reachable leaf tensors that requires grad (tensors whose ``grad`` fields will be populated during :meth:`backward`). - **Orange** Saved tensors of custom autograd functions as well as those saved by built-in backward @@ -85,71 +102,62 @@ def make_dot( If any output is a view, we represent its base tensor with a dark green node. Args: - var: Output tensor. - params: ([dict of (name, tensor) or state_dict]) - Parameters to add names to node that requires grad. - show_attrs: Whether to display non-tensor attributes of backward nodes - (Requires PyTorch version >= 1.9) - show_saved: Whether to display saved tensor nodes that are not by custom autograd - functions. Saved tensor nodes for custom functions, if present, are always displayed. - (Requires PyTorch version >= 1.9) - max_attr_chars: If ``show_attrs`` is :data:`True`, sets max number of characters to display - for any given attribute. + var (Tensor or sequence of Tensor): Output tensor. + params: (dict[str, Tensor], ModuleState, iterable of tuple[str, Tensor], or None, optional): + Parameters to add names to node that requires grad. (default: :data:`None`) + show_attrs (bool, optional): Whether to display non-tensor attributes of backward nodes. + (default: :data:`False`) + show_saved (bool, optional): Whether to display saved tensor nodes that are not by custom + autograd functions. Saved tensor nodes for custom functions, if present, are always + displayed. (default: :data:`False`) + max_attr_chars (int, optional): If ``show_attrs`` is :data:`True`, sets max number of + characters to display for any given attribute. (default: :const:`50`) """ - if parse_version(torch.__version__) < parse_version('1.9') and (show_attrs or show_saved): - warnings.warn( - 'make_dot: showing grad_fn attributes and saved variables ' - 'requires PyTorch version >= 1.9. (This does NOT apply to ' - 'saved tensors saved by custom autograd functions.)' - ) - param_map = {} if params is not None: - from torchopt._src.utils import _ModuleState # pylint: disable=import-outside-toplevel - - if isinstance(params, _ModuleState): + if isinstance(params, ModuleState) and params.visual_contents is not None: param_map.update(params.visual_contents) - elif isinstance(params, Dict): + elif isinstance(params, Mapping): param_map.update({v: k for k, v in params.items()}) elif isinstance(params, Generator): param_map.update({v: k for k, v in params}) else: for param in params: - if isinstance(param, _ModuleState): + if isinstance(param, ModuleState) and param.visual_contents is not None: param_map.update(param.visual_contents) elif isinstance(param, Generator): param_map.update({v: k for k, v in param}) else: - param_map.update({v: k for k, v in param.items()}) - - node_attr = dict( - style='filled', - shape='box', - align='left', - fontsize='10', - ranksep='0.1', - height='0.2', - fontname='monospace', - ) - dot = Digraph(node_attr=node_attr, graph_attr=dict(size='12,12')) + param_map.update({v: k for k, v in cast(Mapping, param).items()}) + + node_attr = { + 'style': 'filled', + 'shape': 'box', + 'align': 'left', + 'fontsize': '10', + 'ranksep': '0.1', + 'height': '0.2', + 'fontname': 'monospace', + } + dot = Digraph(node_attr=node_attr, graph_attr={'size': '12,12'}) seen = set() - def size_to_str(size): + def size_to_str(size: tuple[int, ...]) -> str: return '(' + (', ').join(map(str, size)) + ')' - def get_var_name(var, name=None): + def get_var_name(var: torch.Tensor, name: str | None = None) -> str: if not name: - name = param_map[var] if var in param_map else '' + name = param_map.get(var, '') return f'{name}\n{size_to_str(var.size())}' - def get_var_name_with_flag(var): + def get_var_name_with_flag(var: torch.Tensor) -> str | None: if var in param_map: return f'{param_map[var][0]}\n{size_to_str(param_map[var][1].size())}' return None - def add_nodes(fn): - assert not torch.is_tensor(fn) + def add_nodes(fn: Any) -> None: # noqa: C901 # pylint: disable=too-many-branches + assert not isinstance(fn, torch.Tensor) if fn in seen: return seen.add(fn) @@ -161,12 +169,12 @@ def add_nodes(fn): val = getattr(fn, attr) seen.add(val) attr = attr[len(SAVED_PREFIX) :] - if torch.is_tensor(val): + if isinstance(val, torch.Tensor): dot.edge(str(id(fn)), str(id(val)), dir='none') dot.node(str(id(val)), get_var_name(val, attr), fillcolor='orange') if isinstance(val, tuple): for i, t in enumerate(val): - if torch.is_tensor(t): + if isinstance(t, torch.Tensor): name = f'{attr}[{i}]' dot.edge(str(id(fn)), str(id(t)), dir='none') dot.node(str(id(t)), get_var_name(t, name), fillcolor='orange') @@ -203,32 +211,31 @@ def add_nodes(fn): dot.edge(str(id(t)), str(id(fn))) dot.node(str(id(t)), get_var_name(t), fillcolor='orange') - def add_base_tensor(var, color='darkolivegreen1'): - if var in seen: + def add_base_tensor( + v: torch.Tensor, # pylint: disable=invalid-name + color: str = 'darkolivegreen1', + ) -> None: + if v in seen: return - seen.add(var) - dot.node(str(id(var)), get_var_name(var), fillcolor=color) - if var.grad_fn: - add_nodes(var.grad_fn) - dot.edge(str(id(var.grad_fn)), str(id(var))) + seen.add(v) + dot.node(str(id(v)), get_var_name(v), fillcolor=color) + if v.grad_fn: + add_nodes(v.grad_fn) + dot.edge(str(id(v.grad_fn)), str(id(v))) # pylint: disable=protected-access - if var._is_view(): - add_base_tensor(var._base, color='darkolivegreen3') - dot.edge(str(id(var._base)), str(id(var)), style='dotted') + if v._is_view(): + add_base_tensor(v._base, color='darkolivegreen3') # type: ignore[arg-type] + dot.edge(str(id(v._base)), str(id(v)), style='dotted') # handle multiple outputs - if isinstance(var, tuple): - for v in var: # pylint: disable=invalid-name - add_base_tensor(v) - else: - add_base_tensor(var) + pytree.tree_map_(add_base_tensor, var) resize_graph(dot) return dot -def resize_graph(dot, size_per_element=0.5, min_size=12): +def resize_graph(dot: Digraph, size_per_element: float = 0.5, min_size: float = 12.0) -> None: """Resize the graph according to how much content it contains. Modify the graph in place. diff --git a/tutorials/1_Functional_Optimizer.ipynb b/tutorials/1_Functional_Optimizer.ipynb index f4194835..afc55f38 100644 --- a/tutorials/1_Functional_Optimizer.ipynb +++ b/tutorials/1_Functional_Optimizer.ipynb @@ -1,5 +1,5 @@ { - "cells": [ + "cells": [ { "cell_type": "markdown", "metadata": {}, @@ -11,14 +11,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "[](https://colab.research.google.com/drive/1yfi-ETyIptlIM7WFYWF_IFhX4WF3LldP?usp=sharing)" + "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/1_Functional_Optimizer.ipynb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "In this tutorial, we will introduce how TorchOpt can be treated as functional optimizer to conduct normal optimization with functional programing style. We will also illustrate how to conduct differentiable optimization with functional programing in PyTorch." + "In this tutorial, we will introduce how TorchOpt can be treated as functional optimizer to conduct normal optimization with functional programming style. We will also illustrate how to conduct differentiable optimization with functional programming in PyTorch." ] }, { @@ -70,7 +70,7 @@ "source": [ "### 1.1 Original JAX implementation\n", "\n", - "The first example is JAX implementation coupled with [Optax](https://github.com/deepmind/optax), which belongs to functional programing style." + "The first example is JAX implementation coupled with [Optax](https://github.com/deepmind/optax), which belongs to functional programming style." ] }, { @@ -88,7 +88,7 @@ " return jnp.matmul(x, params['weight']) + params['bias']\n", "\n", " # Obtain the `opt_state` that contains statistics for the optimizer\n", - " learning_rate = 1.\n", + " learning_rate = 1.0\n", " optimizer = optax.adam(learning_rate)\n", " opt_state = optimizer.init(params)\n", "\n", @@ -116,14 +116,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "Parameters before update: {\n", - " 'weight': DeviceArray([[1.]], dtype=float32)),\n", - " 'bias': DeviceArray([0.], dtype=float32)\n", - "}\n", - "Parameters after update: {\n", - " 'weight': DeviceArray([[6.735325e-06]], dtype=float32),\n", - " 'bias': DeviceArray([-0.99999326], dtype=float32)\n", - "}" + "Parameters before update:\n", + "OrderedDict([\n", + " ('weight', DeviceArray([[1.]], dtype=float32)),\n", + " ('bias', DeviceArray([0.], dtype=float32))\n", + "])\n", + "Parameters after update:\n", + "OrderedDict([\n", + " ('weight', DeviceArray([[6.735325e-06]], dtype=float32)),\n", + " ('bias', DeviceArray([-0.99999326], dtype=float32))\n", + "])\n" ] } ], @@ -153,7 +155,7 @@ " model, params = functorch.make_functional(net) # get the functional version of the model\n", "\n", " # Obtain the `opt_state` that contains statistics for the optimizer\n", - " learning_rate = 1.\n", + " learning_rate = 1.0\n", " optimizer = torchopt.adam(learning_rate)\n", " opt_state = optimizer.init(params)\n", "\n", @@ -165,7 +167,7 @@ "\n", " grads = torch.autograd.grad(loss, params)\n", " updates, opt_state = optimizer.update(grads, opt_state)\n", - " \n", + "\n", " print('Parameters before update:', params)\n", " params = torchopt.apply_updates(params, updates)\n", " print('Parameters after update:', params)" @@ -180,14 +182,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "Parameters before update: (\n", + "Parameters before update:\n", + "(\n", " Parameter containing: tensor([[1.]], requires_grad=True),\n", " Parameter containing: tensor([0.], requires_grad=True)\n", ")\n", - "Parameters after update: (\n", - " Parameter containing: tensor([[0.]], requires_grad=True),\n", - " Parameter containing: tensor([-1.], requires_grad=True)\n", - ")" + "Parameters after update:\n", + "(\n", + " Parameter containing: tensor([[6.6757e-06]], requires_grad=True),\n", + " Parameter containing: tensor([-1.0000], requires_grad=True)\n", + ")\n" ] } ], @@ -195,18 +199,77 @@ "interact_with_functorch()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "TorchOpt also offers a wrapper `torchopt.FuncOptimizer` to make it easier to maintain the optimizer states." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "def interact_with_functorch_with_wrapper():\n", + " batch_size = 1\n", + " dim = 1\n", + " net = Net(dim)\n", + " model, params = functorch.make_functional(net) # get the functional version of the model\n", + "\n", + " learning_rate = 1.0\n", + " optimizer = torchopt.FuncOptimizer(torchopt.adam(learning_rate))\n", + "\n", + " xs = 2 * torch.ones((batch_size, dim))\n", + " ys = torch.ones((batch_size, 1))\n", + "\n", + " pred = model(params, xs)\n", + " loss = mse(pred, ys)\n", + "\n", + " print('Parameters before update:', params)\n", + " params = optimizer.step(loss, params)\n", + " print('Parameters after update:', params)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Parameters before update:\n", + "(\n", + " Parameter containing: tensor([[1.]], requires_grad=True),\n", + " Parameter containing: tensor([0.], requires_grad=True)\n", + ")\n", + "Parameters after update:\n", + "(\n", + " tensor([[6.6757e-06]], grad_fn=),\n", + " tensor([-1.0000], grad_fn=)\n", + ")\n" + ] + } + ], + "source": [ + "interact_with_functorch_with_wrapper()" + ] + }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.3 Full TorchOpt\n", "\n", - "The third example is to illustrate that TorchOpt can also directly replace `torch.optim` with exactly the same usage. Note the API difference happens between `torchopt.adam()` and `torchopt.Adam()`." + "`torchopt.Optimizer` is the base class for our PyTorch-like optimizer. Combined with the functional optimizer `torchopt.sgd` and `torchopt.adam`, we can define our high-level API `torchopt.SGD` and `torchopt.Adam`. The third example is to illustrate that TorchOpt can also directly replace `torch.optim` with exactly the same usage. Note the API difference happens between `torchopt.adam()` and `torchopt.Adam()`." ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -215,8 +278,11 @@ " dim = 1\n", " net = Net(dim)\n", "\n", - " learning_rate = 1.\n", + " learning_rate = 1.0\n", + " # High-level API\n", " optim = torchopt.Adam(net.parameters(), lr=learning_rate)\n", + " # Low-level API\n", + " optim = torchopt.Optimizer(net.parameters(), torchopt.adam(lr=learning_rate))\n", "\n", " xs = 2 * torch.ones((batch_size, dim))\n", " ys = torch.ones((batch_size, 1))\n", @@ -233,21 +299,23 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Parameters before update: {\n", + "Parameters before update:\n", + "{\n", " 'fc.weight': Parameter containing: tensor([[1.]], requires_grad=True),\n", " 'fc.bias': Parameter containing: tensor([0.], requires_grad=True)\n", "}\n", - "Parameters after update: {\n", - " 'fc.weight': Parameter containing: tensor([[0.]], requires_grad=True),\n", - " 'fc.bias': Parameter containing: tensor([-1.], requires_grad=True)\n", - "}" + "Parameters after update:\n", + "{\n", + " 'fc.weight': Parameter containing: tensor([[6.6757e-06]], requires_grad=True),\n", + " 'fc.bias': Parameter containing: tensor([-1.0000], requires_grad=True)\n", + "}\n" ] } ], @@ -266,7 +334,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -275,7 +343,7 @@ " dim = 1\n", " net = Net(dim)\n", "\n", - " learning_rate = 1.\n", + " learning_rate = 1.0\n", " optim = torch.optim.Adam(net.parameters(), lr=learning_rate)\n", "\n", " xs = 2 * torch.ones((batch_size, dim))\n", @@ -293,21 +361,23 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Parameters before update: {\n", + "Parameters before update:\n", + "{\n", " 'fc.weight': Parameter containing: tensor([[1.]], requires_grad=True),\n", " 'fc.bias': Parameter containing: tensor([0.], requires_grad=True)\n", "}\n", - "Parameters after update: {\n", + "Parameters after update:\n", + "{\n", " 'fc.weight': Parameter containing: tensor([[1.1921e-07]], requires_grad=True),\n", " 'fc.bias': Parameter containing: tensor([-1.0000], requires_grad=True)\n", - "}" + "}\n" ] } ], @@ -321,14 +391,14 @@ "source": [ "## 2. Differentiable Optimization with Functional Optimizer\n", "\n", - "Coupled with functional optimizer, you can conduct differentiable optimization by setting the `inplace` flag as `False` in update and `apply_updates` function. (which might be helpful for meta-learning algorithm implementation with functional programing style). \n", + "Coupled with functional optimizer, you can conduct differentiable optimization by setting the `inplace` flag as `False` in update and `apply_updates` function. (which might be helpful for meta-learning algorithm implementation with functional programming style). \n", "\n", "Note that `torchopt.SGD` and `torchopt.Adam` do not support differentiable optimization. Refer to the Meta-Optimizer notebook for PyTorch-like differentiable optimizers." ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -342,7 +412,7 @@ " meta_param = nn.Parameter(torch.ones(1))\n", "\n", " # SGD example\n", - " learning_rate = 1.\n", + " learning_rate = 1.0\n", " optimizer = torchopt.sgd(learning_rate)\n", " opt_state = optimizer.init(params)\n", "\n", @@ -356,7 +426,8 @@ "\n", " grads = torch.autograd.grad(loss, params, create_graph=True)\n", " updates, opt_state = optimizer.update(grads, opt_state, inplace=False)\n", - " params = torchopt.apply_updates(params, updates, inplace=False) # update parameters with single step SGD update\n", + " # Update parameters with single step SGD update\n", + " params = torchopt.apply_updates(params, updates, inplace=False)\n", "\n", " pred = model(params, xs)\n", " loss = mse(pred, ys)\n", @@ -367,7 +438,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -393,29 +464,29 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ - "optim = torchopt.adam(lr=1., moment_requires_grad=False)" + "optim = torchopt.adam(lr=1.0, moment_requires_grad=False)" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ - "optim = torchopt.adam(lr=1., moment_requires_grad=True)" + "optim = torchopt.adam(lr=1.0, moment_requires_grad=True)" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ - "optim = torchopt.sgd(lr=1., momentum=0.8, moment_requires_grad=True)" + "optim = torchopt.sgd(lr=1.0, momentum=0.8, moment_requires_grad=True)" ] }, { @@ -436,7 +507,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -453,7 +524,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -470,27 +541,27 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "net = Net(1).cuda()\n", - "optim = torchopt.Adam(net.parameters(), lr=1., use_accelerated_op=True)" + "optim = torchopt.Adam(net.parameters(), lr=1.0, use_accelerated_op=True)" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ - "optim = torchopt.adam(lr=1., use_accelerated_op=True)" + "optim = torchopt.adam(lr=1.0, use_accelerated_op=True)" ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3.8.13 64-bit", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -504,7 +575,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.9.15" }, "vscode": { "interpreter": { diff --git a/tutorials/2_Visualization.ipynb b/tutorials/2_Visualization.ipynb index f1af008f..dd58c48d 100644 --- a/tutorials/2_Visualization.ipynb +++ b/tutorials/2_Visualization.ipynb @@ -1,5 +1,5 @@ { - "cells": [ + "cells": [ { "cell_type": "markdown", "metadata": {}, @@ -11,14 +11,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "[](https://colab.research.google.com/drive/1Uoo2epqZKmJNQOiO0EU8DGd33AVKBlAq?usp=sharing)" + "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/2_Visualization.ipynb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "In [PyTorch](https://pytorch.org), if the attribute `requires_grad` a tensor is `True`, the computation graph will be created if we use the tensor to do any operations. The computation graph is implemented likes link-list -- `Tensor`s are nodes and they are linked by their attribute `gran_fn`. [PyTorchViz](https://github.com/szagoruyko/pytorchviz) is a Python package that uses [Graphviz](https://graphviz.org) as a backend for plotting computation graphs. TorchOpt use PyTorchViz as the blueprint and provide more easy-to-use visualization functions on the premise of supporting all its functions." + "In [PyTorch](https://pytorch.org), if the attribute `requires_grad` of a tensor is `True`, the computation graph will be created if we use the tensor to do any operations. The computation graph is implemented like link-list -- `Tensor`s are nodes and they are linked by their attribute `gran_fn`. [PyTorchViz](https://github.com/szagoruyko/pytorchviz) is a Python package that uses [Graphviz](https://graphviz.org) as a backend for plotting computation graphs. TorchOpt use PyTorchViz as the blueprint and provide more easy-to-use visualization functions on the premise of supporting all its functions." ] }, { @@ -37,12 +37,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "\n" + "\n" ] }, { "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n139996637621680\n\ny\n ()\n\n\n\n139993377217744\n\nMulBackward0\n\n\n\n139993377217744->139996637621680\n\n\n\n\n\n139993377217840\n\nAccumulateGrad\n\n\n\n139993377217840->139993377217744\n\n\n\n\n\n139996637619360\n\nx\n ()\n\n\n\n139996637619360->139993377217840\n\n\n\n\n\n" + "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140534064715952\n\ny\n()\n\n\n\n140534064838304\n\nMulBackward0\n\n\n\n140534064838304->140534064715952\n\n\n\n\n\n140534064837776\n\nAccumulateGrad\n\n\n\n140534064837776->140534064838304\n\n\n\n\n\n140534064714832\n\nx\n()\n\n\n\n140534064714832->140534064837776\n\n\n\n\n\n" }, "metadata": {}, "output_type": "display_data" @@ -58,7 +58,7 @@ "import torchopt\n", "\n", "\n", - "x = torch.tensor(1., requires_grad=True)\n", + "x = torch.tensor(1.0, requires_grad=True)\n", "y = 2 * x\n", "display(torchopt.visual.make_dot(y, params={'x': x, 'y': y}))" ] @@ -86,12 +86,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "\n" + "\n" ] }, { "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n139993376880096\n\nloss\n ()\n\n\n\n139996875678480\n\nMseLossBackward0\n\n\n\n139996875678480->139993376880096\n\n\n\n\n\n139996875677952\n\nAddmmBackward0\n\n\n\n139996875677952->139996875678480\n\n\n\n\n\n139996875678336\n\nAccumulateGrad\n\n\n\n139996875678336->139996875677952\n\n\n\n\n\n139993376879696\n\nfc.bias\n (1)\n\n\n\n139993376879696->139996875678336\n\n\n\n\n\n139996875678912\n\nTBackward0\n\n\n\n139996875678912->139996875677952\n\n\n\n\n\n139996875679152\n\nAccumulateGrad\n\n\n\n139996875679152->139996875678912\n\n\n\n\n\n139993376879616\n\nfc.weight\n (1, 5)\n\n\n\n139993376879616->139996875679152\n\n\n\n\n\n" + "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140534659780336\n\nloss\n()\n\n\n\n140531595570768\n\nMseLossBackward0\n\n\n\n140531595570768->140534659780336\n\n\n\n\n\n140531595570576\n\nAddmmBackward0\n\n\n\n140531595570576->140531595570768\n\n\n\n\n\n140531595570528\n\nAccumulateGrad\n\n\n\n140531595570528->140531595570576\n\n\n\n\n\n140531595583632\n\nfc.bias\n(1)\n\n\n\n140531595583632->140531595570528\n\n\n\n\n\n140531595571104\n\nTBackward0\n\n\n\n140531595571104->140531595570576\n\n\n\n\n\n140531595570432\n\nAccumulateGrad\n\n\n\n140531595570432->140531595571104\n\n\n\n\n\n140531595582816\n\nfc.weight\n(1, 5)\n\n\n\n140531595582816->140531595570432\n\n\n\n\n\n" }, "metadata": {}, "output_type": "display_data" @@ -122,7 +122,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The computation graph of meta learning algorithms will be much more complex. Our visualization tool allows users take as input the extracted network state for better visualization." + "The computation graph of meta-learning algorithms will be much more complex. Our visualization tool allows users take as input the extracted network state for better visualization." ] }, { @@ -134,12 +134,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "\n" + "\n" ] }, { "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n139993376892384\n\nloss\n ()\n\n\n\n139993376862752\n\nMseLossBackward0\n\n\n\n139993376862752->139993376892384\n\n\n\n\n\n139993376862800\n\nAddBackward0\n\n\n\n139993376862800->139993376862752\n\n\n\n\n\n139993376862896\n\nAddmmBackward0\n\n\n\n139993376862896->139993376862800\n\n\n\n\n\n139993377217840\n\nAddBackward0\n step1.fc.bias\n (1)\n\n\n\n139993377217840->139993376862896\n\n\n\n\n\n139993376863136\n\nAccumulateGrad\n\n\n\n139993376863136->139993377217840\n\n\n\n\n\n139993376863664\n\nAddmmBackward0\n\n\n\n139993376863136->139993376863664\n\n\n\n\n\n139993376891904\n\nstep0.fc.bias\n (1)\n\n\n\n139993376891904->139993376863136\n\n\n\n\n\n139993376863088\n\nMulBackward0\n\n\n\n139993376863088->139993377217840\n\n\n\n\n\n139993376863184\n\nViewBackward0\n\n\n\n139993376863184->139993376863088\n\n\n\n\n\n139993376863376\n\nSumBackward1\n\n\n\n139993376863376->139993376863184\n\n\n\n\n\n139993376863472\n\nMseLossBackwardBackward0\n\n\n\n139993376863472->139993376863376\n\n\n\n\n\n139993376864000\n\nTBackward0\n\n\n\n139993376863472->139993376864000\n\n\n\n\n\n139993376863568\n\nAddBackward0\n\n\n\n139993376863568->139993376863472\n\n\n\n\n\n139993376863664->139993376863568\n\n\n\n\n\n139993376863760\n\nTBackward0\n\n\n\n139993376863760->139993376863664\n\n\n\n\n\n139993376863856\n\nAccumulateGrad\n\n\n\n139993376863856->139993376863760\n\n\n\n\n\n139993377218464\n\nAddBackward0\n step1.fc.weight\n (1, 5)\n\n\n\n139993376863856->139993377218464\n\n\n\n\n\n139993376891664\n\nstep0.fc.weight\n (1, 5)\n\n\n\n139993376891664->139993376863856\n\n\n\n\n\n139993376862848\n\nAccumulateGrad\n\n\n\n139993376862848->139993376862800\n\n\n\n\n\n139993376862848->139993376863568\n\n\n\n\n\n139996637619600\n\nmeta_param\n ()\n\n\n\n139996637619600->139993376862848\n\n\n\n\n\n139993376863040\n\nTBackward0\n\n\n\n139993376863040->139993376862896\n\n\n\n\n\n139993377218464->139993376863040\n\n\n\n\n\n139993376863424\n\nMulBackward0\n\n\n\n139993376863424->139993377218464\n\n\n\n\n\n139993376863616\n\nTBackward0\n\n\n\n139993376863616->139993376863424\n\n\n\n\n\n139993376863808\n\nTBackward0\n\n\n\n139993376863808->139993376863616\n\n\n\n\n\n139993376863904\n\nMmBackward0\n\n\n\n139993376863904->139993376863808\n\n\n\n\n\n139993376864000->139993376863904\n\n\n\n\n\n" + "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140531595614064\n\nloss\n()\n\n\n\n140531595567168\n\nMseLossBackward0\n\n\n\n140531595567168->140531595614064\n\n\n\n\n\n140531595569232\n\nAddBackward0\n\n\n\n140531595569232->140531595567168\n\n\n\n\n\n140531595568800\n\nAddmmBackward0\n\n\n\n140531595568800->140531595569232\n\n\n\n\n\n140534660247264\n\nAddBackward0\nstep1.fc.bias\n(1)\n\n\n\n140534660247264->140531595568800\n\n\n\n\n\n140534553595376\n\nAccumulateGrad\n\n\n\n140534553595376->140534660247264\n\n\n\n\n\n140534553592832\n\nAddmmBackward0\n\n\n\n140534553595376->140534553592832\n\n\n\n\n\n140534064448352\n\nstep0.fc.bias\n(1)\n\n\n\n140534064448352->140534553595376\n\n\n\n\n\n140534553595616\n\nMulBackward0\n\n\n\n140534553595616->140534660247264\n\n\n\n\n\n140534553594848\n\nViewBackward0\n\n\n\n140534553594848->140534553595616\n\n\n\n\n\n140534553594992\n\nSumBackward1\n\n\n\n140534553594992->140534553594848\n\n\n\n\n\n140534553594800\n\nMseLossBackwardBackward0\n\n\n\n140534553594800->140534553594992\n\n\n\n\n\n140531595617904\n\nTBackward0\n\n\n\n140534553594800->140531595617904\n\n\n\n\n\n140534553593072\n\nAddBackward0\n\n\n\n140534553593072->140534553594800\n\n\n\n\n\n140534553592832->140534553593072\n\n\n\n\n\n140534553593456\n\nTBackward0\n\n\n\n140534553593456->140534553592832\n\n\n\n\n\n140534553593888\n\nAccumulateGrad\n\n\n\n140534553593888->140534553593456\n\n\n\n\n\n140531595572368\n\nAddBackward0\nstep1.fc.weight\n(1, 5)\n\n\n\n140534553593888->140531595572368\n\n\n\n\n\n140531595612944\n\nstep0.fc.weight\n(1, 5)\n\n\n\n140531595612944->140534553593888\n\n\n\n\n\n140531595567888\n\nAccumulateGrad\n\n\n\n140531595567888->140531595569232\n\n\n\n\n\n140531595567888->140534553593072\n\n\n\n\n\n140531595613184\n\nmeta_param\n()\n\n\n\n140531595613184->140531595567888\n\n\n\n\n\n140534553594272\n\nTBackward0\n\n\n\n140534553594272->140531595568800\n\n\n\n\n\n140531595572368->140534553594272\n\n\n\n\n\n140534553593504\n\nMulBackward0\n\n\n\n140534553593504->140531595572368\n\n\n\n\n\n140534553592976\n\nTBackward0\n\n\n\n140534553592976->140534553593504\n\n\n\n\n\n140534553593216\n\nTBackward0\n\n\n\n140534553593216->140534553592976\n\n\n\n\n\n140534553593552\n\nMmBackward0\n\n\n\n140534553593552->140534553593216\n\n\n\n\n\n140531595617904->140534553593552\n\n\n\n\n\n" }, "metadata": {}, "output_type": "display_data" @@ -163,7 +163,7 @@ "ys = torch.ones((batch_size, 1))\n", "\n", "optimizer = torchopt.MetaSGD(net, lr=1e-3)\n", - "meta_param = torch.tensor(1., requires_grad=True)\n", + "meta_param = torch.tensor(1.0, requires_grad=True)\n", "\n", "# Set enable_visual\n", "net_state_0 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step0.')\n", @@ -179,13 +179,17 @@ "loss = F.mse_loss(pred, torch.ones_like(pred))\n", "\n", "# Draw computation graph\n", - "display(torchopt.visual.make_dot(loss, [net_state_0, net_state_1, {'meta_param': meta_param, 'loss': loss}]))" + "display(\n", + " torchopt.visual.make_dot(\n", + " loss, [net_state_0, net_state_1, {'meta_param': meta_param, 'loss': loss}]\n", + " )\n", + ")" ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3.8.13 ('torchopt')", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -199,7 +203,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.9.15" }, "vscode": { "interpreter": { diff --git a/tutorials/3_Meta_Optimizer.ipynb b/tutorials/3_Meta_Optimizer.ipynb index aaca9e3f..69be77ed 100644 --- a/tutorials/3_Meta_Optimizer.ipynb +++ b/tutorials/3_Meta_Optimizer.ipynb @@ -1,5 +1,5 @@ { - "cells": [ + "cells": [ { "cell_type": "markdown", "metadata": {}, @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "[](https://colab.research.google.com/drive/1lo9q2gQz073urYln-4Yub5s8APUoHvQJ?usp=sharing)" + "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/3_Meta_Optimizer.ipynb)" ] }, { @@ -34,7 +34,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Assume a tensor $x$ is a meta parameter and $a$ is a normal parameters (such as network parameters). We have inner loss $\\mathcal{L}^{\\textrm{in}} = a_0 \\cdot x^2$ and we update $a$ use the gradient $\\frac{\\partial \\mathcal{L}^{\\textrm{in}}}{\\partial a_0} = x^2$ and $a_1 = a_0 - \\eta \\, \\frac{\\partial \\mathcal{L}^{\\textrm{in}}}{\\partial a_0} = a_0 - \\eta \\, x^2$. Then we compute the outer loss $\\mathcal{L}^{\\textrm{out}} = a_1 \\cdot x^2$. So the gradient of outer loss to $x$ would be:\n", + "Assume a tensor $x$ is a meta-parameter and $a$ is a normal parameters (such as network parameters). We have inner loss $\\mathcal{L}^{\\textrm{in}} = a_0 \\cdot x^2$ and we update $a$ use the gradient $\\frac{\\partial \\mathcal{L}^{\\textrm{in}}}{\\partial a_0} = x^2$ and $a_1 = a_0 - \\eta \\, \\frac{\\partial \\mathcal{L}^{\\textrm{in}}}{\\partial a_0} = a_0 - \\eta \\, x^2$. Then we compute the outer loss $\\mathcal{L}^{\\textrm{out}} = a_1 \\cdot x^2$. So the gradient of outer loss to $x$ would be:\n", "\n", "$$\n", "\\begin{split}\n", @@ -73,17 +73,17 @@ "class Net(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", - " self.a = nn.Parameter(torch.tensor(1.), requires_grad=True)\n", - " \n", + " self.a = nn.Parameter(torch.tensor(1.0), requires_grad=True)\n", + "\n", " def forward(self, x):\n", - " return self.a * (x ** 2)" + " return self.a * (x**2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Then we declare the network (parameterized by `a`) and the meta parameter `x`. Do not forget to set flag `requires_grad=True` for `x`." + "Then we declare the network (parameterized by `a`) and the meta-parameter `x`. Do not forget to set flag `requires_grad=True` for `x`." ] }, { @@ -93,20 +93,40 @@ "outputs": [], "source": [ "net = Net()\n", - "x = nn.Parameter(torch.tensor(2.), requires_grad=True)" + "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Next we declare the meta optimizer. The meta optimizer takes as input the network and use method `step` to update the network (parameterized by `a`)." + "Next we declare the meta-optimizer. Here we show two equivalent ways of defining the meta-optimizer. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, + "outputs": [], + "source": [ + "# Low-level API\n", + "optim = torchopt.MetaOptimizer(net, torchopt.sgd(lr=1.0))\n", + "\n", + "# High-level API\n", + "optim = torchopt.MetaSGD(net, lr=1.0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The meta-optimizer takes the network as input and use method `step` to update the network (parameterized by `a`). Finally, we show how a bi-level process works." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -117,8 +137,6 @@ } ], "source": [ - "optim = torchopt.MetaSGD(net, lr=1.)\n", - "\n", "inner_loss = net(x)\n", "optim.step(inner_loss)\n", "\n", @@ -137,7 +155,7 @@ "source": [ "### 1.1 Track the Gradient of Momentum\n", "\n", - "Note that most modern optimizers involve moment term in the gradient update (basically only SGD with `momentum = 0` does not involve). We provide an option for user to choose whether to also track the meta-gradient through moment term. The default option is `moment_requires_grad=True`." + "Note that most modern optimizers involve moment term in the gradient update (basically only SGD with `momentum=0` does not involve). We provide an option for user to choose whether to also track the meta-gradient through moment term. The default option is `moment_requires_grad=True`." ] }, { @@ -149,19 +167,19 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\n" + "\n" ] }, { "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140393111569088\n\nouter_loss\n ()\n\n\n\n140393111544592\n\nMseLossBackward0\n\n\n\n140393111544592->140393111569088\n\n\n\n\n\n140393111544736\n\nMulBackward0\n\n\n\n140393111544736->140393111544592\n\n\n\n\n\n140396237940576\n\nAddBackward0\n step1.a\n ()\n\n\n\n140396237940576->140393111544736\n\n\n\n\n\n140393111545216\n\nAccumulateGrad\n\n\n\n140393111545216->140396237940576\n\n\n\n\n\n140393111545984\n\nMulBackward0\n\n\n\n140393111545216->140393111545984\n\n\n\n\n\n140393111534464\n\nstep0.a\n ()\n\n\n\n140393111534464->140393111545216\n\n\n\n\n\n140393111544112\n\nMulBackward0\n\n\n\n140393111544112->140396237940576\n\n\n\n\n\n140393111545168\n\nDivBackward0\n\n\n\n140393111545168->140393111544112\n\n\n\n\n\n140393111545408\n\nDivBackward0\n\n\n\n140393111545408->140393111545168\n\n\n\n\n\n140393111545552\n\nAddBackward0\n\n\n\n140393111545552->140393111545408\n\n\n\n\n\n140393111545648\n\nPowBackward0\n\n\n\n140393111545648->140393111545552\n\n\n\n\n\n140393111545744\n\nMulBackward0\n\n\n\n140393111545744->140393111545648\n\n\n\n\n\n140393111546272\n\nPowBackward0\n\n\n\n140393111545744->140393111546272\n\n\n\n\n\n140393111545840\n\nMseLossBackwardBackward0\n\n\n\n140393111545840->140393111545744\n\n\n\n\n\n140393111545984->140393111545840\n\n\n\n\n\n140393111545792\n\nPowBackward0\n\n\n\n140393111545792->140393111545744\n\n\n\n\n\n140393111545792->140393111545984\n\n\n\n\n\n140393111546128\n\nAccumulateGrad\n\n\n\n140393111546128->140393111545792\n\n\n\n\n\n140393111545024\n\nPowBackward0\n\n\n\n140393111546128->140393111545024\n\n\n\n\n\n140393111534624\n\nx\n ()\n\n\n\n140393111534624->140393111546128\n\n\n\n\n\n140393111545360\n\nAddBackward0\n\n\n\n140393111545360->140393111545168\n\n\n\n\n\n140393111545696\n\nSqrtBackward0\n\n\n\n140393111545696->140393111545360\n\n\n\n\n\n140393111545936\n\nAddBackward0\n\n\n\n140393111545936->140393111545696\n\n\n\n\n\n140393111545888\n\nDivBackward0\n\n\n\n140393111545888->140393111545936\n\n\n\n\n\n140393111546176\n\nAddBackward0\n\n\n\n140393111546176->140393111545888\n\n\n\n\n\n140393111546272->140393111546176\n\n\n\n\n\n140393111545024->140393111544736\n\n\n\n\n\n" + "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140447553047184\n\nouter_loss\n()\n\n\n\n140447553041216\n\nMseLossBackward0\n\n\n\n140447553041216->140447553047184\n\n\n\n\n\n140447553042896\n\nMulBackward0\n\n\n\n140447553042896->140447553041216\n\n\n\n\n\n140447553019088\n\nAddBackward0\nstep1.a\n()\n\n\n\n140447553019088->140447553042896\n\n\n\n\n\n140447553041072\n\nAccumulateGrad\n\n\n\n140447553041072->140447553019088\n\n\n\n\n\n140447553043664\n\nMulBackward0\n\n\n\n140447553041072->140447553043664\n\n\n\n\n\n140447553045344\n\nstep0.a\n()\n\n\n\n140447553045344->140447553041072\n\n\n\n\n\n140447553041120\n\nMulBackward0\n\n\n\n140447553041120->140447553019088\n\n\n\n\n\n140447553043040\n\nDivBackward0\n\n\n\n140447553043040->140447553041120\n\n\n\n\n\n140447553043184\n\nDivBackward0\n\n\n\n140447553043184->140447553043040\n\n\n\n\n\n140447553043328\n\nAddBackward0\n\n\n\n140447553043328->140447553043184\n\n\n\n\n\n140447553043424\n\nMulBackward0\n\n\n\n140447553043424->140447553043328\n\n\n\n\n\n140447553043856\n\nAddcmulBackward0\n\n\n\n140447553043424->140447553043856\n\n\n\n\n\n140447553043424->140447553043856\n\n\n\n\n\n140447553043520\n\nMseLossBackwardBackward0\n\n\n\n140447553043520->140447553043424\n\n\n\n\n\n140447553043664->140447553043520\n\n\n\n\n\n140447553043472\n\nPowBackward0\n\n\n\n140447553043472->140447553043424\n\n\n\n\n\n140447553043472->140447553043664\n\n\n\n\n\n140447553043808\n\nAccumulateGrad\n\n\n\n140447553043808->140447553043472\n\n\n\n\n\n140447553041264\n\nPowBackward0\n\n\n\n140447553043808->140447553041264\n\n\n\n\n\n140447553045584\n\nx\n()\n\n\n\n140447553045584->140447553043808\n\n\n\n\n\n140447553043136\n\nAddBackward0\n\n\n\n140447553043136->140447553043040\n\n\n\n\n\n140447553043232\n\nSqrtBackward0\n\n\n\n140447553043232->140447553043136\n\n\n\n\n\n140447553043760\n\nAddBackward0\n\n\n\n140447553043760->140447553043232\n\n\n\n\n\n140447553043904\n\nDivBackward0\n\n\n\n140447553043904->140447553043760\n\n\n\n\n\n140447553043856->140447553043904\n\n\n\n\n\n140447553041264->140447553042896\n\n\n\n\n\n" }, "metadata": {}, "output_type": "display_data" @@ -169,10 +187,10 @@ ], "source": [ "net = Net()\n", - "x = nn.Parameter(torch.tensor(2.), requires_grad=True)\n", - "y = torch.tensor(1.)\n", + "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n", + "y = torch.tensor(1.0)\n", "\n", - "optim = torchopt.MetaAdam(net, lr=1., moment_requires_grad=False)\n", + "optim = torchopt.MetaAdam(net, lr=1.0, moment_requires_grad=False)\n", "\n", "net_state_0 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step0.')\n", "inner_loss = F.mse_loss(net(x), y)\n", @@ -180,7 +198,11 @@ "net_state_1 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step1.')\n", "\n", "outer_loss = F.mse_loss(net(x), y)\n", - "display(torchopt.visual.make_dot(outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]))" + "display(\n", + " torchopt.visual.make_dot(\n", + " outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n", + " )\n", + ")" ] }, { @@ -192,19 +214,19 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\n" + "\n" ] }, { "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140393102737552\n\nouter_loss\n ()\n\n\n\n140393111544400\n\nMseLossBackward0\n\n\n\n140393111544400->140393102737552\n\n\n\n\n\n140393111544304\n\nMulBackward0\n\n\n\n140393111544304->140393111544400\n\n\n\n\n\n140396584753232\n\nAddBackward0\n step1.a\n ()\n\n\n\n140396584753232->140393111544304\n\n\n\n\n\n140393111544016\n\nAccumulateGrad\n\n\n\n140393111544016->140396584753232\n\n\n\n\n\n140393111547280\n\nMulBackward0\n\n\n\n140393111544016->140393111547280\n\n\n\n\n\n140393111570848\n\nstep0.a\n ()\n\n\n\n140393111570848->140393111544016\n\n\n\n\n\n140393111544256\n\nMulBackward0\n\n\n\n140393111544256->140396584753232\n\n\n\n\n\n140393111544160\n\nDivBackward0\n\n\n\n140393111544160->140393111544256\n\n\n\n\n\n140393111546512\n\nDivBackward0\n\n\n\n140393111546512->140393111544160\n\n\n\n\n\n140393111544112\n\nAddBackward0\n\n\n\n140393111544112->140393111546512\n\n\n\n\n\n140393111546368\n\nMulBackward0\n\n\n\n140393111546368->140393111544112\n\n\n\n\n\n140393111547040\n\nAccumulateGrad\n\n\n\n140393111547040->140393111546368\n\n\n\n\n\n140393111569408\n\n ()\n\n\n\n140393111569408->140393111547040\n\n\n\n\n\n140393111546272\n\nPowBackward0\n\n\n\n140393111546272->140393111544112\n\n\n\n\n\n140393111547088\n\nMulBackward0\n\n\n\n140393111547088->140393111546272\n\n\n\n\n\n140393111547328\n\nPowBackward0\n\n\n\n140393111547088->140393111547328\n\n\n\n\n\n140393111547184\n\nMseLossBackwardBackward0\n\n\n\n140393111547184->140393111547088\n\n\n\n\n\n140393111547280->140393111547184\n\n\n\n\n\n140393111546944\n\nPowBackward0\n\n\n\n140393111546944->140393111547088\n\n\n\n\n\n140393111546944->140393111547280\n\n\n\n\n\n140393111546320\n\nAccumulateGrad\n\n\n\n140393111546320->140393111546944\n\n\n\n\n\n140393111544208\n\nPowBackward0\n\n\n\n140393111546320->140393111544208\n\n\n\n\n\n140393111571168\n\nx\n ()\n\n\n\n140393111571168->140393111546320\n\n\n\n\n\n140393111546848\n\nAddBackward0\n\n\n\n140393111546848->140393111544160\n\n\n\n\n\n140393111547136\n\nSqrtBackward0\n\n\n\n140393111547136->140393111546848\n\n\n\n\n\n140393111547232\n\nAddBackward0\n\n\n\n140393111547232->140393111547136\n\n\n\n\n\n140393111545360\n\nDivBackward0\n\n\n\n140393111545360->140393111547232\n\n\n\n\n\n140393111547424\n\nAddBackward0\n\n\n\n140393111547424->140393111545360\n\n\n\n\n\n140393111547520\n\nMulBackward0\n\n\n\n140393111547520->140393111547424\n\n\n\n\n\n140393111547616\n\nAccumulateGrad\n\n\n\n140393111547616->140393111547520\n\n\n\n\n\n140393111570288\n\n ()\n\n\n\n140393111570288->140393111547616\n\n\n\n\n\n140393111547328->140393111547424\n\n\n\n\n\n140393111544208->140393111544304\n\n\n\n\n\n" + "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140447553148704\n\nouter_loss\n()\n\n\n\n140447553041024\n\nMseLossBackward0\n\n\n\n140447553041024->140447553148704\n\n\n\n\n\n140447553043424\n\nMulBackward0\n\n\n\n140447553043424->140447553041024\n\n\n\n\n\n140450536407152\n\nAddBackward0\nstep1.a\n()\n\n\n\n140450536407152->140447553043424\n\n\n\n\n\n140447553041264\n\nAccumulateGrad\n\n\n\n140447553041264->140450536407152\n\n\n\n\n\n140447553019232\n\nMulBackward0\n\n\n\n140447553041264->140447553019232\n\n\n\n\n\n140447553148064\n\nstep0.a\n()\n\n\n\n140447553148064->140447553041264\n\n\n\n\n\n140447553041216\n\nMulBackward0\n\n\n\n140447553041216->140450536407152\n\n\n\n\n\n140447553041312\n\nDivBackward0\n\n\n\n140447553041312->140447553041216\n\n\n\n\n\n140447553041408\n\nDivBackward0\n\n\n\n140447553041408->140447553041312\n\n\n\n\n\n140447553043376\n\nAddBackward0\n\n\n\n140447553043376->140447553041408\n\n\n\n\n\n140447553041168\n\nMulBackward0\n\n\n\n140447553041168->140447553043376\n\n\n\n\n\n140447553042272\n\nAccumulateGrad\n\n\n\n140447553042272->140447553041168\n\n\n\n\n\n140450290826352\n\n()\n\n\n\n140450290826352->140447553042272\n\n\n\n\n\n140447553044432\n\nMulBackward0\n\n\n\n140447553044432->140447553043376\n\n\n\n\n\n140447553018320\n\nAddcmulBackward0\n\n\n\n140447553044432->140447553018320\n\n\n\n\n\n140447553044432->140447553018320\n\n\n\n\n\n140447553042080\n\nMseLossBackwardBackward0\n\n\n\n140447553042080->140447553044432\n\n\n\n\n\n140447553019232->140447553042080\n\n\n\n\n\n140447553019088\n\nPowBackward0\n\n\n\n140447553019088->140447553044432\n\n\n\n\n\n140447553019088->140447553019232\n\n\n\n\n\n140447553018464\n\nAccumulateGrad\n\n\n\n140447553018464->140447553019088\n\n\n\n\n\n140447553043328\n\nPowBackward0\n\n\n\n140447553018464->140447553043328\n\n\n\n\n\n140447553148144\n\nx\n()\n\n\n\n140447553148144->140447553018464\n\n\n\n\n\n140447553041456\n\nAddBackward0\n\n\n\n140447553041456->140447553041312\n\n\n\n\n\n140447553041360\n\nSqrtBackward0\n\n\n\n140447553041360->140447553041456\n\n\n\n\n\n140447553015920\n\nAddBackward0\n\n\n\n140447553015920->140447553041360\n\n\n\n\n\n140447553018560\n\nDivBackward0\n\n\n\n140447553018560->140447553015920\n\n\n\n\n\n140447553018320->140447553018560\n\n\n\n\n\n140447553018272\n\nMulBackward0\n\n\n\n140447553018272->140447553018320\n\n\n\n\n\n140447553018944\n\nAccumulateGrad\n\n\n\n140447553018944->140447553018272\n\n\n\n\n\n140450290824272\n\n()\n\n\n\n140450290824272->140447553018944\n\n\n\n\n\n140447553043328->140447553043424\n\n\n\n\n\n" }, "metadata": {}, "output_type": "display_data" @@ -212,10 +234,10 @@ ], "source": [ "net = Net()\n", - "x = nn.Parameter(torch.tensor(2.), requires_grad=True)\n", - "y = torch.tensor(1.)\n", + "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n", + "y = torch.tensor(1.0)\n", "\n", - "optim = torchopt.MetaAdam(net, lr=1., moment_requires_grad=True)\n", + "optim = torchopt.MetaAdam(net, lr=1.0, moment_requires_grad=True)\n", "\n", "net_state_0 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step0.')\n", "inner_loss = F.mse_loss(net(x), y)\n", @@ -223,14 +245,18 @@ "net_state_1 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step1.')\n", "\n", "outer_loss = F.mse_loss(net(x), y)\n", - "display(torchopt.visual.make_dot(outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]))" + "display(\n", + " torchopt.visual.make_dot(\n", + " outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n", + " )\n", + ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We can see that the additional moment terms are added into the computational graph when we set `moment_requires_grad = True`." + "We can see that the additional moment terms are added into the computational graph when we set `moment_requires_grad=True`." ] }, { @@ -248,36 +274,42 @@ "\n", "We observe that how to reinitialize the inner-loop parameter in a new bi-level process vary in different meta-learning algorithms. For instance, in algorithm like Model-Agnostic Meta-Learning (MAML) ([arXiv:1703.03400](https://arxiv.org/abs/1703.03400)), every time a new task comes, we need to reset the parameters to the initial ones. In other cases such as Meta-Gradient Reinforcement Learning (MGRL) ([arXiv:1805.09801](https://arxiv.org/abs/1805.09801)), the inner-loop network parameter just inherit previous updated parameter to continue the new bi-level process.\n", "\n", - "We provide the `torchopt.extract_state_dict` and `torchopt.recover_state_dict` functions to extract and restore the state of network and optimizer. By default, the extracted state dictionary is a reference (this design is for accumulating gradient of multi-task batch training, MAML for example). You can also set `copy=True` to extract the copy of state dictionary." + "We provide the `torchopt.extract_state_dict` and `torchopt.recover_state_dict` functions to extract and restore the state of network and optimizer. By default, the extracted state dictionary is a reference (this design is for accumulating gradient of multi-task batch training, MAML for example). You can also set `by='copy'` to extract the copy of the state dictionary or set `by='deepcopy'` to have a detached copy." ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "a = tensor(-1., grad_fn=)\n", - "a = tensor(-1., grad_fn=)\n" + "a = tensor(-1.0000, grad_fn=)\n", + "a = tensor(-1.0000, grad_fn=)\n" ] } ], "source": [ "net = Net()\n", - "x = nn.Parameter(torch.tensor(2.), requires_grad=True)\n", + "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n", "\n", - "optim = torchopt.MetaAdam(net, lr=1.)\n", + "optim = torchopt.MetaAdam(net, lr=1.0)\n", "\n", "# Get the reference of state dictionary\n", - "init_net_state = torchopt.extract_state_dict(net)\n", - "init_optim_state = torchopt.extract_state_dict(optim)\n", + "init_net_state = torchopt.extract_state_dict(net, by='reference')\n", + "init_optim_state = torchopt.extract_state_dict(optim, by='reference')\n", + "# If set `detach_buffers=True`, the parameters are referenced as references while buffers are detached copies\n", + "init_net_state = torchopt.extract_state_dict(net, by='reference', detach_buffers=True)\n", + "\n", + "# Set `copy` to get the copy of the state dictionary\n", + "init_net_state_copy = torchopt.extract_state_dict(net, by='copy')\n", + "init_optim_state_copy = torchopt.extract_state_dict(optim, by='copy')\n", "\n", - "# Set `copy=True` to get the copy of state dictionary\n", - "init_net_state_copy = torchopt.extract_state_dict(net, copy=True)\n", - "init_optim_state_copy = torchopt.extract_state_dict(optim, copy=True)\n", + "# Set `deepcopy` to get the detached copy of state dictionary\n", + "init_net_state_deepcopy = torchopt.extract_state_dict(net, by='deepcopy')\n", + "init_optim_state_deepcopy = torchopt.extract_state_dict(optim, by='deepcopy')\n", "\n", "# Conduct 2 inner-loop optimization\n", "for i in range(2):\n", @@ -303,9 +335,9 @@ "source": [ "### 2.2 Multi-task Example with `extract_state_dict` and `recover_state_dict`\n", "\n", - "Let's move to another more complex setting. Meta Learning algorithms always fix network on several different tasks and accumulate outer loss of each task to the meta gradient.\n", + "Let's move to another more complex setting. Meta-Learning algorithms always fix network on several different tasks and accumulate outer loss of each task to the meta-gradient.\n", "\n", - "Assume $x$ is a meta parameter and $a$ is a normal parameter. We firstly update $a$ use inner loss $\\mathcal{L}_1^{\\textrm{in}} = a_0 \\cdot x^2$ to $a_1$. Then we use $a_1$ to compute the outer loss $\\mathcal{L}_1^{\\textrm{out}} = a_1 \\cdot x^2$ and back-propagate it. Then we use $a_0$ to compute the inner loss $\\mathcal{L}_2^{\\textrm{in}} = a_0 \\cdot x$ and update $a_0$ to $a_2 = a_0 - \\eta \\, \\frac{\\partial \\mathcal{L}_2^{\\textrm{in}}}{\\partial a_0} = a_0 - \\eta \\, x$. Then we compute outer loss $\\mathcal{L}_2^{\\textrm{out}} = a_2 \\cdot x$ and back-propagate it. So the accumulated meta gradient would be:\n", + "Assume $x$ is a meta-parameter and $a$ is a normal parameter. We firstly update $a$ use inner loss $\\mathcal{L}_1^{\\textrm{in}} = a_0 \\cdot x^2$ to $a_1$. Then we use $a_1$ to compute the outer loss $\\mathcal{L}_1^{\\textrm{out}} = a_1 \\cdot x^2$ and backpropagate it. Then we use $a_0$ to compute the inner loss $\\mathcal{L}_2^{\\textrm{in}} = a_0 \\cdot x$ and update $a_0$ to $a_2 = a_0 - \\eta \\, \\frac{\\partial \\mathcal{L}_2^{\\textrm{in}}}{\\partial a_0} = a_0 - \\eta \\, x$. Then we compute outer loss $\\mathcal{L}_2^{\\textrm{out}} = a_2 \\cdot x$ and backpropagate it. So the accumulated meta-gradient would be:\n", "\n", "$$\n", "\\begin{split}\n", @@ -328,26 +360,26 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "class Net2Tasks(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", - " self.a = nn.Parameter(torch.tensor(1.), requires_grad=True)\n", - " \n", + " self.a = nn.Parameter(torch.tensor(1.0), requires_grad=True)\n", + "\n", " def task1(self, x):\n", - " return self.a * x ** 2\n", - " \n", + " return self.a * x**2\n", + "\n", " def task2(self, x):\n", " return self.a * x\n", "\n", "\n", "net = Net2Tasks()\n", - "x = nn.Parameter(torch.tensor(2.), requires_grad=True)\n", + "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n", "\n", - "optim = torchopt.MetaSGD(net, lr=1.)" + "optim = torchopt.MetaSGD(net, lr=1.0)" ] }, { @@ -359,14 +391,14 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "init_optim_state = ((EmptyState(), EmptyState()),)\n", + "init_optim_state = ((EmptyState(),),)\n", "Task 1: x.grad = tensor(-28.)\n", "Accumulated: x.grad = tensor(-31.)\n" ] @@ -374,8 +406,8 @@ ], "source": [ "# Get the reference of state dictionary\n", - "init_net_state = torchopt.extract_state_dict(net)\n", - "init_optim_state = torchopt.extract_state_dict(optim)\n", + "init_net_state = torchopt.extract_state_dict(net, by='reference')\n", + "init_optim_state = torchopt.extract_state_dict(optim, by='reference')\n", "# The `state_dict` is empty for vanilla SGD optimizer\n", "print(f'init_optim_state = {init_optim_state!r}')\n", "\n", @@ -430,7 +462,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -443,9 +475,12 @@ ], "source": [ "net = Net()\n", - "x = nn.Parameter(torch.tensor(2.), requires_grad=True)\n", + "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n", "\n", - "optim_impl = torchopt.combine.chain(torchopt.clip.clip_grad_norm(max_norm=2.), torchopt.sgd(lr=1., moment_requires_grad=True))\n", + "optim_impl = torchopt.combine.chain(\n", + " torchopt.clip.clip_grad_norm(max_norm=2.0),\n", + " torchopt.sgd(lr=1.0, moment_requires_grad=True),\n", + ")\n", "optim = torchopt.MetaOptimizer(net, optim_impl)\n", "\n", "inner_loss = net(x)\n", @@ -465,9 +500,45 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 4. Accelerated Optimizer\n", + "## 4. Learning Rate Scheduler\n", + "\n", + "TorchOpt also provides implementation of learning rate scheduler, which can be used as:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "functional_adam = torchopt.adam(\n", + " lr=torchopt.schedule.linear_schedule(\n", + " init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n", + " )\n", + ")\n", + "\n", + "adam = torchopt.Adam(\n", + " net.parameters(),\n", + " lr=torchopt.schedule.linear_schedule(\n", + " init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n", + " ),\n", + ")\n", + "\n", + "meta_adam = torchopt.MetaAdam(\n", + " net,\n", + " lr=torchopt.schedule.linear_schedule(\n", + " init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Accelerated Optimizer\n", "\n", - "Users can use accelerated optimizer by setting the `use_accelerated_op` as `True`. Currently we only support the Adam optimizer." + "Users can use accelerated optimizer by setting the `use_accelerated_op=True`. Currently we only support the Adam optimizer." ] }, { @@ -479,7 +550,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -496,7 +567,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -513,19 +584,19 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\n" + "\n" ] }, { "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140393102828544\n\nouter_loss\n ()\n\n\n\n140393111546128\n\nMseLossBackward0\n\n\n\n140393111546128->140393102828544\n\n\n\n\n\n140393111546032\n\nMulBackward0\n\n\n\n140393111546032->140393111546128\n\n\n\n\n\n140396237940288\n\nAddBackward0\n step1.a\n ()\n\n\n\n140396237940288->140393111546032\n\n\n\n\n\n140393111546464\n\nAccumulateGrad\n\n\n\n140393111546464->140396237940288\n\n\n\n\n\n140393102725760\n\nMulBackward0\n\n\n\n140393111546464->140393102725760\n\n\n\n\n\n140393102827744\n\nstep0.a\n ()\n\n\n\n140393102827744->140393111546464\n\n\n\n\n\n140393102725232\n\nMulBackward0\n\n\n\n140393102725232->140396237940288\n\n\n\n\n\n140393112318976\n\nUpdatesOpBackward\n\n\n\n140393112318976->140393102725232\n\n\n\n\n\n140396647894368\n\nMuOpBackward\n\n\n\n140396647894368->140393112318976\n\n\n\n\n\n140393102725472\n\nMulBackward0\n\n\n\n140393102725472->140396647894368\n\n\n\n\n\n140393112318736\n\nNuOpBackward\n\n\n\n140393102725472->140393112318736\n\n\n\n\n\n140393102725616\n\nMseLossBackwardBackward0\n\n\n\n140393102725616->140393102725472\n\n\n\n\n\n140393102725760->140393102725616\n\n\n\n\n\n140393102725568\n\nPowBackward0\n\n\n\n140393102725568->140393102725472\n\n\n\n\n\n140393102725568->140393102725760\n\n\n\n\n\n140393102725904\n\nAccumulateGrad\n\n\n\n140393102725904->140393102725568\n\n\n\n\n\n140393111543968\n\nPowBackward0\n\n\n\n140393102725904->140393111543968\n\n\n\n\n\n140393111485872\n\nx\n ()\n\n\n\n140393111485872->140393102725904\n\n\n\n\n\n140393102725328\n\nAccumulateGrad\n\n\n\n140393102725328->140396647894368\n\n\n\n\n\n140393111534224\n\n ()\n\n\n\n140393111534224->140396647894368\n\n\n\n\n\n140393111534224->140393102725328\n\n\n\n\n\n140393111531904\n\n ()\n\n\n\n140393111531904->140396647894368\n\n\n\n\n\n140393111531904->140393112318736\n\n\n\n\n\n140393112318736->140393112318976\n\n\n\n\n\n140393102725712\n\nAccumulateGrad\n\n\n\n140393102725712->140393112318736\n\n\n\n\n\n140393102827824\n\n ()\n\n\n\n140393102827824->140393112318736\n\n\n\n\n\n140393102827824->140393102725712\n\n\n\n\n\n140393102828784\n\n ()\n\n\n\n140393102828784->140393112318976\n\n\n\n\n\n140393102828144\n\n ()\n\n\n\n140393102828144->140393112318976\n\n\n\n\n\n140393102828224\n\n ()\n\n\n\n140393102828224->140393112318976\n\n\n\n\n\n140393111543968->140393111546032\n\n\n\n\n\n" + "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140450290825712\n\nouter_loss\n()\n\n\n\n140450533650240\n\nMseLossBackward0\n\n\n\n140450533650240->140450290825712\n\n\n\n\n\n140450533648560\n\nMulBackward0\n\n\n\n140450533648560->140450533650240\n\n\n\n\n\n140450533647456\n\nAddBackward0\nstep1.a\n()\n\n\n\n140450533647456->140450533648560\n\n\n\n\n\n140447435136640\n\nAccumulateGrad\n\n\n\n140447435136640->140450533647456\n\n\n\n\n\n140450533648416\n\nMulBackward0\n\n\n\n140447435136640->140450533648416\n\n\n\n\n\n140447435236512\n\nstep0.a\n()\n\n\n\n140447435236512->140447435136640\n\n\n\n\n\n140447435136688\n\nMulBackward0\n\n\n\n140447435136688->140450533647456\n\n\n\n\n\n140447554132144\n\nUpdatesOpBackward\n\n\n\n140447554132144->140447435136688\n\n\n\n\n\n140447554131664\n\nMuOpBackward\n\n\n\n140447554131664->140447554132144\n\n\n\n\n\n140447435134816\n\nMulBackward0\n\n\n\n140447435134816->140447554131664\n\n\n\n\n\n140447554131904\n\nNuOpBackward\n\n\n\n140447435134816->140447554131904\n\n\n\n\n\n140450533648992\n\nMseLossBackwardBackward0\n\n\n\n140450533648992->140447435134816\n\n\n\n\n\n140450533648416->140450533648992\n\n\n\n\n\n140450533646448\n\nPowBackward0\n\n\n\n140450533646448->140447435134816\n\n\n\n\n\n140450533646448->140450533648416\n\n\n\n\n\n140447553018176\n\nAccumulateGrad\n\n\n\n140447553018176->140450533646448\n\n\n\n\n\n140447435135536\n\nPowBackward0\n\n\n\n140447553018176->140447435135536\n\n\n\n\n\n140447553045424\n\nx\n()\n\n\n\n140447553045424->140447553018176\n\n\n\n\n\n140447435136592\n\nAccumulateGrad\n\n\n\n140447435136592->140447554131664\n\n\n\n\n\n140447552973856\n\n()\n\n\n\n140447552973856->140447554131664\n\n\n\n\n\n140447552973856->140447435136592\n\n\n\n\n\n140447553044544\n\n()\n\n\n\n140447553044544->140447554131664\n\n\n\n\n\n140447553044544->140447554131904\n\n\n\n\n\n140447554131904->140447554132144\n\n\n\n\n\n140450533648896\n\nAccumulateGrad\n\n\n\n140450533648896->140447554131904\n\n\n\n\n\n140447435236752\n\n()\n\n\n\n140447435236752->140447554131904\n\n\n\n\n\n140447435236752->140450533648896\n\n\n\n\n\n140447553045904\n\n()\n\n\n\n140447553045904->140447554132144\n\n\n\n\n\n140447435237152\n\n()\n\n\n\n140447435237152->140447554132144\n\n\n\n\n\n140447435237232\n\n()\n\n\n\n140447435237232->140447554132144\n\n\n\n\n\n140447435135536->140450533648560\n\n\n\n\n\n" }, "metadata": {}, "output_type": "display_data" @@ -533,24 +604,89 @@ ], "source": [ "net = Net().to(device='cuda')\n", - "x = nn.Parameter(torch.tensor(2., device=torch.device('cuda')), requires_grad=True)\n", - "y = torch.tensor(1., device=torch.device('cuda'))\n", + "x = nn.Parameter(torch.tensor(2.0, device=torch.device('cuda')), requires_grad=True)\n", + "y = torch.tensor(1.0, device=torch.device('cuda'))\n", "\n", - "optim = torchopt.MetaAdam(net, lr=1., moment_requires_grad=True, use_accelerated_op=True)\n", + "optim = torchopt.MetaAdam(net, lr=1.0, moment_requires_grad=True, use_accelerated_op=True)\n", "\n", - "net_state_0 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step0.')\n", + "net_state_0 = torchopt.extract_state_dict(\n", + " net, by='reference', enable_visual=True, visual_prefix='step0.'\n", + ")\n", "inner_loss = F.mse_loss(net(x), y)\n", "optim.step(inner_loss)\n", - "net_state_1 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step1.')\n", + "net_state_1 = torchopt.extract_state_dict(\n", + " net, by='reference', enable_visual=True, visual_prefix='step1.'\n", + ")\n", "\n", "outer_loss = F.mse_loss(net(x), y)\n", - "display(torchopt.visual.make_dot(outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]))" + "display(\n", + " torchopt.visual.make_dot(\n", + " outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Known Issues\n", + "\n", + "Here we record some common issues faced by users when using the meta-optimizer." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**1. Get `NaN` error when using `MetaAdam` or other meta-optimizers.**\n", + "\n", + "The `NaN` error is because of the numerical instability of the `Adam` in meta-learning. There exist an `sqrt` operation in `Adam`'s computation process. Backpropogating through the `Adam` operator introduces the second derivation of the `sqrt` operation, which is not numerical stable, i.e. ${\\left. \\frac{d^2 \\sqrt{x}}{{dx}^2} \\right\\rvert}_{x = 0} = \\texttt{NaN}$. You can also refer to issue [facebookresearch/higher#125](https://github.com/facebookresearch/higher/issues/125).\n", + "\n", + "For this problem, TorchOpt have two recommended solutions.\n", + "\n", + "* Put the `sqrt` operation into the whole equation, and compute the derivation of the output to the input manually. The second derivation of the `sqrt` operation will be eliminated. You can achieve this by setting the flag `use_accelerated_op=True`, you can follow the instructions in notebook [Functional Optimizer](1_Functional_Optimizer.ipynb) and Meta-Optimizer." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "inner_optim = torchopt.MetaAdam(net, lr=1.0, use_accelerated_op=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "* Register hook to the first-order gradients. During the backpropagation, the NaN gradients will be set to 0, which will have a similar effect to the first solution but much slower. " + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "impl = torchopt.chain(torchopt.hook.register_hook(torchopt.hook.zero_nan_hook), torchopt.adam(1e-1))\n", + "inner_optim = torchopt.MetaOptimizer(net, impl)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**2. Get `Trying to backward through the graph a second time` error when conducting multiple meta-optimization.**\n", + "\n", + "Please refer to the tutorial notebook [Stop Gradient](4_Stop_Gradient.ipynb) for more guidance." ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3.8.13 ('torchopt')", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -564,7 +700,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.9.15" }, "vscode": { "interpreter": { diff --git a/tutorials/4_Stop_Gradient.ipynb b/tutorials/4_Stop_Gradient.ipynb index 604196ca..d8c24bc6 100644 --- a/tutorials/4_Stop_Gradient.ipynb +++ b/tutorials/4_Stop_Gradient.ipynb @@ -1,5 +1,5 @@ { - "cells": [ + "cells": [ { "cell_type": "markdown", "metadata": {}, @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "[](https://colab.research.google.com/drive/1jp_oPHIG6aaQMYGNxG72FSuWjABk1DHo?usp=sharing)" + "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/4_Stop_Gradient.ipynb)" ] }, { @@ -40,10 +40,11 @@ " def __init__(self, dim):\n", " super().__init__()\n", " self.fc = nn.Linear(dim, 1, bias=True)\n", - " \n", + "\n", " def forward(self, x):\n", " return self.fc(x)\n", "\n", + "\n", "loss_fn = F.mse_loss" ] }, @@ -81,7 +82,7 @@ "metadata": {}, "outputs": [], "source": [ - "meta_parameter = nn.Parameter(torch.tensor(1.), requires_grad=True)\n", + "meta_parameter = nn.Parameter(torch.tensor(1.0), requires_grad=True)\n", "\n", "optim = torchopt.MetaSGD(net, lr=1e-1)\n", "meta_optim = torch.optim.Adam([meta_parameter], lr=1e-1)" @@ -103,13 +104,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "inner loss: 0.5540\n", - "\n" + "inner loss: 0.3472\n", + "\n" ] }, { "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n139978828415600\n\ninner_loss\n ()\n\n\n\n139978603488640\n\nMseLossBackward0\n\n\n\n139978603488640->139978828415600\n\n\n\n\n\n139978603489744\n\nAddmmBackward0\n\n\n\n139978603489744->139978603488640\n\n\n\n\n\n139978603490800\n\nAccumulateGrad\n\n\n\n139978603490800->139978603489744\n\n\n\n\n\n139975938634512\n\nstep0.fc.bias\n (1)\n\n\n\n139975938634512->139978603490800\n\n\n\n\n\n139978603490224\n\nTBackward0\n\n\n\n139978603490224->139978603489744\n\n\n\n\n\n139978603490368\n\nAccumulateGrad\n\n\n\n139978603490368->139978603490224\n\n\n\n\n\n139975938634432\n\nstep0.fc.weight\n (1, 16)\n\n\n\n139975938634432->139978603490368\n\n\n\n\n\n" + "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140025091550880\n\ninner_loss\n()\n\n\n\n140028156253184\n\nMseLossBackward0\n\n\n\n140028156253184->140025091550880\n\n\n\n\n\n140028156436736\n\nAddmmBackward0\n\n\n\n140028156436736->140028156253184\n\n\n\n\n\n140025091526416\n\nAccumulateGrad\n\n\n\n140025091526416->140028156436736\n\n\n\n\n\n140028155952000\n\nstep0.fc.bias\n(1)\n\n\n\n140028155952000->140025091526416\n\n\n\n\n\n140025091525408\n\nTBackward0\n\n\n\n140025091525408->140028156436736\n\n\n\n\n\n140025091526224\n\nAccumulateGrad\n\n\n\n140025091526224->140025091525408\n\n\n\n\n\n140028155952880\n\nstep0.fc.weight\n(1, 16)\n\n\n\n140028155952880->140025091526224\n\n\n\n\n\n" }, "metadata": {}, "output_type": "display_data" @@ -122,12 +123,7 @@ "inner_loss = loss_fn(net(x), y)\n", "\n", "print(f'inner loss: {inner_loss:.4f}')\n", - "display(\n", - " torchopt.visual.make_dot(\n", - " inner_loss,\n", - " params=(init_net_state, {'inner_loss': inner_loss})\n", - " )\n", - ")" + "display(torchopt.visual.make_dot(inner_loss, params=(init_net_state, {'inner_loss': inner_loss})))" ] }, { @@ -168,13 +164,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "outer loss: 0.2297\n", - "\n" + "outer loss: 0.2039\n", + "\n" ] }, { "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n139975938634752\n\nouter_loss\n ()\n\n\n\n139975938188288\n\nMseLossBackward0\n\n\n\n139975938188288->139975938634752\n\n\n\n\n\n139975938188336\n\nAddmmBackward0\n\n\n\n139975938188336->139975938188288\n\n\n\n\n\n139975938188096\n\nAddBackward0\n step1.fc.bias\n (1)\n\n\n\n139975938188096->139975938188336\n\n\n\n\n\n139978603490800\n\nAccumulateGrad\n\n\n\n139978603490800->139975938188096\n\n\n\n\n\n139978603489744\n\nAddmmBackward0\n\n\n\n139978603490800->139978603489744\n\n\n\n\n\n139975938634512\n\nstep0.fc.bias\n (1)\n\n\n\n139975938634512->139978603490800\n\n\n\n\n\n139975938188480\n\nMulBackward0\n\n\n\n139975938188480->139975938188096\n\n\n\n\n\n139975938188144\n\nViewBackward0\n\n\n\n139975938188144->139975938188480\n\n\n\n\n\n139975938187664\n\nSumBackward1\n\n\n\n139975938187664->139975938188144\n\n\n\n\n\n139975938188720\n\nMseLossBackwardBackward0\n\n\n\n139975938188720->139975938187664\n\n\n\n\n\n139975938189200\n\nTBackward0\n\n\n\n139975938188720->139975938189200\n\n\n\n\n\n139975938188816\n\nMulBackward0\n\n\n\n139975938188816->139975938188720\n\n\n\n\n\n139975938188912\n\nAccumulateGrad\n\n\n\n139975938188912->139975938188816\n\n\n\n\n\n139975938635072\n\nmeta_parameter\n ()\n\n\n\n139975938635072->139975938188912\n\n\n\n\n\n139978603489744->139975938188720\n\n\n\n\n\n139978603490224\n\nTBackward0\n\n\n\n139978603490224->139978603489744\n\n\n\n\n\n139978603490368\n\nAccumulateGrad\n\n\n\n139978603490368->139978603490224\n\n\n\n\n\n139975938187808\n\nAddBackward0\n step1.fc.weight\n (1, 16)\n\n\n\n139978603490368->139975938187808\n\n\n\n\n\n139975938634432\n\nstep0.fc.weight\n (1, 16)\n\n\n\n139975938634432->139978603490368\n\n\n\n\n\n139975938188384\n\nTBackward0\n\n\n\n139975938188384->139975938188336\n\n\n\n\n\n139975938187808->139975938188384\n\n\n\n\n\n139975938188672\n\nMulBackward0\n\n\n\n139975938188672->139975938187808\n\n\n\n\n\n139975938189008\n\nTBackward0\n\n\n\n139975938189008->139975938188672\n\n\n\n\n\n139975938189104\n\nTBackward0\n\n\n\n139975938189104->139975938189008\n\n\n\n\n\n139975938188864\n\nMmBackward0\n\n\n\n139975938188864->139975938189104\n\n\n\n\n\n139975938189200->139975938188864\n\n\n\n\n\n" + "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140027829238416\n\nouter_loss\n()\n\n\n\n140025091525072\n\nMseLossBackward0\n\n\n\n140025091525072->140027829238416\n\n\n\n\n\n140025091525216\n\nAddmmBackward0\n\n\n\n140025091525216->140025091525072\n\n\n\n\n\n140025091526128\n\nAddBackward0\nstep1.fc.bias\n(1)\n\n\n\n140025091526128->140025091525216\n\n\n\n\n\n140025091526416\n\nAccumulateGrad\n\n\n\n140025091526416->140025091526128\n\n\n\n\n\n140028156436736\n\nAddmmBackward0\n\n\n\n140025091526416->140028156436736\n\n\n\n\n\n140028155952000\n\nstep0.fc.bias\n(1)\n\n\n\n140028155952000->140025091526416\n\n\n\n\n\n140025091524976\n\nMulBackward0\n\n\n\n140025091524976->140025091526128\n\n\n\n\n\n140025091526560\n\nViewBackward0\n\n\n\n140025091526560->140025091524976\n\n\n\n\n\n140025091525456\n\nSumBackward1\n\n\n\n140025091525456->140025091526560\n\n\n\n\n\n140025091524112\n\nMseLossBackwardBackward0\n\n\n\n140025091524112->140025091525456\n\n\n\n\n\n140024973742672\n\nTBackward0\n\n\n\n140025091524112->140024973742672\n\n\n\n\n\n140024973742288\n\nMulBackward0\n\n\n\n140024973742288->140025091524112\n\n\n\n\n\n140024973742384\n\nAccumulateGrad\n\n\n\n140024973742384->140024973742288\n\n\n\n\n\n140025091549440\n\nmeta_parameter\n()\n\n\n\n140025091549440->140024973742384\n\n\n\n\n\n140028156436736->140025091524112\n\n\n\n\n\n140025091525408\n\nTBackward0\n\n\n\n140025091525408->140028156436736\n\n\n\n\n\n140025091526224\n\nAccumulateGrad\n\n\n\n140025091526224->140025091525408\n\n\n\n\n\n140025091524928\n\nAddBackward0\nstep1.fc.weight\n(1, 16)\n\n\n\n140025091526224->140025091524928\n\n\n\n\n\n140028155952880\n\nstep0.fc.weight\n(1, 16)\n\n\n\n140028155952880->140025091526224\n\n\n\n\n\n140025091524448\n\nTBackward0\n\n\n\n140025091524448->140025091525216\n\n\n\n\n\n140025091524928->140025091524448\n\n\n\n\n\n140025091525600\n\nMulBackward0\n\n\n\n140025091525600->140025091524928\n\n\n\n\n\n140024973742144\n\nTBackward0\n\n\n\n140024973742144->140025091525600\n\n\n\n\n\n140024973742576\n\nTBackward0\n\n\n\n140024973742576->140024973742144\n\n\n\n\n\n140024973742480\n\nMmBackward0\n\n\n\n140024973742480->140024973742576\n\n\n\n\n\n140024973742672->140024973742480\n\n\n\n\n\n" }, "metadata": {}, "output_type": "display_data" @@ -191,7 +187,11 @@ "display(\n", " torchopt.visual.make_dot(\n", " outer_loss,\n", - " params=(init_net_state, one_step_net_state, {'meta_parameter': meta_parameter, 'outer_loss': outer_loss})\n", + " params=(\n", + " init_net_state,\n", + " one_step_net_state,\n", + " {'meta_parameter': meta_parameter, 'outer_loss': outer_loss},\n", + " ),\n", " )\n", ")" ] @@ -200,7 +200,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Then we backward the loss to conduct outer-loop meta optimization." + "Then we backward the loss to conduct outer-loop meta-optimization." ] }, { @@ -212,8 +212,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "meta_parameter.grad = tensor(-0.2464)\n", - "meta_parameter = Parameter containing: tensor(1.1000, requires_grad=True)\n" + "meta_parameter.grad = tensor(-0.1205)\n", + "meta_parameter = Parameter containing:\n", + "tensor(1.1000, requires_grad=True)\n" ] } ], @@ -236,11 +237,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In general, the back-propagation only frees saved tensors (often used as auxiliary data for computing the gradient) but the computation graph remains. Once the outer iteration is finished, if you want to use any intermediate network parameters produced by the inner loop for the next bi-level iteration, you should detach them from the computation graph.\n", + "In general, the backpropagation only frees saved tensors (often used as auxiliary data for computing the gradient) but the computation graph remains. Once the outer iteration is finished, if you want to use any intermediate network parameters produced by the inner loop for the next bi-level iteration, you should detach them from the computation graph.\n", "\n", "There are two main reasons:\n", "\n", - "- The network parameters are still connected to the previous computation graph (`.grad_fn` is not `None`). If later the gradient back-propagate to these parameters, the PyTorch backward engine will try to back-propagate through the previous computation graph. This will raise a `RuntimeError`: Trying to backward through the graph a second time...\n", + "- The network parameters are still connected to the previous computation graph (`.grad_fn` is not `None`). If later the gradient backpropagate to these parameters, the PyTorch backward engine will try to backpropagate through the previous computation graph. This will raise a `RuntimeError`: Trying to backward through the graph a second time...\n", "- If we do not detach the computation graph, the computation graph connected to these parameters can not be freed by GC (Garbage Collector) until these parameters are collected by GC." ] }, @@ -260,12 +261,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "\n" + "\n" ] }, { "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n139978828415600\n\nouter_loss\n ()\n\n\n\n139975938626944\n\nMseLossBackward0\n\n\n\n139975938626944->139978828415600\n\n\n\n\n\n139975938626656\n\nAddmmBackward0\n\n\n\n139975938626656->139975938626944\n\n\n\n\n\n139975938188624\n\nAddBackward0\n\n\n\n139975938188624->139975938626656\n\n\n\n\n\n139975938188096\n\nAddBackward0\n step1.fc.bias\n (1)\n\n\n\n139975938188096->139975938188624\n\n\n\n\n\n139975938188144\n\nAddmmBackward0\n\n\n\n139975938188096->139975938188144\n\n\n\n\n\n139975938187424\n\nAccumulateGrad\n\n\n\n139975938187424->139975938188096\n\n\n\n\n\n139975938188912\n\nAddmmBackward0\n\n\n\n139975938187424->139975938188912\n\n\n\n\n\n139975938634512\n\nstep0.fc.bias\n (1)\n\n\n\n139975938634512->139975938187424\n\n\n\n\n\n139975938187856\n\nMulBackward0\n\n\n\n139975938187856->139975938188096\n\n\n\n\n\n139975938188768\n\nViewBackward0\n\n\n\n139975938188768->139975938187856\n\n\n\n\n\n139975938189200\n\nSumBackward1\n\n\n\n139975938189200->139975938188768\n\n\n\n\n\n139975938189008\n\nMseLossBackwardBackward0\n\n\n\n139975938189008->139975938189200\n\n\n\n\n\n139975938189728\n\nTBackward0\n\n\n\n139975938189008->139975938189728\n\n\n\n\n\n139975938188864\n\nMulBackward0\n\n\n\n139975938188864->139975938189008\n\n\n\n\n\n139975938187952\n\nAccumulateGrad\n\n\n\n139975938187952->139975938188864\n\n\n\n\n\n139975938187712\n\nMulBackward0\n\n\n\n139975938187952->139975938187712\n\n\n\n\n\n139975938635072\n\nmeta_parameter\n ()\n\n\n\n139975938635072->139975938187952\n\n\n\n\n\n139975938188912->139975938189008\n\n\n\n\n\n139975938188480\n\nTBackward0\n\n\n\n139975938188480->139975938188912\n\n\n\n\n\n139975938188384\n\nAccumulateGrad\n\n\n\n139975938188384->139975938188480\n\n\n\n\n\n139975938187808\n\nAddBackward0\n step1.fc.weight\n (1, 16)\n\n\n\n139975938188384->139975938187808\n\n\n\n\n\n139975938634432\n\nstep0.fc.weight\n (1, 16)\n\n\n\n139975938634432->139975938188384\n\n\n\n\n\n139975938187520\n\nMulBackward0\n\n\n\n139975938187520->139975938188624\n\n\n\n\n\n139975938189296\n\nViewBackward0\n\n\n\n139975938189296->139975938187520\n\n\n\n\n\n139975938188576\n\nSumBackward1\n\n\n\n139975938188576->139975938189296\n\n\n\n\n\n139975938188720\n\nMseLossBackwardBackward0\n\n\n\n139975938188720->139975938188576\n\n\n\n\n\n139975938189824\n\nTBackward0\n\n\n\n139975938188720->139975938189824\n\n\n\n\n\n139975938187712->139975938188720\n\n\n\n\n\n139975938188144->139975938188720\n\n\n\n\n\n139975938188816\n\nTBackward0\n\n\n\n139975938188816->139975938188144\n\n\n\n\n\n139975938187808->139975938188816\n\n\n\n\n\n139975938189104\n\nAddBackward0\n\n\n\n139975938187808->139975938189104\n\n\n\n\n\n139975938189248\n\nMulBackward0\n\n\n\n139975938189248->139975938187808\n\n\n\n\n\n139975938189344\n\nTBackward0\n\n\n\n139975938189344->139975938189248\n\n\n\n\n\n139975938189536\n\nTBackward0\n\n\n\n139975938189536->139975938189344\n\n\n\n\n\n139975938189440\n\nMmBackward0\n\n\n\n139975938189440->139975938189536\n\n\n\n\n\n139975938189728->139975938189440\n\n\n\n\n\n139975938187904\n\nTBackward0\n\n\n\n139975938187904->139975938626656\n\n\n\n\n\n139975938189104->139975938187904\n\n\n\n\n\n139975938188240\n\nMulBackward0\n\n\n\n139975938188240->139975938189104\n\n\n\n\n\n139975938188048\n\nTBackward0\n\n\n\n139975938188048->139975938188240\n\n\n\n\n\n139975938188528\n\nTBackward0\n\n\n\n139975938188528->139975938188048\n\n\n\n\n\n139975938189584\n\nMmBackward0\n\n\n\n139975938189584->139975938188528\n\n\n\n\n\n139975938189824->139975938189584\n\n\n\n\n\n" + "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140024973755152\n\nouter_loss\n()\n\n\n\n140027829363232\n\nMseLossBackward0\n\n\n\n140027829363232->140024973755152\n\n\n\n\n\n140027829363616\n\nAddmmBackward0\n\n\n\n140027829363616->140027829363232\n\n\n\n\n\n140027829366544\n\nAddBackward0\n\n\n\n140027829366544->140027829363616\n\n\n\n\n\n140025091526128\n\nAddBackward0\nstep1.fc.bias\n(1)\n\n\n\n140025091526128->140027829366544\n\n\n\n\n\n140025091725152\n\nAddmmBackward0\n\n\n\n140025091526128->140025091725152\n\n\n\n\n\n140025091526416\n\nAccumulateGrad\n\n\n\n140025091526416->140025091526128\n\n\n\n\n\n140028156436736\n\nAddmmBackward0\n\n\n\n140025091526416->140028156436736\n\n\n\n\n\n140028155952000\n\nstep0.fc.bias\n(1)\n\n\n\n140028155952000->140025091526416\n\n\n\n\n\n140025091524976\n\nMulBackward0\n\n\n\n140025091524976->140025091526128\n\n\n\n\n\n140025091526560\n\nViewBackward0\n\n\n\n140025091526560->140025091524976\n\n\n\n\n\n140025091525456\n\nSumBackward1\n\n\n\n140025091525456->140025091526560\n\n\n\n\n\n140025091524112\n\nMseLossBackwardBackward0\n\n\n\n140025091524112->140025091525456\n\n\n\n\n\n140024973742672\n\nTBackward0\n\n\n\n140025091524112->140024973742672\n\n\n\n\n\n140024973742288\n\nMulBackward0\n\n\n\n140024973742288->140025091524112\n\n\n\n\n\n140024973742384\n\nAccumulateGrad\n\n\n\n140024973742384->140024973742288\n\n\n\n\n\n140025091726064\n\nMulBackward0\n\n\n\n140024973742384->140025091726064\n\n\n\n\n\n140025091549440\n\nmeta_parameter\n()\n\n\n\n140025091549440->140024973742384\n\n\n\n\n\n140028156436736->140025091524112\n\n\n\n\n\n140025091525408\n\nTBackward0\n\n\n\n140025091525408->140028156436736\n\n\n\n\n\n140025091526224\n\nAccumulateGrad\n\n\n\n140025091526224->140025091525408\n\n\n\n\n\n140025091524928\n\nAddBackward0\nstep1.fc.weight\n(1, 16)\n\n\n\n140025091526224->140025091524928\n\n\n\n\n\n140028155952880\n\nstep0.fc.weight\n(1, 16)\n\n\n\n140028155952880->140025091526224\n\n\n\n\n\n140025091726784\n\nMulBackward0\n\n\n\n140025091726784->140027829366544\n\n\n\n\n\n140025091726688\n\nViewBackward0\n\n\n\n140025091726688->140025091726784\n\n\n\n\n\n140025091725680\n\nSumBackward1\n\n\n\n140025091725680->140025091726688\n\n\n\n\n\n140025091726112\n\nMseLossBackwardBackward0\n\n\n\n140025091726112->140025091725680\n\n\n\n\n\n140025091726880\n\nTBackward0\n\n\n\n140025091726112->140025091726880\n\n\n\n\n\n140025091726064->140025091726112\n\n\n\n\n\n140025091725152->140025091726112\n\n\n\n\n\n140025091725824\n\nTBackward0\n\n\n\n140025091725824->140025091725152\n\n\n\n\n\n140025091524928->140025091725824\n\n\n\n\n\n140025091726016\n\nAddBackward0\n\n\n\n140025091524928->140025091726016\n\n\n\n\n\n140025091525600\n\nMulBackward0\n\n\n\n140025091525600->140025091524928\n\n\n\n\n\n140024973742144\n\nTBackward0\n\n\n\n140024973742144->140025091525600\n\n\n\n\n\n140024973742576\n\nTBackward0\n\n\n\n140024973742576->140024973742144\n\n\n\n\n\n140024973742480\n\nMmBackward0\n\n\n\n140024973742480->140024973742576\n\n\n\n\n\n140024973742672->140024973742480\n\n\n\n\n\n140027829365632\n\nTBackward0\n\n\n\n140027829365632->140027829363616\n\n\n\n\n\n140025091726016->140027829365632\n\n\n\n\n\n140025091726544\n\nMulBackward0\n\n\n\n140025091726544->140025091726016\n\n\n\n\n\n140025091726448\n\nTBackward0\n\n\n\n140025091726448->140025091726544\n\n\n\n\n\n140025091725584\n\nTBackward0\n\n\n\n140025091725584->140025091726448\n\n\n\n\n\n140025091727024\n\nMmBackward0\n\n\n\n140025091727024->140025091725584\n\n\n\n\n\n140025091726880->140025091727024\n\n\n\n\n\n" }, "metadata": {}, "output_type": "display_data" @@ -273,67 +274,103 @@ { "data": { "text/html": [ - "
╭──────────────────────────── Traceback (most recent call last) ────────────────────────────╮\n",
-       " <ipython-input-8-5906690e2182>:17 in <cell line: 17>                                      \n",
-       " /home/TorchOpt/Miniconda3/envs/torchopt/lib/python3.8/site-packages/torch/_tensor.py:396  \n",
-       " in backward                                                                               \n",
-       "                                                                                           \n",
-       "    393 │   │   │   │   retain_graph=retain_graph,                                         \n",
-       "    394 │   │   │   │   create_graph=create_graph,                                         \n",
-       "    395 │   │   │   │   inputs=inputs)                                                     \n",
-       "  396 │   │   torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs \n",
-       "    397 │                                                                                  \n",
-       "    398 │   def register_hook(self, hook):                                                 \n",
-       "    399 │   │   r\"\"\"Registers a backward hook.                                             \n",
-       "                                                                                           \n",
-       " /home/TorchOpt/Miniconda3/envs/torchopt/lib/python3.8/site-packages/torch/autograd/__init \n",
-       " __.py:173 in backward                                                                     \n",
-       "                                                                                           \n",
-       "   170 │   # The reason we repeat same the comment below is that                           \n",
-       "   171 │   # some Python versions print out the first line of a multi-line function        \n",
-       "   172 │   # calls in the traceback and some print out the last line                       \n",
-       " 173 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run th \n",
-       "   174 │   │   tensors, grad_tensors_, retain_graph, create_graph, inputs,                 \n",
-       "   175 │   │   allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine  \n",
-       "   176                                                                                     \n",
-       "╰───────────────────────────────────────────────────────────────────────────────────────────╯\n",
-       "RuntimeError: Trying to backward through the graph a second time (or directly access saved \n",
-       "tensors after they have already been freed). Saved intermediate values of the graph are freed\n",
-       "when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to \n",
-       "backward through the graph a second time or if you need to access saved tensors after calling\n",
-       "backward.\n",
+       "
╭─────────────────────────────────────── Traceback (most recent call last) ───────────────────────────────────────╮\n",
+       " /tmp/ipykernel_3962266/4178930003.py:21 in <module>                                                             \n",
+       "                                                                                                                 \n",
+       " [Errno 2] No such file or directory: '/tmp/ipykernel_3962266/4178930003.py'                                     \n",
+       "                                                                                                                 \n",
+       " /home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.9/site-packages/torch/_tensor.py:487 in backward           \n",
+       "                                                                                                                 \n",
+       "    484 │   │   │   │   create_graph=create_graph,                                                               \n",
+       "    485 │   │   │   │   inputs=inputs,                                                                           \n",
+       "    486 │   │   │   )                                                                                            \n",
+       "  487 │   │   torch.autograd.backward(                                                                         \n",
+       "    488 │   │   │   self, gradient, retain_graph, create_graph, inputs=inputs                                    \n",
+       "    489 │   │   )                                                                                                \n",
+       "    490                                                                                                          \n",
+       "                                                                                                                 \n",
+       " ╭───────────────────────── locals ──────────────────────────╮                                                   \n",
+       "  create_graph = False                                                                                         \n",
+       "      gradient = None                                                                                          \n",
+       "        inputs = None                                                                                          \n",
+       "  retain_graph = None                                                                                          \n",
+       "          self = tensor(0.1203, grad_fn=<MseLossBackward0>)                                                    \n",
+       " ╰───────────────────────────────────────────────────────────╯                                                   \n",
+       "                                                                                                                 \n",
+       " /home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.9/site-packages/torch/autograd/__init__.py:197 in backward \n",
+       "                                                                                                                 \n",
+       "   194 │   # The reason we repeat same the comment below is that                                                 \n",
+       "   195 │   # some Python versions print out the first line of a multi-line function                              \n",
+       "   196 │   # calls in the traceback and some print out the last line                                             \n",
+       " 197 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the ba                   \n",
+       "   198 │   │   tensors, grad_tensors_, retain_graph, create_graph, inputs,                                       \n",
+       "   199 │   │   allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to r                   \n",
+       "   200                                                                                                           \n",
+       "                                                                                                                 \n",
+       " ╭──────────────────────────── locals ────────────────────────────╮                                              \n",
+       "    create_graph = False                                                                                       \n",
+       "    grad_tensors = None                                                                                        \n",
+       "   grad_tensors_ = (tensor(1.),)                                                                               \n",
+       "  grad_variables = None                                                                                        \n",
+       "          inputs = ()                                                                                          \n",
+       "    retain_graph = False                                                                                       \n",
+       "         tensors = (tensor(0.1203, grad_fn=<MseLossBackward0>),)                                               \n",
+       " ╰────────────────────────────────────────────────────────────────╯                                              \n",
+       "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
+       "RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have \n",
+       "already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad().\n",
+       "Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved \n",
+       "tensors after calling backward.\n",
        "
\n" ], "text/plain": [ - "\u001b[91m╭─\u001b[0m\u001b[91m─────────────────────────── \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[91m ───────────────────────────\u001b[0m\u001b[91m─╮\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[33m\u001b[0m:\u001b[94m17\u001b[0m in \u001b[92m\u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[2;33m/home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.8/site-packages/torch/\u001b[0m\u001b[1;33m_tensor.py\u001b[0m:\u001b[94m396\u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m in \u001b[92mbackward\u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[2m 393 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mretain_graph=retain_graph, \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[2m 394 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mcreate_graph=create_graph, \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[2m 395 \u001b[0m\u001b[2m│ │ │ │ \u001b[0minputs=inputs) \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[31m❱ \u001b[0m 396 \u001b[2m│ │ \u001b[0mtorch.autograd.backward(\u001b[96mself\u001b[0m, gradient, retain_graph, create_graph, inputs \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[2m 397 \u001b[0m\u001b[2m│ \u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[2m 398 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92mregister_hook\u001b[0m(\u001b[96mself\u001b[0m, hook): \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[2m 399 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[33mr\u001b[0m\u001b[33m\"\"\"Registers a backward hook.\u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[2;33m/home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.8/site-packages/torch/autograd/\u001b[0m\u001b[1;33m__ini\u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[1;33mt__.py\u001b[0m:\u001b[94m173\u001b[0m in \u001b[92mbackward\u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[2m170 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# The reason we repeat same the comment below is that\u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[2m171 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# some Python versions print out the first line of a multi-line function\u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[2m172 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# calls in the traceback and some print out the last line\u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[31m❱ \u001b[0m173 \u001b[2m│ \u001b[0mVariable._execution_engine.run_backward( \u001b[2m# Calls into the C++ engine to run th\u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[2m174 \u001b[0m\u001b[2m│ │ \u001b[0mtensors, grad_tensors_, retain_graph, create_graph, inputs, \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[2m175 \u001b[0m\u001b[2m│ │ \u001b[0mallow_unreachable=\u001b[94mTrue\u001b[0m, accumulate_grad=\u001b[94mTrue\u001b[0m) \u001b[2m# Calls into the C++ engine \u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[2m176 \u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m╰───────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n", - "\u001b[1;91mRuntimeError: \u001b[0mTrying to backward through the graph a second time \u001b[1m(\u001b[0mor directly access saved \n", - "tensors after they have already been freed\u001b[1m)\u001b[0m. Saved intermediate values of the graph are freed\n", - "when you call \u001b[1;35m.backward\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m or \u001b[1;35mautograd.grad\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m. Specify \u001b[33mretain_graph\u001b[0m=\u001b[3;92mTrue\u001b[0m if you need to \n", - "backward through the graph a second time or if you need to access saved tensors after calling\n", - "backward.\n" + "\u001b[31m╭─\u001b[0m\u001b[31m────────────────────────────────────── \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m ──────────────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/tmp/ipykernel_3962266/\u001b[0m\u001b[1;33m4178930003.py\u001b[0m:\u001b[94m21\u001b[0m in \u001b[92m\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[3;31m[Errno 2] No such file or directory: '/tmp/ipykernel_3962266/4178930003.py'\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.9/site-packages/torch/\u001b[0m\u001b[1;33m_tensor.py\u001b[0m:\u001b[94m487\u001b[0m in \u001b[92mbackward\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 484 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mcreate_graph=create_graph, \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 485 \u001b[0m\u001b[2m│ │ │ │ \u001b[0minputs=inputs, \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 486 \u001b[0m\u001b[2m│ │ │ \u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 487 \u001b[2m│ │ \u001b[0mtorch.autograd.backward( \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 488 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[96mself\u001b[0m, gradient, retain_graph, create_graph, inputs=inputs \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 489 \u001b[0m\u001b[2m│ │ \u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 490 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m╭─\u001b[0m\u001b[33m──────────────────────── locals ─────────────────────────\u001b[0m\u001b[33m─╮\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m create_graph = \u001b[94mFalse\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m gradient = \u001b[94mNone\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m inputs = \u001b[94mNone\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m retain_graph = \u001b[94mNone\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m self = \u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[94m0.1203\u001b[0m, \u001b[33mgrad_fn\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mMseLossBackward0\u001b[0m\u001b[1m>\u001b[0m\u001b[1m)\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m╰───────────────────────────────────────────────────────────╯\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.9/site-packages/torch/autograd/\u001b[0m\u001b[1;33m__init__.py\u001b[0m:\u001b[94m197\u001b[0m in \u001b[92mbackward\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m194 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# The reason we repeat same the comment below is that\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m195 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# some Python versions print out the first line of a multi-line function\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m196 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# calls in the traceback and some print out the last line\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m197 \u001b[2m│ \u001b[0mVariable._execution_engine.run_backward( \u001b[2m# Calls into the C++ engine to run the ba\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m198 \u001b[0m\u001b[2m│ │ \u001b[0mtensors, grad_tensors_, retain_graph, create_graph, inputs, \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m199 \u001b[0m\u001b[2m│ │ \u001b[0mallow_unreachable=\u001b[94mTrue\u001b[0m, accumulate_grad=\u001b[94mTrue\u001b[0m) \u001b[2m# Calls into the C++ engine to r\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m200 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m╭─\u001b[0m\u001b[33m─────────────────────────── locals ───────────────────────────\u001b[0m\u001b[33m─╮\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m create_graph = \u001b[94mFalse\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m grad_tensors = \u001b[94mNone\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m grad_tensors_ = \u001b[1m(\u001b[0m\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[94m1\u001b[0m.\u001b[1m)\u001b[0m,\u001b[1m)\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m grad_variables = \u001b[94mNone\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m inputs = \u001b[1m(\u001b[0m\u001b[1m)\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m retain_graph = \u001b[94mFalse\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m tensors = \u001b[1m(\u001b[0m\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[94m0.1203\u001b[0m, \u001b[33mgrad_fn\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mMseLossBackward0\u001b[0m\u001b[1m>\u001b[0m\u001b[1m)\u001b[0m,\u001b[1m)\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m╰────────────────────────────────────────────────────────────────╯\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n", + "\u001b[1;91mRuntimeError: \u001b[0mTrying to backward through the graph a second time \u001b[1m(\u001b[0mor directly access saved tensors after they have \n", + "already been freed\u001b[1m)\u001b[0m. Saved intermediate values of the graph are freed when you call \u001b[1;35m.backward\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m or \u001b[1;35mautograd.grad\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m.\n", + "Specify \u001b[33mretain_graph\u001b[0m=\u001b[3;92mTrue\u001b[0m if you need to backward through the graph a second time or if you need to access saved \n", + "tensors after calling backward.\n" ] }, "metadata": {}, @@ -351,7 +388,11 @@ "display(\n", " torchopt.visual.make_dot(\n", " outer_loss,\n", - " params=(init_net_state, one_step_net_state, {'meta_parameter': meta_parameter, 'outer_loss': outer_loss})\n", + " params=(\n", + " init_net_state,\n", + " one_step_net_state,\n", + " {'meta_parameter': meta_parameter, 'outer_loss': outer_loss},\n", + " ),\n", " )\n", ")\n", "\n", @@ -397,14 +438,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "meta_parameter.grad = tensor(-0.0914)\n", - "meta_parameter = Parameter containing: tensor(1.1887, requires_grad=True)\n", - "\n" + "meta_parameter.grad = tensor(-0.0635)\n", + "meta_parameter = Parameter containing:\n", + "tensor(1.1940, requires_grad=True)\n", + "\n" ] }, { "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n139975938621248\n\nouter_loss\n ()\n\n\n\n139975251126352\n\nMseLossBackward0\n\n\n\n139975251126352->139975938621248\n\n\n\n\n\n139975251126592\n\nAddmmBackward0\n\n\n\n139975251126592->139975251126352\n\n\n\n\n\n139975251125920\n\nAddBackward0\n\n\n\n139975251125920->139975251126592\n\n\n\n\n\n139975251126400\n\nAccumulateGrad\n\n\n\n139975251126400->139975251125920\n\n\n\n\n\n139975251127120\n\nAddmmBackward0\n\n\n\n139975251126400->139975251127120\n\n\n\n\n\n139975938636032\n\nstep1.detached.fc.bias\n (1)\n\n\n\n139975938636032->139975251126400\n\n\n\n\n\n139975251126304\n\nMulBackward0\n\n\n\n139975251126304->139975251125920\n\n\n\n\n\n139975251127072\n\nViewBackward0\n\n\n\n139975251127072->139975251126304\n\n\n\n\n\n139975251128080\n\nSumBackward1\n\n\n\n139975251128080->139975251127072\n\n\n\n\n\n139975251126448\n\nMseLossBackwardBackward0\n\n\n\n139975251126448->139975251128080\n\n\n\n\n\n139975251127456\n\nTBackward0\n\n\n\n139975251126448->139975251127456\n\n\n\n\n\n139975251127312\n\nMulBackward0\n\n\n\n139975251127312->139975251126448\n\n\n\n\n\n139975251126016\n\nAccumulateGrad\n\n\n\n139975251126016->139975251127312\n\n\n\n\n\n139975938635072\n\nmeta_parameter\n ()\n\n\n\n139975938635072->139975251126016\n\n\n\n\n\n139975251127120->139975251126448\n\n\n\n\n\n139975251126880\n\nTBackward0\n\n\n\n139975251126880->139975251127120\n\n\n\n\n\n139975251126544\n\nAccumulateGrad\n\n\n\n139975251126544->139975251126880\n\n\n\n\n\n139975251128272\n\nAddBackward0\n\n\n\n139975251126544->139975251128272\n\n\n\n\n\n139975938635552\n\nstep1.detached.fc.weight\n (1, 16)\n\n\n\n139975938635552->139975251126544\n\n\n\n\n\n139975251126256\n\nTBackward0\n\n\n\n139975251126256->139975251126592\n\n\n\n\n\n139975251128272->139975251126256\n\n\n\n\n\n139975251127744\n\nMulBackward0\n\n\n\n139975251127744->139975251128272\n\n\n\n\n\n139975251126112\n\nTBackward0\n\n\n\n139975251126112->139975251127744\n\n\n\n\n\n139975251126640\n\nTBackward0\n\n\n\n139975251126640->139975251126112\n\n\n\n\n\n139975251126976\n\nMmBackward0\n\n\n\n139975251126976->139975251126640\n\n\n\n\n\n139975251127456->139975251126976\n\n\n\n\n\n" + "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140024973754912\n\nouter_loss\n()\n\n\n\n140024956770528\n\nMseLossBackward0\n\n\n\n140024956770528->140024973754912\n\n\n\n\n\n140024956772112\n\nAddmmBackward0\n\n\n\n140024956772112->140024956770528\n\n\n\n\n\n140024956770720\n\nAddBackward0\n\n\n\n140024956770720->140024956772112\n\n\n\n\n\n140024962101312\n\nAccumulateGrad\n\n\n\n140024962101312->140024956770720\n\n\n\n\n\n140024973745552\n\nAddmmBackward0\n\n\n\n140024962101312->140024973745552\n\n\n\n\n\n140025091547520\n\nstep1.detached.fc.bias\n(1)\n\n\n\n140025091547520->140024962101312\n\n\n\n\n\n140024971586864\n\nMulBackward0\n\n\n\n140024971586864->140024956770720\n\n\n\n\n\n140024973742528\n\nViewBackward0\n\n\n\n140024973742528->140024971586864\n\n\n\n\n\n140024973743968\n\nSumBackward1\n\n\n\n140024973743968->140024973742528\n\n\n\n\n\n140024973742768\n\nMseLossBackwardBackward0\n\n\n\n140024973742768->140024973743968\n\n\n\n\n\n140024973744400\n\nTBackward0\n\n\n\n140024973742768->140024973744400\n\n\n\n\n\n140024973744688\n\nMulBackward0\n\n\n\n140024973744688->140024973742768\n\n\n\n\n\n140024973745264\n\nAccumulateGrad\n\n\n\n140024973745264->140024973744688\n\n\n\n\n\n140025091549440\n\nmeta_parameter\n()\n\n\n\n140025091549440->140024973745264\n\n\n\n\n\n140024973745552->140024973742768\n\n\n\n\n\n140024973745168\n\nTBackward0\n\n\n\n140024973745168->140024973745552\n\n\n\n\n\n140024973744256\n\nAccumulateGrad\n\n\n\n140024973744256->140024973745168\n\n\n\n\n\n140024973745984\n\nAddBackward0\n\n\n\n140024973744256->140024973745984\n\n\n\n\n\n140027828983424\n\nstep1.detached.fc.weight\n(1, 16)\n\n\n\n140027828983424->140024973744256\n\n\n\n\n\n140024956771632\n\nTBackward0\n\n\n\n140024956771632->140024956772112\n\n\n\n\n\n140024973745984->140024956771632\n\n\n\n\n\n140024973743728\n\nMulBackward0\n\n\n\n140024973743728->140024973745984\n\n\n\n\n\n140024973743344\n\nTBackward0\n\n\n\n140024973743344->140024973743728\n\n\n\n\n\n140024973745312\n\nTBackward0\n\n\n\n140024973745312->140024973743344\n\n\n\n\n\n140024973743200\n\nMmBackward0\n\n\n\n140024973743200->140024973745312\n\n\n\n\n\n140024973744400->140024973743200\n\n\n\n\n\n" }, "metadata": {}, "output_type": "display_data" @@ -414,7 +456,9 @@ "# Stop gradient and make them become the leaf node\n", "torchopt.stop_gradient(net)\n", "torchopt.stop_gradient(optim)\n", - "one_step_net_state_detached = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step1.detached.')\n", + "one_step_net_state_detached = torchopt.extract_state_dict(\n", + " net, enable_visual=True, visual_prefix='step1.detached.'\n", + ")\n", "\n", "# Inner update\n", "inner_loss = loss_fn(net(x), y)\n", @@ -432,7 +476,10 @@ "display(\n", " torchopt.visual.make_dot(\n", " outer_loss,\n", - " params=(one_step_net_state_detached, {'meta_parameter': meta_parameter, 'outer_loss': outer_loss})\n", + " params=(\n", + " one_step_net_state_detached,\n", + " {'meta_parameter': meta_parameter, 'outer_loss': outer_loss},\n", + " ),\n", " )\n", ")" ] @@ -447,7 +494,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.8.13 ('torchopt')", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -461,7 +508,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.9.15" }, "vscode": { "interpreter": { diff --git a/tutorials/5_Implicit_Differentiation.ipynb b/tutorials/5_Implicit_Differentiation.ipynb new file mode 100644 index 00000000..23407801 --- /dev/null +++ b/tutorials/5_Implicit_Differentiation.ipynb @@ -0,0 +1,576 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8850c832-3b54-4971-8ee0-2cd64b585ea8", + "metadata": {}, + "source": [ + "# TorchOpt for Implicit Differentiation" + ] + }, + { + "cell_type": "markdown", + "id": "2b547376", + "metadata": {}, + "source": [ + "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/5_Implicit_Differentiation.ipynb)" + ] + }, + { + "cell_type": "markdown", + "id": "8d7f9865-dc02-43d4-be90-da1160c4e4dd", + "metadata": {}, + "source": [ + "By treating the solution $\\phi^{\\star}$ as an implicit function of $\\theta$, the idea of implicit differentiation is to directly get analytical best-response derivatives $\\partial \\phi^{\\star}(\\theta)/ \\partial \\theta$ by implicit function theorem. This is suitable for algorithms when the inner-level optimal solution is achieved ${\\left. \\frac{\\partial F (\\phi, \\theta)}{\\partial \\phi} \\right\\rvert}_{\\phi = \\phi^{\\star}} = 0$ or reaches some stationary conditions $F (\\phi^{\\star}, \\theta) = 0$, such as [iMAML](https://arxiv.org/abs/1909.04630) and [DEQ](https://arxiv.org/abs/1909.01377)." + ] + }, + { + "cell_type": "markdown", + "id": "d7e4b9e1-115f-45ad-a9b3-ea338bcfe6dd", + "metadata": {}, + "source": [ + "In this tutorial, we will introduce how TorchOpt can be used to conduct implicit differentiation." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "8f13ae67-e328-409f-84a8-1fc425c03a66", + "metadata": {}, + "outputs": [], + "source": [ + "import functorch\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "import torchopt" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "0cdaac49-4b94-4900-9bb5-a39057ac8b21", + "metadata": {}, + "source": [ + "## 1. Functional API\n", + "\n", + "The basic functional API is `torchopt.diff.implicit.custom_root`, which is used as the decorator for the forward process implicit gradient procedures. Users are required to implement the stationary conditions for the inner-loop process, which will be used as the input of custom_root decorator. We show the pseudo code in the following part.\n", + "\n", + "```python\n", + "# Functional API for implicit gradient\n", + "def stationary(params, meta_params, data):\n", + " # stationary condition construction\n", + " return stationary condition\n", + "\n", + "# Decorator that wraps the function\n", + "# Optionally specify the linear solver (conjugate gradient or Neumann series)\n", + "@torchopt.diff.implicit.custom_root(stationary, solve=linear_solver)\n", + "def solve(params, meta_params, data):\n", + " # Forward optimization process for params\n", + " return optimal_params\n", + "\n", + "# Define params, meta_params and get data\n", + "params, meta_prams, data = ..., ..., ...\n", + "optimal_params = solve(params, meta_params, data)\n", + "loss = outer_loss(optimal_params)\n", + "\n", + "meta_grads = torch.autograd.grad(loss, meta_params)\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "dbef87df-2164-4f1d-8919-37a6fbdc5011", + "metadata": {}, + "source": [ + "Here we use the example of [iMAML](https://arxiv.org/abs/1909.04630) as a real example. For iMAML, the inner-loop objective is described by the following equation.\n", + "\n", + "$$\n", + "{\\mathcal{Alg}}^{\\star} \\left( \\boldsymbol{\\theta}, \\mathcal{D}_{i}^{\\text{tr}} \\right) = \\underset{\\phi'}{\\operatorname{\\arg \\min}} ~ G \\left( \\boldsymbol{\\phi}', \\boldsymbol{\\theta} \\right) \\triangleq \\mathcal{L} \\left( \\boldsymbol{\\phi}', \\mathcal{D}_{i}^{\\text{tr}} \\right) + \\frac{\\lambda}{2} {\\left\\| \\boldsymbol{\\phi}' - \\boldsymbol{\\theta} \\right\\|}^{2}\n", + "$$\n", + "\n", + "According to this function, we can define the forward function `inner_solver`, where we solve this equation based on sufficient gradient descents. For such inner-loop process, the optimality condition is that the gradient w.r.t inner-loop parameter is $0$.\n", + "\n", + "$$\n", + "{\\left. \\nabla_{\\boldsymbol{\\phi}'} G \\left( \\boldsymbol{\\phi}', \\boldsymbol{\\theta} \\right) \\right\\rvert}_{\\boldsymbol{\\phi}' = \\boldsymbol{\\phi}^{\\star}} = 0\n", + "$$\n", + "\n", + "Thus we can define the optimality function by defining `imaml_objective` and make it first-order gradient w.r.t the inner-loop parameter as $0$. We achieve so by calling out `functorch.grad(imaml_objective, argnums=0)`. Finally, the forward function is decorated by the `@torchopt.diff.implicit.custom_root` decorator and the optimality condition we define." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8d623b2f-48ee-4df6-a2ce-cf306b4c9067", + "metadata": {}, + "outputs": [], + "source": [ + "# Inner-loop objective function\n", + "# The optimality function: grad(imaml_objective)\n", + "def imaml_objective(params, meta_params, data):\n", + " x, y, fmodel = data\n", + " y_pred = fmodel(params, x)\n", + " regularization_loss = 0.0\n", + " for p1, p2 in zip(params, meta_params):\n", + " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", + " loss = F.mse_loss(y_pred, y) + regularization_loss\n", + " return loss\n", + "\n", + "\n", + "# Optimality Condition is: the gradient w.r.t inner-loop optimal params is 0 (we achieve so by\n", + "# specifying argnums=0 in functorch.grad) the argnums=1 specify which meta-parameter we want to\n", + "# backpropogate, in this case we want to backpropogate to the initial parameters so we set it as 1.\n", + "# You can also set argnums as (1, 2) if you want to backpropogate through multiple meta-parameters\n", + "\n", + "\n", + "# Here we pass argnums=1 to the custom_root. That means we want to compute the gradient of\n", + "# optimal_params w.r.t. the 1-indexed argument in inner_solver, i.e., params.\n", + "# torchopt.linear_solve.solve_normal_cg specify that we use the conjugate gradient based linear solver\n", + "@torchopt.diff.implicit.custom_root(\n", + " functorch.grad(imaml_objective, argnums=0), # optimality function\n", + " argnums=1,\n", + " solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),\n", + ")\n", + "def inner_solver(params, meta_params, data):\n", + " # Initial functional optimizer based on TorchOpt\n", + " x, y, fmodel = data\n", + " optimizer = torchopt.sgd(lr=2e-2)\n", + " opt_state = optimizer.init(params)\n", + " with torch.enable_grad():\n", + " # Temporarily enable gradient computation for conducting the optimization\n", + " for i in range(100):\n", + " pred = fmodel(params, x)\n", + " loss = F.mse_loss(pred, y) # compute loss\n", + "\n", + " # Compute regularization loss\n", + " regularization_loss = 0.0\n", + " for p1, p2 in zip(params, meta_params):\n", + " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", + " final_loss = loss + regularization_loss\n", + "\n", + " grads = torch.autograd.grad(final_loss, params) # compute gradients\n", + " updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get updates\n", + " params = torchopt.apply_updates(params, updates, inplace=True)\n", + "\n", + " optimal_params = params\n", + " return optimal_params\n", + "\n", + "\n", + "# torchopt.linear_solve.solve_inv specify that we use the Neumann Series inversion linear solver\n", + "@torchopt.diff.implicit.custom_root(\n", + " functorch.grad(imaml_objective, argnums=0), # optimality function\n", + " argnums=1,\n", + " solve=torchopt.linear_solve.solve_inv(ns=True, maxiter=100, alpha=0.1),\n", + ")\n", + "def inner_solver_inv_ns(params, meta_params, data):\n", + " # Initial functional optimizer based on TorchOpt\n", + " x, y, fmodel = data\n", + " optimizer = torchopt.sgd(lr=2e-2)\n", + " opt_state = optimizer.init(params)\n", + " with torch.enable_grad():\n", + " # Temporarily enable gradient computation for conducting the optimization\n", + " for i in range(100):\n", + " pred = fmodel(params, x)\n", + " loss = F.mse_loss(pred, y) # compute loss\n", + "\n", + " # Compute regularization loss\n", + " regularization_loss = 0.0\n", + " for p1, p2 in zip(params, meta_params):\n", + " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", + " final_loss = loss + regularization_loss\n", + "\n", + " grads = torch.autograd.grad(final_loss, params) # compute gradients\n", + " updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get updates\n", + " params = torchopt.apply_updates(params, updates, inplace=True)\n", + "\n", + " optimal_params = params\n", + " return optimal_params" + ] + }, + { + "cell_type": "markdown", + "id": "32a75c81-d479-4120-a73d-5b2b488358d0", + "metadata": {}, + "source": [ + "In the next step, we consider a specific case for one layer neural network to fit the linear data." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "fb95538b-1fd9-4ec8-9f57-6360bedc05b7", + "metadata": {}, + "outputs": [], + "source": [ + "torch.manual_seed(0)\n", + "x = torch.randn(20, 4)\n", + "w = torch.randn(4, 1)\n", + "b = torch.randn(1)\n", + "y = x @ w + b + 0.5 * torch.randn(20, 1)" + ] + }, + { + "cell_type": "markdown", + "id": "eeb1823a-2231-4471-bb68-cce7724f2578", + "metadata": {}, + "source": [ + "We instantiate an one layer neural network, where the weights and bias are initialized with constant." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d50a7bfe-ac69-4089-8cf8-3cbd69d6d4e7", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "class Net(nn.Module):\n", + " def __init__(self, dim):\n", + " super().__init__()\n", + " self.fc = nn.Linear(dim, 1, bias=True)\n", + " nn.init.ones_(self.fc.weight)\n", + " nn.init.zeros_(self.fc.bias)\n", + "\n", + " def forward(self, x):\n", + " return self.fc(x)\n", + "\n", + "\n", + "model = Net(4)\n", + "fmodel, meta_params = functorch.make_functional(model)\n", + "data = (x, y, fmodel)\n", + "\n", + "\n", + "# Clone function for parameters\n", + "def clone(params):\n", + " cloned = []\n", + " for item in params:\n", + " if isinstance(item, torch.Tensor):\n", + " cloned.append(item.clone().detach_().requires_grad_(True))\n", + " else:\n", + " cloned.append(item)\n", + " return tuple(cloned)" + ] + }, + { + "cell_type": "markdown", + "id": "065c36c4-89e2-4a63-8213-63db6ee3b08e", + "metadata": {}, + "source": [ + "We take the forward process by calling out the forward function, then we pass the optimal params into the outer-loop loss function." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "115e79c6-911f-4743-a2ed-e50a71c3a813", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "optimal_params = inner_solver(clone(meta_params), meta_params, data)\n", + "\n", + "outer_loss = fmodel(optimal_params, x).mean()" + ] + }, + { + "cell_type": "markdown", + "id": "e2812351-f635-496e-9732-c80831ac04a6", + "metadata": {}, + "source": [ + "Finally, we can get the meta-gradient as shown below." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "6bdcbe8d-2336-4f80-b124-eb43c5a2fc0a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(tensor([[-0.0369, 0.0248, 0.0347, 0.0067]]), tensor([0.3156]))\n" + ] + } + ], + "source": [ + "torch.autograd.grad(outer_loss, meta_params)" + ] + }, + { + "cell_type": "markdown", + "id": "926ae8bb", + "metadata": {}, + "source": [ + "Also we can switch to the Neumann Series inversion linear solver." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "43df0374", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(tensor([[-0.0369, 0.0248, 0.0347, 0.0067]]), tensor([0.3156]))\n" + ] + } + ], + "source": [ + "optimal_params = inner_solver_inv_ns(clone(meta_params), meta_params, data)\n", + "outer_loss = fmodel(optimal_params, x).mean()\n", + "torch.autograd.grad(outer_loss, meta_params)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "c92e67ea-b220-4a14-a1ea-4eb3c5f52b6b", + "metadata": {}, + "source": [ + "## 2. OOP API\n", + "\n", + "The basic OOP class is the class `ImplicitMetaGradientModule`. We make the network as an `nn.Module` following a classical PyTorch style. Users need to define the stationary condition/objective function and the inner-loop solve function to enable implicit gradient computation. We show the pseudo code in the following part.\n", + "\n", + "```python\n", + "from torchopt.nn import ImplicitMetaGradientModule\n", + "\n", + "# Inherited from the class ImplicitMetaGradientModule\n", + "# Optionally specify the linear solver (conjugate gradient or Neumann series)\n", + "class InnerNet(ImplicitMetaGradientModule, linear_solve=linear_solver):\n", + " def __init__(self, meta_module):\n", + " ...\n", + "\n", + " def forward(self, batch):\n", + " # Forward process\n", + " ...\n", + "\n", + " def optimality(self, batch, labels):\n", + " # Stationary condition construction for calculating implicit gradient\n", + " # NOTE: If this method is not implemented, it will be automatically derived from the\n", + " # gradient of the `objective` function.\n", + " ...\n", + "\n", + " def objective(self, batch, labels):\n", + " # Define the inner-loop optimization objective\n", + " # NOTE: This method is optional if method `optimality` is implemented.\n", + " ...\n", + "\n", + " def solve(self, batch, labels):\n", + " # Conduct the inner-loop optimization\n", + " ...\n", + " return self # optimized module\n", + "\n", + "# Get meta_params and data\n", + "meta_params, data = ..., ...\n", + "inner_net = InnerNet()\n", + "\n", + "# Solve for inner-loop process related to the meta-parameters\n", + "optimal_inner_net = inner_net.solve(meta_params, *data)\n", + "\n", + "# Get outer-loss and solve for meta-gradient\n", + "loss = outer_loss(optimal_inner_net)\n", + "meta_grad = torch.autograd.grad(loss, meta_params)\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "62fbe520-11d0-41ff-9b0a-c6508b1d01cf", + "metadata": {}, + "source": [ + "The class `ImplicitMetaGradientModule` is to enable the gradient flow from `self.parameters()` to `self.meta_parameters()`. In `__init__` function, users need to define the inner parameters and meta-parameters. By default, `ImplicitMetaGradientModule` treats all tensors and modules from input as `self.meta_parameters()`, and all tensors and modules defined in the `__init__` are regarded as `self.parameters()`. Users can also register `self.parameters()` and `self.meta_parameters()` by calling `self.register_parameter()` and `self.register_meta_parameter()` respectively." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c3999684-f4d3-4bc0-86ab-a7e803b2fe80", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(tensor([[-0.0369, 0.0248, 0.0347, 0.0067]]), tensor([0.3156]))\n" + ] + } + ], + "source": [ + "class InnerNet(\n", + " torchopt.nn.ImplicitMetaGradientModule,\n", + " linear_solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),\n", + "):\n", + " def __init__(self, meta_net, n_inner_iter, reg_param):\n", + " super().__init__()\n", + " # Declaration of the meta-parameter\n", + " self.meta_net = meta_net\n", + " # Get a deepcopy, register inner-parameter\n", + " self.net = torchopt.module_clone(meta_net, by='deepcopy', detach_buffers=True)\n", + " self.n_inner_iter = n_inner_iter\n", + " self.reg_param = reg_param\n", + "\n", + " def forward(self, x):\n", + " return self.net(x)\n", + "\n", + " def objective(self, x, y):\n", + " # We do not implement the optimality conditions, so it will be automatically derived from\n", + " # the gradient of the `objective` function.\n", + " y_pred = self(x)\n", + " loss = F.mse_loss(y_pred, y)\n", + " regularization_loss = 0\n", + " for p1, p2 in zip(\n", + " self.parameters(), # parameters of `self.net`\n", + " self.meta_parameters(), # parameters of `self.meta_net`\n", + " ):\n", + " regularization_loss += (\n", + " 0.5 * self.reg_param * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", + " )\n", + " return loss + regularization_loss\n", + "\n", + " def solve(self, x, y):\n", + " params = tuple(self.parameters())\n", + " inner_optim = torchopt.SGD(params, lr=2e-2)\n", + " with torch.enable_grad():\n", + " # Temporarily enable gradient computation for conducting the optimization\n", + " for _ in range(self.n_inner_iter):\n", + " loss = self.objective(x, y)\n", + " inner_optim.zero_grad()\n", + " # NOTE: The parameter inputs should be explicitly specified in `backward` function\n", + " # as argument `inputs`. Otherwise, if not provided, the gradient is accumulated into\n", + " # all the leaf Tensors (including the meta-parameters) that were used to compute the\n", + " # objective output. Alternatively, please use `torch.autograd.grad` instead.\n", + " loss.backward(inputs=params) # backward pass in inner-loop\n", + " inner_optim.step() # update inner parameters\n", + " return self\n", + "\n", + "\n", + "# Initialize the meta-network\n", + "meta_net = Net(4)\n", + "inner_net = InnerNet(meta_net, 100, reg_param=1)\n", + "\n", + "# Solve for inner-loop\n", + "optimal_inner_net = inner_net.solve(x, y)\n", + "outer_loss = optimal_inner_net(x).mean()\n", + "\n", + "# Derive the meta-gradient\n", + "torch.autograd.grad(outer_loss, meta_net.parameters())" + ] + }, + { + "cell_type": "markdown", + "id": "2b69a5d6-b5e4-4f08-af0a-40afc2382b45", + "metadata": {}, + "source": [ + "We also show an example on how to implement implicit gradient calculation when the inner-level optimal solution reaches some stationary conditions $F (\\phi^{\\star}, \\theta) = 0$, such as [DEQ](https://arxiv.org/abs/1909.01377), based on the OOP API. " + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "de87c308-d847-4491-9aa1-bc393e6dd1d8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "(\n", + "│ tensor([[ 0.0272, 0.0031, -0.0156, -0.0238],\n", + "│ │ [ 0.1004, 0.0113, -0.0573, -0.0878],\n", + "│ │ [ 0.0666, 0.0075, -0.0380, -0.0583],\n", + "│ │ [ 0.1446, 0.0163, -0.0826, -0.1265]]),\n", + "│ tensor([0.0574, 0.2114, 0.1403, 0.3046])\n", + ")\n" + ] + } + ], + "source": [ + "class Net(nn.Module):\n", + " def __init__(self, dim):\n", + " super().__init__()\n", + " self.fc = nn.Linear(dim, dim)\n", + "\n", + " def forward(self, x):\n", + " return self.fc(x)\n", + "\n", + "\n", + "class InnerNet(\n", + " torchopt.nn.ImplicitMetaGradientModule,\n", + " linear_solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),\n", + "):\n", + " def __init__(self, meta_net, x0):\n", + " super().__init__()\n", + " # Register meta-parameter\n", + " self.meta_net = meta_net\n", + " # Declaration of the inner-parameter, register inner-parameter\n", + " self.x = nn.Parameter(x0.clone().detach_(), requires_grad=True)\n", + "\n", + " def forward(self, x):\n", + " return self.meta_net(x)\n", + "\n", + " def optimality(self):\n", + " # Fixed-point condition\n", + " return (self.x - self(self.x),)\n", + "\n", + " def solve(self):\n", + " # Solving inner-loop fixed-point iteration\n", + " # This is just an illustrating example for solving fixed-point iteration\n", + " # one can use more advanced method to solve fixed-point iteration\n", + " # such as anderson acceleration.\n", + " for _ in range(10):\n", + " self.x.copy_(self(self.x))\n", + " return self\n", + "\n", + "\n", + "# Initialize meta-network\n", + "torch.manual_seed(0)\n", + "meta_net = Net(4)\n", + "x0 = torch.randn(1, 4)\n", + "inner_net = InnerNet(meta_net, x0)\n", + "\n", + "# Solve for inner-loop\n", + "optimal_inner_net = inner_net.solve()\n", + "outer_loss = optimal_inner_net.x.mean()\n", + "\n", + "# Derive the meta-gradient\n", + "torch.autograd.grad(outer_loss, meta_net.parameters())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.15" + }, + "vscode": { + "interpreter": { + "hash": "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/6_Zero_Order_Differentiation.ipynb b/tutorials/6_Zero_Order_Differentiation.ipynb new file mode 100644 index 00000000..d6cb028c --- /dev/null +++ b/tutorials/6_Zero_Order_Differentiation.ipynb @@ -0,0 +1,356 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8850c832-3b54-4971-8ee0-2cd64b585ea8", + "metadata": {}, + "source": [ + "# TorchOpt for Zero-Order Differentiation" + ] + }, + { + "cell_type": "markdown", + "id": "2b547376", + "metadata": {}, + "source": [ + "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/6_Zero_Order_Differentiation.ipynb)" + ] + }, + { + "cell_type": "markdown", + "id": "8d7f9865-dc02-43d4-be90-da1160c4e4dd", + "metadata": {}, + "source": [ + "When the inner-loop process is non-differentiable or one wants to eliminate the heavy computation burdens in the previous two modes (brought by Hessian), one can choose ZD. ZD typically gets gradients based on zero-order estimation, such as finite-difference, or Evolutionary Strategy.\n", + "\n", + "TorchOpt offers API for ES-based differentiation. Instead of optimizing the objective $f (\\boldsymbol{\\theta}): \\mathbb{R}^n \\to \\mathbb{R}$, ES optimizes a Gaussion smoothing objective defined as $\\tilde{f}_{\\sigma} (\\boldsymbol{\\theta}) = \\mathbb{E}_{\\boldsymbol{z} \\sim \\mathcal{N}( 0, {I}_d )} [ f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) ]$, where $\\sigma$ denotes precision. The gradient of such objective is $\\nabla_{\\boldsymbol{\\theta}} \\tilde{f}_{\\sigma} (\\boldsymbol{\\theta}) = \\frac{1}{\\sigma} \\mathbb{E}_{\\boldsymbol{z} \\sim \\mathcal{N}( 0, {I}_d )} [ f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) \\cdot \\boldsymbol{z} ]$. Refer to [ES-MAML](https://arxiv.org/pdf/1910.01215.pdf) for more details." + ] + }, + { + "cell_type": "markdown", + "id": "d7e4b9e1-115f-45ad-a9b3-ea338bcfe6dd", + "metadata": {}, + "source": [ + "In this tutorial, we will introduce how TorchOpt can be used to ES-based differentiation." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "8f13ae67-e328-409f-84a8-1fc425c03a66", + "metadata": {}, + "outputs": [], + "source": [ + "import functorch\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "import torchopt" + ] + }, + { + "cell_type": "markdown", + "id": "0cdaac49-4b94-4900-9bb5-a39057ac8b21", + "metadata": {}, + "source": [ + "## 1. Functional API\n", + "\n", + "The basic functional API is `torchopt.diff.zero_order.zero_order`, which is used as the decorator for the forward process zero-order gradient procedures. Users are required to implement the noise sampling function, which will be used as the input of zero_order decorator. Here we show the specific meaning for each parameter used in the decorator.\n", + "\n", + "- `distribution` for noise sampling distribution. The distribution $\\lambda$ should be spherical symmetric and with a constant variance of $1$ for each element. I.e.:\n", + "\n", + " - Spherical symmetric: $\\mathbb{E}_{\\boldsymbol{z} \\sim \\lambda} [ \\boldsymbol{z} ] = \\boldsymbol{0}$.\n", + " - Constant variance of $1$ for each element: $\\mathbb{E}_{\\boldsymbol{z} \\sim \\lambda} [ {\\lvert z_i \\rvert}^2 ] = 1$.\n", + " - For example, the standard multi-dimensional normal distribution $\\mathcal{N} (\\boldsymbol{0}, \\boldsymbol{1})$.\n", + "\n", + "- `method` for different kind of algorithms, we support `'naive'` ([ES-RL](https://arxiv.org/abs/1703.03864)), `'forward'` ([Forward-FD](http://proceedings.mlr.press/v80/choromanski18a/choromanski18a.pdf)), and `'antithetic'` ([antithetic](https://arxiv.org/abs/1803.07055)).\n", + "\n", + " $$\n", + " \\begin{align*}\n", + " \\text{naive} \\qquad & \\nabla_{\\boldsymbol{\\theta}} \\tilde{f}_{\\sigma} (\\boldsymbol{\\theta}) = \\frac{1}{\\sigma} \\mathbb{E}_{\\boldsymbol{z} \\sim \\lambda} [ f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) \\cdot \\boldsymbol{z} ] \\\\\n", + " \\text{forward} \\qquad & \\nabla_{\\boldsymbol{\\theta}} \\tilde{f}_{\\sigma} (\\boldsymbol{\\theta}) = \\frac{1}{\\sigma} \\mathbb{E}_{\\boldsymbol{z} \\sim \\lambda} [ ( f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) - f (\\boldsymbol{\\theta}) ) \\cdot \\boldsymbol{z} ] \\\\\n", + " \\text{antithetic} \\qquad & \\nabla_{\\boldsymbol{\\theta}} \\tilde{f}_{\\sigma} (\\boldsymbol{\\theta}) = \\frac{1}{2 \\sigma} \\mathbb{E}_{\\boldsymbol{z} \\sim \\lambda} [ (f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) - f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) ) \\cdot \\boldsymbol{z} ]\n", + " \\end{align*}\n", + " $$\n", + "\n", + "- `argnums` specifies which parameter we want to trace the meta-gradient.\n", + "- `num_samples` specifies how many times we want to conduct the sampling.\n", + "- `sigma` is for precision. This is the scaling factor for the sampling distribution.\n", + "\n", + "We show the pseudo code in the following part.\n", + "\n", + "```python\n", + "# Functional API for zero-order differentiation\n", + "# 1. Customize the noise distribution via a distribution class\n", + "class Distribution:\n", + " def sample(self, sample_shape=torch.Size()):\n", + " # Sampling function for noise\n", + " # NOTE: The distribution should be spherical symmetric and with a constant variance of 1.\n", + " ...\n", + " return noise_batch\n", + "\n", + "distribution = Distribution()\n", + "\n", + "# 2. Customize the noise distribution via a sampling function\n", + "def distribution(sample_shape=torch.Size()):\n", + " # Sampling function for noise\n", + " # NOTE: The distribution should be spherical symmetric and with a constant variance of 1.\n", + " ...\n", + " return noise_batch\n", + "\n", + "# 3. Distribution can also be an instance of `torch.distributions.Distribution`, e.g., `torch.distributions.Normal(...)`\n", + "distribution = torch.distributions.Normal(loc=0, scale=1)\n", + "\n", + "# Decorator that wraps the function\n", + "@torchopt.diff.zero_order(distribution=distribution, method='naive', argnums=0, num_samples=100, sigma=0.01)\n", + "def forward(params, data):\n", + " # Forward optimization process for params\n", + " ...\n", + " return objective # the returned tensor should be a scalar tensor\n", + "\n", + "# Define params and get data\n", + "params, data = ..., ...\n", + "\n", + "# Forward pass\n", + "loss = forward(params, data)\n", + "# Backward pass using zero-order differentiation\n", + "grads = torch.autograd.grad(loss, params)\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "dbef87df-2164-4f1d-8919-37a6fbdc5011", + "metadata": {}, + "source": [ + "Here we use the example of a linear layer as an example, note that this is just an example to show linear layer can work with ES." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8d623b2f-48ee-4df6-a2ce-cf306b4c9067", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "001: tensor(0.0265, grad_fn=)\n", + "002: tensor(0.0243, grad_fn=)\n", + "003: tensor(0.0222, grad_fn=)\n", + "004: tensor(0.0202, grad_fn=)\n", + "005: tensor(0.0184, grad_fn=)\n", + "006: tensor(0.0170, grad_fn=)\n", + "007: tensor(0.0157, grad_fn=)\n", + "008: tensor(0.0146, grad_fn=)\n", + "009: tensor(0.0137, grad_fn=)\n", + "010: tensor(0.0130, grad_fn=)\n", + "011: tensor(0.0123, grad_fn=)\n", + "012: tensor(0.0118, grad_fn=)\n", + "013: tensor(0.0114, grad_fn=)\n", + "014: tensor(0.0111, grad_fn=)\n", + "015: tensor(0.0111, grad_fn=)\n", + "016: tensor(0.0111, grad_fn=)\n", + "017: tensor(0.0113, grad_fn=)\n", + "018: tensor(0.0115, grad_fn=)\n", + "019: tensor(0.0118, grad_fn=)\n", + "020: tensor(0.0120, grad_fn=)\n", + "021: tensor(0.0121, grad_fn=)\n", + "022: tensor(0.0121, grad_fn=)\n", + "023: tensor(0.0122, grad_fn=)\n", + "024: tensor(0.0122, grad_fn=)\n", + "025: tensor(0.0122, grad_fn=)\n" + ] + } + ], + "source": [ + "torch.random.manual_seed(0)\n", + "\n", + "fmodel, params = functorch.make_functional(nn.Linear(32, 1))\n", + "x = torch.randn(64, 32) * 0.1\n", + "y = torch.randn(64, 1) * 0.1\n", + "distribution = torch.distributions.Normal(loc=0, scale=1)\n", + "\n", + "\n", + "@torchopt.diff.zero_order(\n", + " distribution=distribution, method='forward', argnums=0, num_samples=100, sigma=0.01\n", + ")\n", + "def forward_process(params, fn, x, y):\n", + " y_pred = fn(params, x)\n", + " loss = F.mse_loss(y_pred, y)\n", + " return loss\n", + "\n", + "\n", + "optimizer = torchopt.adam(lr=0.01)\n", + "opt_state = optimizer.init(params) # init optimizer\n", + "\n", + "for i in range(25):\n", + " loss = forward_process(params, fmodel, x, y) # compute loss\n", + "\n", + " grads = torch.autograd.grad(loss, params) # compute gradients\n", + " updates, opt_state = optimizer.update(grads, opt_state) # get updates\n", + " params = torchopt.apply_updates(params, updates) # update network parameters\n", + "\n", + " print(f'{i + 1:03d}: {loss!r}')" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "db723f6b", + "metadata": {}, + "source": [ + "## 2. OOP API\n", + "\n", + "The basic OOP API is the class `ZeroOrderGradientModule`. We make the network as an `nn.Module` following a classical PyTorch style. Users need to define the forward process zero-order gradient procedures `forward()` and a noise sampling function `sample()`. Here we show the specific meaning for each parameter used in the class.\n", + "\n", + "- `method` for different kind of algorithms, we support `'naive'` ([ES-RL](https://arxiv.org/abs/1703.03864)), `'forward'` ([Forward-FD](http://proceedings.mlr.press/v80/choromanski18a/choromanski18a.pdf)), and `'antithetic'` ([antithetic](https://d1wqtxts1xzle7.cloudfront.net/75609515/coredp2011_1web-with-cover-page-v2.pdf?Expires=1670215467&Signature=RfP~mQhhhI7aGknwXbRBgSggFrKuNTPYdyUSdMmfTxOa62QoOJAm-Xhr3F1PLyjUQc2JVxmKIKGGuyYvyfCTpB31dfmMtuVQxZMWVF-SfErTN05SliC93yjA1x1g2kjhn8bkBFdQqGl~1RQSKnhj88BakgSeDNzyCxwbD5VgR89BXRs4YIK5RBIKYtgLhoyz5jar7wHS3TJhRzs3WNeTIAjAmLqJ068oGFZ0Jr7maGquTe3w~8LEEIprJ6cyCMc6b1UUJkmwjNq0RLTVbxgFjfi4Z9kyxyJB9IOS1J25OOON4jfwh5JlXS7MVskuONUyHJim1TQ8OwCraKlBsQLPQw__&Key-Pair-Id=APKAJLOHF5GGSLRBV4ZA)).\n", + "- `num_samples` specifies how many times we want to conduct the sampling.\n", + "- `sigma` is for precision. This is the scaling factor for the sampling distribution.\n", + "\n", + "We show the pseudo code in the following part.\n", + "\n", + "```python\n", + "from torchopt.nn import ZeroOrderGradientModule\n", + "\n", + "# Inherited from the class ZeroOrderGradientModule\n", + "# Optionally specify the `method` and/or `num_samples` and/or `sigma` used for sampling\n", + "class Net(ZeroOrderGradientModule, method='naive', num_samples=100, sigma=0.01):\n", + " def __init__(self, ...):\n", + " ...\n", + "\n", + " def forward(self, batch):\n", + " # Forward process\n", + " ...\n", + " return objective # the returned tensor should be a scalar tensor\n", + "\n", + " def sample(self, sample_shape=torch.Size()):\n", + " # Generate a batch of noise samples\n", + " # NOTE: The distribution should be spherical symmetric and with a constant variance of 1.\n", + " ...\n", + " return noise_batch\n", + "\n", + "# Get model and data\n", + "net = Net(...)\n", + "data = ...\n", + "\n", + "# Forward pass\n", + "loss = Net(data)\n", + "# Backward pass using zero-order differentiation\n", + "grads = torch.autograd.grad(loss, net.parameters())\n", + "```" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "b53524f5", + "metadata": {}, + "source": [ + "Here we reimplement the functional API example above with the OOP API." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ecc5730c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "001: tensor(0.0201, grad_fn=)\n", + "002: tensor(0.0181, grad_fn=)\n", + "003: tensor(0.0167, grad_fn=)\n", + "004: tensor(0.0153, grad_fn=)\n", + "005: tensor(0.0142, grad_fn=)\n", + "006: tensor(0.0133, grad_fn=)\n", + "007: tensor(0.0125, grad_fn=)\n", + "008: tensor(0.0119, grad_fn=)\n", + "009: tensor(0.0116, grad_fn=)\n", + "010: tensor(0.0114, grad_fn=)\n", + "011: tensor(0.0112, grad_fn=)\n", + "012: tensor(0.0112, grad_fn=)\n", + "013: tensor(0.0113, grad_fn=)\n", + "014: tensor(0.0116, grad_fn=)\n", + "015: tensor(0.0118, grad_fn=)\n", + "016: tensor(0.0121, grad_fn=)\n", + "017: tensor(0.0123, grad_fn=)\n", + "018: tensor(0.0125, grad_fn=)\n", + "019: tensor(0.0127, grad_fn=)\n", + "020: tensor(0.0127, grad_fn=)\n", + "021: tensor(0.0125, grad_fn=)\n", + "022: tensor(0.0123, grad_fn=)\n", + "023: tensor(0.0120, grad_fn=)\n", + "024: tensor(0.0118, grad_fn=)\n", + "025: tensor(0.0117, grad_fn=)\n" + ] + } + ], + "source": [ + "torch.random.manual_seed(0)\n", + "\n", + "\n", + "class Net(torchopt.nn.ZeroOrderGradientModule, method='forward', num_samples=100, sigma=0.01):\n", + " def __init__(self, dim):\n", + " super().__init__()\n", + " self.fc = nn.Linear(dim, 1)\n", + " self.distribution = torch.distributions.Normal(loc=0, scale=1)\n", + "\n", + " def forward(self, x, y):\n", + " y_pred = self.fc(x)\n", + " loss = F.mse_loss(y_pred, y)\n", + " return loss\n", + "\n", + " def sample(self, sample_shape=torch.Size()):\n", + " return self.distribution.sample(sample_shape)\n", + "\n", + "\n", + "x = torch.randn(64, 32) * 0.1\n", + "y = torch.randn(64, 1) * 0.1\n", + "net = Net(dim=32)\n", + "\n", + "\n", + "optimizer = torchopt.Adam(net.parameters(), lr=0.01)\n", + "\n", + "for i in range(25):\n", + " loss = net(x, y) # compute loss\n", + "\n", + " optimizer.zero_grad()\n", + " loss.backward() # backward pass\n", + " optimizer.step() # update network parameters\n", + "\n", + " print(f'{i + 1:03d}: {loss!r}')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.15 ('torchopt')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.15" + }, + "vscode": { + "interpreter": { + "hash": "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/requirements.txt b/tutorials/requirements.txt index 5fe3b1ad..e8a3be95 100644 --- a/tutorials/requirements.txt +++ b/tutorials/requirements.txt @@ -1,8 +1,11 @@ ---extra-index-url https://download.pytorch.org/whl/cu116 -torch >= 1.12 +--extra-index-url https://download.pytorch.org/whl/cu121 +# Sync with project.dependencies +torch >= 2.0 torchvision -functorch >= 0.2 --requirement ../requirements.txt ipykernel +jax[cpu] >= 0.4 +jaxopt +optax pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy