diff --git a/.gitattributes b/.gitattributes index a894e29e..1d0afc65 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1 +1,8 @@ +* text eol=lf *.ipynb linguist-detectable=false + +*.png binary +*.jpg binary +*.jpeg binary +*.gif binary +*.pdf binary diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml new file mode 100644 index 00000000..4b90bb84 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -0,0 +1,119 @@ +name: 🐛 Bug Report +description: File an issue about a bug. +title: "[BUG] " +labels: [bug] +assignees: [Benjamin-eecs] +body: + - type: markdown + attributes: + value: | + Please do your best to make the issue as easy to act on as possible, and only submit here if there is clearly a problem with TorchOpt (ask in [Discussions](https://github.com/metaopt/torchopt/discussions) first if unsure). + + - type: checkboxes + id: steps + attributes: + label: Required prerequisites + description: Make sure you've completed the following steps before submitting your issue -- thank you! + options: + - label: I have read the documentation . + required: true + - label: I have searched the [Issue Tracker](https://github.com/metaopt/torchopt/issues) and [Discussions](https://github.com/metaopt/torchopt/discussions) that this hasn't already been reported. (+1 or comment there if it has.) + required: true + - label: Consider asking first in a [Discussion](https://github.com/metaopt/torchopt/discussions/new). + required: false + + - type: input + id: version + attributes: + label: | + What version of TorchOpt are you using? + value: | + python3 -m pip show torchopt + validations: + required: true + + - type: textarea + id: system-info + attributes: + label: System information + value: | + Describe the characteristic of your environment: + + - Describe how the library was installed (pip, conda, source, ...) + - Python version + - Versions of any other relevant libraries + + ```python + import sys, torch, functorch, torchopt + print(sys.version, sys.platform) + print(torchopt.__version__, torch.__version__, functorch.__version__) + ``` + validations: + required: true + + - type: textarea + id: description + attributes: + label: Problem description + placeholder: | + Provide a short description, state the expected behavior and what actually happens. Include + relevant information like what version of TorchOpt you are using, what system you are on, + and any useful commands / output. + validations: + required: true + + - type: textarea + id: code + attributes: + label: Reproducible example code + value: | + + + The Python snippets: + + ```python + + ``` + + Run the snippets with the following commands: + + ```bash + + ``` + + Extra dependencies: + + ```text + + ``` + validations: + required: true + + - type: textarea + id: traceback + attributes: + label: Traceback + placeholder: | + Put the Python traceback information here. + + Traceback (most recent call last): + File ... + render: pytb + + - type: textarea + id: expected + attributes: + label: Expected behavior + placeholder: | + Provide a clear and concise description of what you expected to happen. + + - type: textarea + id: additional-context + attributes: + label: Additional context + placeholder: | + Add any other context about the problem here. Screenshots may also be helpful. + + If you know or suspect the reason for this bug, paste the code lines and suggest modifications. diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md deleted file mode 100644 index 86dcfbcb..00000000 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ /dev/null @@ -1,64 +0,0 @@ ---- -name: Bug report -about: Create a report to help us improve -title: "[BUG]" -labels: ["bug"] -assignees: Benjamin-eecs - ---- - -## Describe the bug - -A clear and concise description of what the bug is. - -## To Reproduce - -Steps to reproduce the behavior. - -Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful. - -Please use the markdown code blocks for both code and stack traces. - -```python -import torchopt -``` - -```pytb -Traceback (most recent call last): - File ... -``` - -## Expected behavior - -A clear and concise description of what you expected to happen. - -## Screenshots - -If applicable, add screenshots to help explain your problem. - -## System info - -Describe the characteristic of your environment: - -- Describe how the library was installed (pip, source, ...) -- Python version -- Versions of any other relevant libraries - -```python -import torchopt, numpy, sys -print(torchopt.__version__, numpy.__version__, sys.version, sys.platform) -``` - -## Additional context - -Add any other context about the problem here. - -## Reason and Possible fixes - -If you know or suspect the reason for this bug, paste the code lines and suggest modifications. - -## Checklist - -- [ ] I have checked that there is no similar issue in the repo (**required**) -- [ ] I have read the [documentation](https://torchopt.readthedocs.io/) (**required**) -- [ ] I have provided a minimal working example to reproduce the bug (**required**) diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 00000000..a3b57cdc --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,5 @@ +blank_issues_enabled: false +contact_links: + - name: 💬 Start a discussion + url: https://github.com/metaopt/torchopt/discussions/new + about: Please ask and answer questions here if unsure. diff --git a/.github/ISSUE_TEMPLATE/feature-request.yml b/.github/ISSUE_TEMPLATE/feature-request.yml new file mode 100644 index 00000000..959ec909 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature-request.yml @@ -0,0 +1,48 @@ +name: ✨ Feature Request +description: Suggest an idea for this project. +title: "[Feature Request] " +labels: [enhancement] +assignees: [Benjamin-eecs] +body: + - type: checkboxes + id: steps + attributes: + label: Required prerequisites + description: Make sure you've completed the following steps before submitting your issue -- thank you! + options: + - label: I have searched the [Issue Tracker](https://github.com/metaopt/torchopt/issues) and [Discussions](https://github.com/metaopt/torchopt/discussions) that this hasn't already been reported. (+1 or comment there if it has.) + required: true + - label: Consider asking first in a [Discussion](https://github.com/metaopt/torchopt/discussions/new). + required: false + + - type: textarea + id: motivation + attributes: + label: Motivation + value: | + + validations: + required: true + + - type: textarea + id: solution + attributes: + label: Solution + placeholder: | + Provide a clear and concise description of what you want to happen. + + - type: textarea + id: alternatives + attributes: + label: Alternatives + placeholder: | + A clear and concise description of any alternative solutions or features you've considered. + + - type: textarea + id: additional-context + attributes: + label: Additional context + placeholder: | + Add any other context about the problem here. Screenshots may also be helpful. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md deleted file mode 100644 index b61aa154..00000000 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ /dev/null @@ -1,30 +0,0 @@ ---- -name: Feature request -about: Suggest an idea for this project -title: "[Feature Request]" -labels: ["enhancement"] -assignees: Benjamin-eecs - ---- - -## Motivation - -Please outline the motivation for the proposal. -Is your feature request related to a problem? e.g., "I'm always frustrated when [...]". -If this is related to another issue, please link here too. - -## Solution - -A clear and concise description of what you want to happen. - -## Alternatives - -A clear and concise description of any alternative solutions or features you've considered. - -## Additional context - -Add any other context or screenshots about the feature request here. - -## Checklist - -- [ ] I have checked that there is no similar issue in the repo (**required**) diff --git a/.github/ISSUE_TEMPLATE/questions.yml b/.github/ISSUE_TEMPLATE/questions.yml new file mode 100644 index 00000000..33968b1e --- /dev/null +++ b/.github/ISSUE_TEMPLATE/questions.yml @@ -0,0 +1,27 @@ +name: 🤔 Questions / Help / Support +description: Do you need support? +title: "[Question] " +labels: [question] +assignees: [Benjamin-eecs] +body: + - type: checkboxes + id: steps + attributes: + label: Required prerequisites + description: Make sure you've completed the following steps before submitting your issue -- thank you! + options: + - label: I have read the documentation . + required: true + - label: I have searched the [Issue Tracker](https://github.com/metaopt/torchopt/issues) and [Discussions](https://github.com/metaopt/torchopt/discussions) that this hasn't already been reported. (+1 or comment there if it has.) + required: true + - label: Consider asking first in a [Discussion](https://github.com/metaopt/torchopt/discussions/new). + required: false + + - type: textarea + id: questions + attributes: + label: Questions + placeholder: | + Describe your questions with relevant resources such as snippets, links, images, etc. + validations: + required: true diff --git a/.github/ISSUE_TEMPLATE/questions_help_support.md b/.github/ISSUE_TEMPLATE/questions_help_support.md deleted file mode 100644 index 072d2e52..00000000 --- a/.github/ISSUE_TEMPLATE/questions_help_support.md +++ /dev/null @@ -1,17 +0,0 @@ ---- -name: Questions / Help / Support -about: Do you need support? -title: "[Question]" -labels: "question" -assignees: Benjamin-eecs - ---- - -## Questions - - - -## Checklist - -- [ ] I have checked that there is no similar issue in the repo (**required**) -- [ ] I have read the [documentation](https://torchopt.readthedocs.io/) (**required**) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 807bd4bb..2709e055 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -31,10 +31,10 @@ What types of changes does your code introduce? Put an `x` in all the boxes that Go over all the following points, and put an `x` in all the boxes that apply. If you are unsure about any of these, don't hesitate to ask. We are here to help! -- [ ] I have read the [CONTRIBUTION](https://torchopt.readthedocs.io/en/latest/developer/contributing.html) guide (**required**) +- [ ] I have read the [CONTRIBUTION](https://torchopt.readthedocs.io/en/latest/developer/contributing.html) guide. (**required**) - [ ] My change requires a change to the documentation. -- [ ] I have updated the tests accordingly (*required for a bug fix or a new feature*). +- [ ] I have updated the tests accordingly. (*required for a bug fix or a new feature*) - [ ] I have updated the documentation accordingly. -- [ ] I have reformatted the code using `make format` (**required**) -- [ ] I have checked the code using `make lint` (**required**) +- [ ] I have reformatted the code using `make format`. (**required**) +- [ ] I have checked the code using `make lint`. (**required**) - [ ] I have ensured `make test` pass. (**required**) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 72dd012a..93539731 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -37,54 +37,154 @@ concurrency: cancel-in-progress: ${{ github.event_name == 'pull_request' }} env: - CUDA_VERSION: "11.6" - TEST_TORCH_SPECS: "cpu cu113 cu116" + CUDA_VERSION: "11.7" + TEST_TORCH_SPECS: "cpu cu116" jobs: - build-sdist: + build: + name: Build sdist and pure-Python wheel runs-on: ubuntu-latest if: github.repository == 'metaopt/torchopt' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/')) - timeout-minutes: 10 + timeout-minutes: 60 steps: - name: Checkout uses: actions/checkout@v3 with: submodules: "recursive" - fetch-depth: 1 + fetch-depth: 0 - name: Set up Python uses: actions/setup-python@v4 with: - python-version: "3.7 - 3.10" + python-version: "3.7 - 3.10" # sync with requires-python in pyproject.toml update-environment: true + - name: Set __release__ + if: | + startsWith(github.ref, 'refs/tags/') || + (github.event_name == 'workflow_dispatch' && github.event.inputs.task == 'build-and-publish') + run: | + python .github/workflows/set_release.py + + - name: Print version + run: python setup.py --version + - name: Install dependencies run: python -m pip install --upgrade pip setuptools wheel build - - name: Build sdist - run: python -m build --sdist + - name: Build sdist and pure-Python wheel + run: python -m build + env: + TORCHOPT_NO_EXTENSIONS: "true" - name: Upload artifact uses: actions/upload-artifact@v3 with: - name: sdist - path: dist/*.tar.gz + name: build + path: dist/* + if-no-files-found: error + + - name: Install dependencies + run: | + python -m pip install -r tests/requirements.txt + + - name: Install TorchOpt + run: | + python -m pip install -vvv dist/*.whl + + - name: Test with pytest + run: | + make pytest + + build-wheels-py37: + name: Build wheels for Python ${{ matrix.python-version }} on ubuntu-latest + runs-on: ubuntu-latest + needs: [build] + if: github.repository == 'metaopt/torchopt' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/')) + strategy: + matrix: + python-version: ["3.7"] # sync with requires-python in pyproject.toml + fail-fast: false + timeout-minutes: 30 + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + submodules: "recursive" + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + update-environment: true + + - name: Set __release__ + if: | + startsWith(github.ref, 'refs/tags/') || + (github.event_name == 'workflow_dispatch' && github.event.inputs.task == 'build-and-publish') + run: python .github/workflows/set_release.py + + - name: Print version + run: python setup.py --version + + - name: Set CIBW_BUILD + run: python .github/workflows/set_cibw_build.py + + - name: Build wheels + uses: pypa/cibuildwheel@v2.11.2 + env: + CIBW_BUILD: ${{ env.CIBW_BUILD }} + with: + package-dir: . + output-dir: wheelhouse + config-file: "{package}/pyproject.toml" + + - uses: actions/upload-artifact@v3 + with: + name: wheels-py37 + path: wheelhouse/*.whl if-no-files-found: error build-wheels: + name: Build wheels for Python ${{ matrix.python-version }} on ubuntu-latest runs-on: ubuntu-latest - needs: [build-sdist] + needs: [build, build-wheels-py37] if: github.repository == 'metaopt/torchopt' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/')) - timeout-minutes: 90 + strategy: + matrix: + python-version: ["3.8", "3.9", "3.10"] # sync with requires-python in pyproject.toml + fail-fast: false + timeout-minutes: 30 steps: - name: Checkout uses: actions/checkout@v3 with: submodules: "recursive" - fetch-depth: 1 + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + update-environment: true + + - name: Set __release__ + if: | + startsWith(github.ref, 'refs/tags/') || + (github.event_name == 'workflow_dispatch' && github.event.inputs.task == 'build-and-publish') + run: python .github/workflows/set_release.py + + - name: Print version + run: python setup.py --version + + - name: Set CIBW_BUILD + run: python .github/workflows/set_cibw_build.py - name: Build wheels - uses: pypa/cibuildwheel@v2.8.1 + uses: pypa/cibuildwheel@v2.11.2 + env: + CIBW_BUILD: ${{ env.CIBW_BUILD }} with: package-dir: . output-dir: wheelhouse @@ -98,20 +198,34 @@ jobs: publish: runs-on: ubuntu-latest - needs: [build-sdist, build-wheels] + needs: [build, build-wheels-py37, build-wheels] if: | github.repository == 'metaopt/torchopt' && github.event_name != 'pull_request' && (github.event_name != 'workflow_dispatch' || github.event.inputs.task == 'build-and-publish') && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/')) timeout-minutes: 15 steps: + - name: Checkout + uses: actions/checkout@v3 + with: + submodules: "recursive" + fetch-depth: 0 + - name: Set up Python uses: actions/setup-python@v4 if: startsWith(github.ref, 'refs/tags/') with: - python-version: "3.7 - 3.10" + python-version: "3.7 - 3.11" # sync with requires-python in pyproject.toml update-environment: true + - name: Set __release__ + if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' + run: | + python .github/workflows/set_release.py + + - name: Print version + run: python setup.py --version + - name: Check consistency between the package version and release tag if: startsWith(github.ref, 'refs/tags/') run: | @@ -127,7 +241,15 @@ jobs: with: # unpacks default artifact into dist/ # if `name: artifact` is omitted, the action will create extra parent dir - name: sdist + name: build + path: dist + + - name: Download built wheels + uses: actions/download-artifact@v3 + with: + # unpacks default artifact into dist/ + # if `name: artifact` is omitted, the action will create extra parent dir + name: wheels-py37 path: dist - name: Download built wheels @@ -138,9 +260,12 @@ jobs: name: wheels path: dist + - name: List distributions + run: ls -lh dist/* + - name: Publish to TestPyPI if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' - uses: pypa/gh-action-pypi-publish@v1.5.0 + uses: pypa/gh-action-pypi-publish@release/v1 with: user: __token__ password: ${{ secrets.TESTPYPI_UPLOAD_TOKEN }} @@ -151,7 +276,7 @@ jobs: - name: Publish to PyPI if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' - uses: pypa/gh-action-pypi-publish@v1.5.0 + uses: pypa/gh-action-pypi-publish@release/v1 with: user: __token__ password: ${{ secrets.PYPI_UPLOAD_TOKEN }} diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 44ece663..92d6036f 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -26,17 +26,17 @@ jobs: submodules: "recursive" fetch-depth: 1 - - name: Set up Python 3.7 # the lowest version we support + - name: Set up Python 3.7 uses: actions/setup-python@v4 with: - python-version: "3.7" + python-version: "3.7" # the lowest version we support (sync with requires-python in pyproject.toml) update-environment: true - name: Setup CUDA Toolkit - uses: Jimver/cuda-toolkit@v0.2.7 + uses: Jimver/cuda-toolkit@v0.2.8 id: cuda-toolkit with: - cuda: "11.6.2" + cuda: "11.7.0" method: network sub-packages: '["nvcc"]' - run: | diff --git a/.github/workflows/set_cibw_build.py b/.github/workflows/set_cibw_build.py new file mode 100755 index 00000000..03838b4a --- /dev/null +++ b/.github/workflows/set_cibw_build.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python3 + +# pylint: disable=missing-module-docstring + +import os +import sys + + +# pylint: disable-next=consider-using-f-string +CIBW_BUILD = 'CIBW_BUILD=*cp%d%d-*manylinux*' % sys.version_info[:2] + +print(CIBW_BUILD) +with open(os.getenv('GITHUB_ENV'), mode='a', encoding='UTF-8') as file: + print(CIBW_BUILD, file=file) diff --git a/.github/workflows/set_release.py b/.github/workflows/set_release.py new file mode 100755 index 00000000..568a38e2 --- /dev/null +++ b/.github/workflows/set_release.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 + +# pylint: disable=missing-module-docstring + +import pathlib +import re + + +ROOT = pathlib.Path(__file__).absolute().parent.parent.parent + +VERSION_FILE = ROOT / 'torchopt' / 'version.py' + +VERSION_CONTENT = VERSION_FILE.read_text(encoding='UTF-8') + +VERSION_FILE.write_text( + data=re.sub( + r'__release__\s*=.*', + '__release__ = True', + string=VERSION_CONTENT, + ), + encoding='UTF-8', +) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c36e78f2..67732041 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -28,6 +28,7 @@ concurrency: jobs: test: + name: Test with CXX/CUDA extensions on ubuntu-latest runs-on: ubuntu-latest timeout-minutes: 60 steps: @@ -37,17 +38,17 @@ jobs: submodules: "recursive" fetch-depth: 1 - - name: Set up Python 3.7 # the lowest version we support + - name: Set up Python 3.7 uses: actions/setup-python@v4 with: - python-version: "3.7" + python-version: "3.7" # the lowest version we support (sync with requires-python in pyproject.toml) update-environment: true - name: Setup CUDA Toolkit - uses: Jimver/cuda-toolkit@v0.2.7 + uses: Jimver/cuda-toolkit@v0.2.8 id: cuda-toolkit with: - cuda: "11.6.2" + cuda: "11.7.0" method: network sub-packages: '["nvcc"]' - run: | @@ -81,10 +82,49 @@ jobs: make pytest - name: Upload coverage to Codecov - uses: codecov/codecov-action@v1 + uses: codecov/codecov-action@v3 with: - token: ${{ secrets.CODECOV }} + token: ${{ secrets.CODECOV_TOKEN }} file: ./tests/coverage.xml flags: unittests name: codecov-umbrella fail_ci_if_error: false + + test-pure-python: + name: Test for pure-Python on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + timeout-minutes: 60 + strategy: + matrix: + os: [ubuntu-latest, macos-latest] # jaxlib is not available on Windows + fail-fast: false + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + submodules: "recursive" + fetch-depth: 1 + + - name: Set up Python 3.7 + uses: actions/setup-python@v4 + with: + python-version: "3.7" # the lowest version we support (sync with requires-python in pyproject.toml) + update-environment: true + + - name: Upgrade pip + run: | + python -m pip install --upgrade pip setuptools wheel + + - name: Install dependencies + run: | + python -m pip install -r tests/requirements.txt + + - name: Install TorchOpt + run: | + python -m pip install -vvv -e . + env: + TORCHOPT_NO_EXTENSIONS: "true" + + - name: Test with pytest + run: | + make pytest diff --git a/.gitignore b/.gitignore index a0107f9b..62b1adbc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,5 @@ -##### Project specific ##### -!torchopt/_src/ -!torchopt/_lib/ +##### Project Specific ##### +third-party/ ##### Python.gitignore ##### # Byte-compiled / optimized / DLL files @@ -31,6 +30,7 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST +*.whl # PyInstaller # Usually these files are written by a python script from a template diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 21062f0e..316271e6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ # See https://pre-commit.com/hooks.html for more hooks repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.4.0 hooks: - id: check-symlinks - id: destroyed-symlinks @@ -24,10 +24,20 @@ repos: - id: isort stages: [commit, push, manual] - repo: https://github.com/psf/black - rev: 22.8.0 + rev: 22.10.0 hooks: - - id: black + - id: black-jupyter stages: [commit, push, manual] + - repo: https://github.com/asottile/pyupgrade + rev: v3.3.0 + hooks: + - id: pyupgrade + args: [--py37-plus] # sync with requires-python + stages: [commit, push, manual] + exclude: | + (?x)( + ^examples/ + ) - repo: local hooks: - id: pylint diff --git a/.pylintrc b/.pylintrc index e55faae7..f0846434 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,4 +1,22 @@ -[MASTER] +[MAIN] + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Load and enable all available extensions. Use --list-extensions to see a list +# all available extensions. +#enable-all-extensions= + +# In error mode, messages with a category besides ERROR or FATAL are +# suppressed, and no reports are done by default. Error mode is compatible with +# disabling specific errors. +#errors-only= + +# Always return a 0 (non-error) status code, even if lint errors are found. +# This is primarily useful in continuous integration scripts. +#exit-zero= # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may @@ -16,28 +34,41 @@ extension-pkg-whitelist= # specified are enabled, while categories only check already-enabled messages. fail-on= -# Specify a score threshold to be exceeded before program exits with error. -fail-under=10.0 +# Specify a score threshold under which the program will exit with error. +fail-under=10 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +#from-stdin= # Files or directories to be skipped. They should be base names, not paths. ignore=CVS,.vscode,.history -# Add files or directories matching the regex patterns to the ignore-list. The -# regex matches against paths and can be in Posix or Windows format. +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\' represents the directory delimiter on Windows systems, it +# can't be used as an escape character. ignore-paths=^_C/$,^examples/$,^tests/$ -# Files or directories matching the regex patterns are skipped. The regex -# matches against base names, not paths. The default value ignores emacs file -# locks +# Files or directories matching the regular expression patterns are skipped. +# The regex matches against base names, not paths. The default value ignores +# Emacs file locks ignore-patterns=^\.# +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + # Python code to execute, usually for sys.path manipulation such as # pygtk.require(). #init-hook= # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the -# number of processors available to use. -jobs=1 +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +jobs=0 # Control the amount of potential inferred values when inferring a single # object. This can help the performance when dealing with large functions or @@ -53,7 +84,7 @@ persistent=yes # Minimum Python version to use for version dependent checks. Will default to # the version used to run pylint. -py-version=3.7 +py-version=3.7 # the lowest version we support (sync with requires-python in pyproject.toml) # Discover python modules and packages in the file system subtree. recursive=no @@ -66,115 +97,8 @@ suggestion-mode=yes # active Python interpreter and may run arbitrary code. unsafe-load-any-extension=no - -[MESSAGES CONTROL] - -# Only show warnings with the listed confidence levels. Leave empty to show -# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, -# UNDEFINED. -confidence= - -# Disable the message, report, category or checker with the given id(s). You -# can either give multiple identifiers separated by comma (,) or put this -# option multiple times (only on the command line, not in the configuration -# file where it should appear only once). You can also use "--disable=all" to -# disable everything first and then re-enable specific checks. For example, if -# you want to run only the similarities checker, you can use "--disable=all -# --enable=similarities". If you want to run only the classes checker, but have -# no Warning level messages displayed, use "--disable=all --enable=classes -# --disable=W". -disable=missing-module-docstring, - duplicate-code, - consider-using-from-import - -# Enable the message, report, category or checker with the given id(s). You can -# either give multiple identifier separated by comma (,) or put this option -# multiple time (only on the command line, not in the configuration file where -# it should appear only once). See also the "--disable" option for examples. -enable=c-extension-no-member - - -[REPORTS] - -# Python expression which should return a score less than or equal to 10. You -# have access to the variables 'error', 'warning', 'refactor', and 'convention' -# which contain the number of messages in each category, as well as 'statement' -# which is the total number of statements analyzed. This score is used by the -# global evaluation report (RP0004). -evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) - -# Template used to display messages. This is a python new-style format string -# used to format the message information. See doc for all details. -#msg-template= - -# Set the output format. Available formats are text, parseable, colorized, json -# and msvs (visual studio). You can also give a reporter class, e.g. -# mypackage.mymodule.MyReporterClass. -output-format=text - -# Tells whether to display a full report or only the messages. -reports=no - -# Activate the evaluation score. -score=yes - - -[REFACTORING] - -# Maximum number of nested blocks for function / method body -max-nested-blocks=5 - -# Complete name of functions that never returns. When checking for -# inconsistent-return-statements if a never returning function is called then -# it will be considered as an explicit return statement and no message will be -# printed. -never-returning-functions=sys.exit,argparse.parse_error - - -[STRING] - -# This flag controls whether inconsistent-quotes generates a warning when the -# character used as a quote delimiter is used inconsistently within a module. -check-quote-consistency=no - -# This flag controls whether the implicit-str-concat should generate a warning -# on implicit string concatenation in sequences defined over several lines. -check-str-concat-over-line-jumps=no - - -[MISCELLANEOUS] - -# List of note tags to take in consideration, separated by a comma. -notes=FIXME, - XXX, - TODO - -# Regular expression of note tags to take in consideration. -#notes-rgx= - - -[SPELLING] - -# Limits count of emitted suggestions for spelling mistakes. -max-spelling-suggestions=4 - -# Spelling dictionary name. Available dictionaries: none. To make it work, -# install the 'python-enchant' package. -spelling-dict= - -# List of comma separated words that should be considered directives if they -# appear and the beginning of a comment and should not be checked. -spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: - -# List of comma separated words that should not be checked. -spelling-ignore-words= - -# A path to a file that contains the private dictionary; one word per line. -spelling-private-dict-file= - -# Tells whether to store unknown words to the private dictionary (see the -# --spelling-private-dict-file option) instead of raising a message. -spelling-store-unknown-words=no +# In verbose mode, extra non-checker-related info will be displayed. +#verbose= [BASIC] @@ -266,7 +190,9 @@ good-names=i, t, lr, mu, - nu + nu, + x, + y # Good variable names regexes, separated by a comma. If names match any regex, # they will always be accepted @@ -323,158 +249,6 @@ variable-naming-style=snake_case #variable-rgx= -[LOGGING] - -# The type of string formatting that logging methods do. `old` means using % -# formatting, `new` is for `{}` formatting. -logging-format-style=old - -# Logging modules to check that the string format arguments are in logging -# function parameter format. -logging-modules=logging - - -[VARIABLES] - -# List of additional names supposed to be defined in builtins. Remember that -# you should avoid defining new builtins when possible. -additional-builtins= - -# Tells whether unused global variables should be treated as a violation. -allow-global-unused-variables=yes - -# List of names allowed to shadow builtins -allowed-redefined-builtins= - -# List of strings which can identify a callback function by name. A callback -# name must start or end with one of those strings. -callbacks=cb_, - _cb - -# A regular expression matching the name of dummy variables (i.e. expected to -# not be used). -dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ - -# Argument names that match this expression will be ignored. Default to name -# with leading underscore. -ignored-argument-names=_.*|^ignored_|^unused_ - -# Tells whether we should check for unused import in __init__ files. -init-import=no - -# List of qualified module names which can have objects that can redefine -# builtins. -redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io - - -[TYPECHECK] - -# List of decorators that produce context managers, such as -# contextlib.contextmanager. Add to this list to register other decorators that -# produce valid context managers. -contextmanager-decorators=contextlib.contextmanager - -# List of members which are set dynamically and missed by pylint inference -# system, and so shouldn't trigger E1101 when accessed. Python regular -# expressions are accepted. -generated-members=numpy.*, - torch.* - -# Tells whether missing members accessed in mixin class should be ignored. A -# class is considered mixin if its name matches the mixin-class-rgx option. -ignore-mixin-members=yes - -# Tells whether to warn about missing members when the owner of the attribute -# is inferred to be None. -ignore-none=yes - -# This flag controls whether pylint should warn about no-member and similar -# checks whenever an opaque object is returned when inferring. The inference -# can return multiple potential results while evaluating a Python object, but -# some branches might not be evaluated, which results in partial inference. In -# that case, it might be useful to still emit no-member and other checks for -# the rest of the inferred objects. -ignore-on-opaque-inference=yes - -# List of class names for which member attributes should not be checked (useful -# for classes with dynamically set attributes). This supports the use of -# qualified names. -ignored-classes=optparse.Values,thread._local,_thread._local - -# List of module names for which member attributes should not be checked -# (useful for modules/projects where namespaces are manipulated during runtime -# and thus existing member attributes cannot be deduced by static analysis). It -# supports qualified module names, as well as Unix pattern matching. -ignored-modules= - -# Show a hint with possible names when a member name was not found. The aspect -# of finding the hint is based on edit distance. -missing-member-hint=yes - -# The minimum edit distance a name should have in order to be considered a -# similar match for a missing member name. -missing-member-hint-distance=1 - -# The total number of similar names that should be taken in consideration when -# showing a hint for a missing member. -missing-member-max-choices=1 - -# Regex pattern to define which classes are considered mixins ignore-mixin- -# members is set to 'yes' -mixin-class-rgx=.*[Mm]ixin - -# List of decorators that change the signature of a decorated function. -signature-mutators= - - -[FORMAT] - -# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. -expected-line-ending-format= - -# Regexp for a line that is allowed to be longer than the limit. -ignore-long-lines=^\s*(# )??$ - -# Number of spaces of indent required inside a hanging or continued line. -indent-after-paren=4 - -# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 -# tab). -indent-string=' ' - -# Maximum number of characters on a single line. -max-line-length=100 - -# Maximum number of lines in a module. -max-module-lines=1000 - -# Allow the body of a class to be on the same line as the declaration if body -# contains single statement. -single-line-class-stmt=no - -# Allow the body of an if to be on the same line as the test if there is no -# else. -single-line-if-stmt=no - - -[SIMILARITIES] - -# Comments are removed from the similarity computation -ignore-comments=yes - -# Docstrings are removed from the similarity computation -ignore-docstrings=yes - -# Imports are removed from the similarity computation -ignore-imports=no - -# Signatures are removed from the similarity computation -ignore-signatures=no - -# Minimum lines number of a similarity. -min-similarity-lines=4 - - [CLASSES] # Warn about protected attribute access inside special methods @@ -542,6 +316,43 @@ max-statements=50 min-public-methods=2 +[EXCEPTIONS] + +# Exceptions that will emit a warning when caught. +overgeneral-exceptions=BaseException, + Exception + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=120 + +# Maximum number of lines in a module. +max-module-lines=1000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + [IMPORTS] # List of modules that can be imported at any level, not just the top level @@ -551,11 +362,6 @@ allow-any-import-level= # Allow wildcard imports from modules that define __all__. allow-wildcard-with-all=no -# Analyse import fallback blocks. This can be used to support both Python 2 and -# 3 compatible code, which means that the block might have code that exists -# only in one or another interpreter, leading to false positives when analysed. -analyse-fallback-blocks=no - # Deprecated modules which should not be used, separated by a comma. deprecated-modules= @@ -583,9 +389,241 @@ known-third-party=enchant preferred-modules= -[EXCEPTIONS] +[LOGGING] -# Exceptions that will emit a warning when being caught. Defaults to -# "BaseException, Exception". -overgeneral-exceptions=BaseException, - Exception +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, +# UNDEFINED. +confidence=HIGH, + CONTROL_FLOW, + INFERENCE, + INFERENCE_FAILURE, + UNDEFINED + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then re-enable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=duplicate-code, + consider-using-from-import + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[METHOD_ARGS] + +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +notes-rgx= + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit,argparse.parse_error + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each +# category, as well as 'statement' which is the total number of statements +# analyzed. This score is used by the global evaluation report (RP0004). +evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +#output-format= + +# Tells whether to display a full report or only the messages. +reports=no + +# Activate the evaluation score. +score=yes + + +[SIMILARITIES] + +# Comments are removed from the similarity computation +ignore-comments=yes + +# Docstrings are removed from the similarity computation +ignore-docstrings=yes + +# Imports are removed from the similarity computation +ignore-imports=yes + +# Signatures are removed from the similarity computation +ignore-signatures=yes + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: en_AU (hunspell), en_CA +# (hunspell), en_GB (hunspell), en_US (hunspell), en_ZA (hunspell). +spelling-dict=en_US + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file=docs/source/spelling_wordlist.txt + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members=numpy.*, + torch.* + +# Tells whether missing members accessed in mixin class should be ignored. A +# class is considered mixin if its name matches the mixin-class-rgx option. +ignore-mixin-members=yes + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins=no-member, + not-async-context-manager, + not-context-manager, + attribute-defined-outside-init + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx=.*[Mm]ixin + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of names allowed to shadow builtins +allowed-redefined-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io diff --git a/CHANGELOG.md b/CHANGELOG.md index 5334e26a..5d7adbb5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,39 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ------ +## [0.6.0] - 2022-12-07 + +### Added + +- Add unroll pragma for CUDA OPs by [@JieRen98](https://github.com/JieRen98) and [@XuehaiPan](https://github.com/XuehaiPan) in [#112](https://github.com/metaopt/torchopt/pull/112). +- Add Python implementation of accelerated OP and pure-Python wheels by [@XuehaiPan](https://github.com/XuehaiPan) in [#67](https://github.com/metaopt/torchopt/pull/67). +- Add `nan_to_num` hook and gradient transformation by [@XuehaiPan](https://github.com/XuehaiPan) in [#119](https://github.com/metaopt/torchopt/pull/119). +- Add matrix inversion linear solver with neumann series approximation by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#98](https://github.com/metaopt/torchopt/pull/98). +- Add if condition of number of threads for CPU OPs by [@JieRen98](https://github.com/JieRen98) in [#105](https://github.com/metaopt/torchopt/pull/105). +- Add implicit MAML omniglot few-shot classification example with OOP APIs by [@XuehaiPan](https://github.com/XuehaiPan) in [#107](https://github.com/metaopt/torchopt/pull/107). +- Add implicit MAML omniglot few-shot classification example by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#48](https://github.com/metaopt/torchopt/pull/48). +- Add object-oriented modules support for implicit meta-gradient by [@XuehaiPan](https://github.com/XuehaiPan) in [#101](https://github.com/metaopt/torchopt/pull/101). +- Bump PyTorch version to 1.13.0 by [@XuehaiPan](https://github.com/XuehaiPan) in [#104](https://github.com/metaopt/torchopt/pull/104). +- Add zero-order gradient estimation by [@JieRen98](https://github.com/JieRen98) in [#93](https://github.com/metaopt/torchopt/pull/93). +- Add RPC-based distributed training support and add distributed MAML example by [@XuehaiPan](https://github.com/XuehaiPan) in [#83](https://github.com/metaopt/torchopt/pull/83). +- Add full type hints by [@XuehaiPan](https://github.com/XuehaiPan) in [#92](https://github.com/metaopt/torchopt/pull/92). +- Add API documentation and tutorial for implicit gradients by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@JieRen98](https://github.com/JieRen98) and [@XuehaiPan](https://github.com/XuehaiPan) in [#73](https://github.com/metaopt/torchopt/pull/73). +- Add wrapper class for functional optimizers and examples of `functorch` integration by [@vmoens](https://github.com/vmoens) and [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#6](https://github.com/metaopt/torchopt/pull/6). +- Implicit differentiation support by [@JieRen98](https://github.com/JieRen98) and [@waterhorse1](https://github.com/waterhorse1) and [@XuehaiPan](https://github.com/XuehaiPan) in [#41](https://github.com/metaopt/torchopt/pull/41). + +### Changed + +- Refactor code organization by [@XuehaiPan](https://github.com/XuehaiPan) in [#92](https://github.com/metaopt/torchopt/pull/92) and [#100](https://github/metaopt/torchopt/pull/100). + +### Fixed + +- Fix implicit MAML omniglot few-shot classification example by [@XuehaiPan](https://github.com/XuehaiPan) in [#108](https://github.com/metaopt/torchopt/pull/108). +- Align results of distributed examples by [@XuehaiPan](https://github.com/XuehaiPan) in [#95](https://github.com/metaopt/torchopt/pull/95). +- Fix `None` in module containers by [@XuehaiPan](https://github.com/XuehaiPan). +- Fix backward errors when using inplace `sqrt_` and `add_` by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@JieRen98](https://github.com/JieRen98) and [@XuehaiPan](https://github.com/XuehaiPan). +- Fix LR scheduling by [@XuehaiPan](https://github.com/XuehaiPan) in [#76](https://github.com/metaopt/torchopt/pull/76). +- Fix the step count tensor (`shape=(1,)`) can change the shape of the scalar updates (`shape=()`) by [@XuehaiPan](https://github.com/XuehaiPan) in [#71](https://github.com/metaopt/torchopt/pull/71). + ## [0.5.0] - 2022-09-05 ### Added @@ -114,7 +147,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ------ -[Unreleased]: https://github.com/olivierlacan/keep-a-changelog/compare/v0.5.0...HEAD +[Unreleased]: https://github.com/olivierlacan/keep-a-changelog/compare/v0.6.0...HEAD +[0.6.0]: https://github.com/olivierlacan/keep-a-changelog/compare/v0.5.0...v0.6.0 [0.5.0]: https://github.com/olivierlacan/keep-a-changelog/compare/v0.4.3...v0.5.0 [0.4.3]: https://github.com/olivierlacan/keep-a-changelog/compare/v0.4.2...v0.4.3 [0.4.2]: https://github.com/olivierlacan/keep-a-changelog/compare/v0.4.1...v0.4.2 diff --git a/CITATION.cff b/CITATION.cff index b738a26c..aa997b82 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -20,6 +20,10 @@ authors: family-names: Pan email: xuehaipan@pku.edu.cn affiliation: Peking University + - given-names: Yao + family-names: Fu + email: f.yu@ed.ac.uk + affiliation: University of Edinburgh - given-names: Luo family-names: Mai email: luo.mai@ed.ac.uk @@ -28,7 +32,7 @@ authors: family-names: Yang affiliation: Peking University email: yaodong.yang@pku.edu.cn -version: 0.5.0 -date-released: "2022-09-05" +version: 0.6.0 +date-released: "2022-12-07" license: Apache-2.0 repository-code: "https://github.com/metaopt/torchopt" diff --git a/CMakeLists.txt b/CMakeLists.txt index 26786756..50f6144f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,9 +13,12 @@ # limitations under the License. # ============================================================================== -cmake_minimum_required(VERSION 3.8) +cmake_minimum_required(VERSION 3.11) # for FetchContent project(torchopt LANGUAGES CXX) +include(FetchContent) +set(PYBIND11_VERSION v2.10.1) + if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) endif() @@ -26,6 +29,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) find_package(Threads REQUIRED) # -pthread find_package(OpenMP REQUIRED) # -Xpreprocessor -fopenmp set(CMAKE_POSITION_INDEPENDENT_CODE ON) # -fPIC +set(CMAKE_CXX_VISIBILITY_PRESET hidden) # -fvisibility=hidden if(MSVC) string(APPEND CMAKE_CXX_FLAGS " /Wall") @@ -178,7 +182,7 @@ if(NOT DEFINED PYTHON_INCLUDE_DIR) message(STATUS "Auto detecting Python include directory...") system( STRIP OUTPUT_VARIABLE PYTHON_INCLUDE_DIR - COMMAND "${PYTHON_EXECUTABLE}" -c "print(__import__('sysconfig').get_path('include'))" + COMMAND "${PYTHON_EXECUTABLE}" -c "print(__import__('sysconfig').get_path('platinclude'))" ) endif() @@ -186,15 +190,16 @@ if("${PYTHON_INCLUDE_DIR}" STREQUAL "") message(FATAL_ERROR "Python include directory not found") else() message(STATUS "Detected Python include directory: \"${PYTHON_INCLUDE_DIR}\"") - include_directories(${PYTHON_INCLUDE_DIR}) + include_directories("${PYTHON_INCLUDE_DIR}") endif() system( STRIP OUTPUT_VARIABLE PYTHON_SITE_PACKAGES - COMMAND "${PYTHON_EXECUTABLE}" -c "print(__import__('sysconfig') .get_path('purelib'))" + COMMAND "${PYTHON_EXECUTABLE}" -c "print(__import__('sysconfig').get_path('purelib'))" ) message(STATUS "Detected Python site packages: \"${PYTHON_SITE_PACKAGES}\"") +# Include pybind11 set(PYBIND11_PYTHON_VERSION "${PYTHON_VERSION}") if(NOT DEFINED PYBIND11_CMAKE_DIR) @@ -206,14 +211,27 @@ if(NOT DEFINED PYBIND11_CMAKE_DIR) endif() if("${PYBIND11_CMAKE_DIR}" STREQUAL "") - message(FATAL_ERROR "Pybind11 CMake directory not found") + FetchContent_Declare( + pybind11 + GIT_REPOSITORY https://github.com/pybind/pybind11.git + GIT_TAG "${PYBIND11_VERSION}" + GIT_SHALLOW TRUE + SOURCE_DIR "${CMAKE_SOURCE_DIR}/third-party/pybind11" + BINARY_DIR "${CMAKE_SOURCE_DIR}/third-party/.cmake/pybind11/build" + STAMP_DIR "${CMAKE_SOURCE_DIR}/third-party/.cmake/pybind11/stamp" + ) + FetchContent_GetProperties(pybind11) + if(NOT pybind11_POPULATED) + message(STATUS "Populating Git repository pybind11@${PYBIND11_VERSION} to third-party/pybind11...") + FetchContent_MakeAvailable(pybind11) + endif() else() message(STATUS "Detected Pybind11 CMake directory: \"${PYBIND11_CMAKE_DIR}\"") find_package(pybind11 CONFIG PATHS "${PYBIND11_CMAKE_DIR}") endif() if(NOT DEFINED TORCH_INCLUDE_PATH) - message(STATUS "Auto detecting PyTorch include directory...") + message(STATUS "Auto detecting Torch include directory...") system( STRIP OUTPUT_VARIABLE TORCH_INCLUDE_PATH COMMAND "${PYTHON_EXECUTABLE}" -c "print('\\\;'.join(__import__('torch.utils.cpp_extension', fromlist=[None]).include_paths()))" @@ -232,7 +250,7 @@ else() endif() if(NOT DEFINED TORCH_LIBRARY_PATH) - message(STATUS "Auto detecting PyTorch library directory...") + message(STATUS "Auto detecting Torch library directory...") system( STRIP OUTPUT_VARIABLE TORCH_LIBRARY_PATH COMMAND "${PYTHON_EXECUTABLE}" -c "print('\\\;'.join(__import__('torch.utils.cpp_extension', fromlist=[None]).library_paths()))" @@ -251,19 +269,23 @@ endif() unset(TORCH_LIBRARIES) +foreach(VAR_PATH ${TORCH_LIBRARY_PATH}) + file(GLOB TORCH_LIBRARY "${VAR_PATH}/*") + message(STATUS "Detected Torch libraries: \"${TORCH_LIBRARY}\"") +endforeach() + foreach(VAR_PATH ${TORCH_LIBRARY_PATH}) if(WIN32) file(GLOB TORCH_LIBRARY "${VAR_PATH}/*.lib") else() file(GLOB TORCH_LIBRARY "${VAR_PATH}/libtorch_python.*") endif() - list(APPEND TORCH_LIBRARIES "${TORCH_LIBRARY}") endforeach() -message(STATUS "Detected Torch libraries: \"${TORCH_LIBRARIES}\"") +message(STATUS "Detected Torch Python libraries: \"${TORCH_LIBRARIES}\"") add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) -include_directories(${CMAKE_SOURCE_DIR}) +include_directories("${CMAKE_SOURCE_DIR}") add_subdirectory(src) diff --git a/Dockerfile b/Dockerfile index 82434eed..d34eda03 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,7 +7,7 @@ # $ docker build --target devel --tag torchopt-devel:latest . # -ARG cuda_docker_tag="11.6.2-cudnn8-devel-ubuntu20.04" +ARG cuda_docker_tag="11.7.1-cudnn8-devel-ubuntu22.04" FROM nvidia/cuda:"${cuda_docker_tag}" AS builder ENV DEBIAN_FRONTEND=noninteractive diff --git a/MANIFEST.in b/MANIFEST.in index 08cf6257..09403999 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,5 @@ recursive-include torchopt *.pyi +recursive-include torchopt *.typed include LICENSE # Include source files in sdist diff --git a/Makefile b/Makefile index ac67d4b8..5d99fce4 100644 --- a/Makefile +++ b/Makefile @@ -9,12 +9,14 @@ CXX_FILES = $(shell find $(SOURCE_FOLDERS) -type f -name "*.h" -o -name "*. COMMIT_HASH = $(shell git log -1 --format=%h) PATH := $(HOME)/go/bin:$(PATH) PYTHON ?= $(shell command -v python3 || command -v python) +CLANG_FORMAT ?= $(shell command -v clang-format-14 || command -v clang-format) +PYTESTOPTS ?= .PHONY: default default: install install: - $(PYTHON) -m pip install . + $(PYTHON) -m pip install -vvv . install-editable: $(PYTHON) -m pip install --upgrade pip @@ -24,6 +26,9 @@ install-editable: install-e: install-editable # alias +uninstall: + $(PYTHON) -m pip uninstall -y $(PROJECT_NAME) + build: $(PYTHON) -m pip install --upgrade pip $(PYTHON) -m pip install --upgrade setuptools wheel build @@ -35,18 +40,19 @@ check_pip_install = $(PYTHON) -m pip show $(1) &>/dev/null || (cd && $(PYTHON) - check_pip_install_extra = $(PYTHON) -m pip show $(1) &>/dev/null || (cd && $(PYTHON) -m pip install $(2) --upgrade) pylint-install: - $(call check_pip_install,pylint) + $(call check_pip_install_extra,pylint,pylint[spelling]) flake8-install: $(call check_pip_install,flake8) - $(call check_pip_install_extra,bugbear,flake8_bugbear) + $(call check_pip_install_extra,flake8-bugbear,flake8-bugbear) py-format-install: $(call check_pip_install,isort) - $(call check_pip_install,black) + $(call check_pip_install_extra,black,black[jupyter]) mypy-install: $(call check_pip_install,mypy) + $(call check_pip_install,types-setuptools) pre-commit-install: $(call check_pip_install,pre-commit) @@ -54,7 +60,11 @@ pre-commit-install: docs-install: $(call check_pip_install,pydocstyle) - $(call check_pip_install,doc8) + $(call check_pip_install_extra,doc8,"doc8<1.0.0a0") + if ! $(PYTHON) -c "import sys; exit(sys.version_info < (3, 8))"; then \ + $(PYTHON) -m pip uninstall --yes importlib-metadata; \ + $(call check_pip_install_extra,importlib-metadata,"importlib-metadata<5.0.0a0"); \ + fi $(call check_pip_install,sphinx) $(call check_pip_install,sphinx-rtd-theme) $(call check_pip_install,sphinx-autoapi) @@ -75,7 +85,9 @@ cpplint-install: $(call check_pip_install,cpplint) clang-format-install: - command -v clang-format || sudo apt-get install -y clang-format + command -v clang-format-14 || command -v clang-format || \ + sudo apt-get install -y clang-format-14 || \ + sudo apt-get install -y clang-format clang-tidy-install: command -v clang-tidy || sudo apt-get install -y clang-tidy @@ -93,7 +105,7 @@ pytest: pytest-install cd tests && \ $(PYTHON) -m pytest --verbose --color=yes --durations=0 \ --cov="$(PROJECT_NAME)" --cov-report=xml --cov-report=term-missing \ - . + $(PYTESTOPTS) . test: pytest @@ -106,8 +118,8 @@ flake8: flake8-install $(PYTHON) -m flake8 $(PYTHON_FILES) --count --select=E9,F63,F7,F82,E225,E251 --show-source --statistics py-format: py-format-install - $(PYTHON) -m isort --project torchopt --check $(PYTHON_FILES) && \ - $(PYTHON) -m black --check $(PYTHON_FILES) + $(PYTHON) -m isort --project $(PROJECT_NAME) --check $(PYTHON_FILES) && \ + $(PYTHON) -m black --check $(PYTHON_FILES) tutorials mypy: mypy-install $(PYTHON) -m mypy $(PROJECT_PATH) @@ -121,7 +133,7 @@ cpplint: cpplint-install $(PYTHON) -m cpplint $(CXX_FILES) clang-format: clang-format-install - clang-format --style=file -i $(CXX_FILES) -n --Werror + $(CLANG_FORMAT) --style=file -i $(CXX_FILES) -n --Werror # Documentation @@ -129,12 +141,14 @@ addlicense: addlicense-install addlicense -c $(COPYRIGHT) -l apache -y 2022 -check $(SOURCE_FOLDERS) docstyle: docs-install + make -C docs clean $(PYTHON) -m pydocstyle $(PROJECT_PATH) && doc8 docs && make -C docs html SPHINXOPTS="-W" docs: docs-install $(PYTHON) -m sphinx_autobuild --watch $(PROJECT_PATH) --open-browser docs/source docs/build spelling: docs-install + make -C docs clean make -C docs spelling SPHINXOPTS="-W" clean-docs: @@ -142,12 +156,12 @@ clean-docs: # Utility functions -lint: flake8 py-format mypy clang-format cpplint docstyle spelling +lint: flake8 py-format mypy pylint clang-format cpplint docstyle spelling format: py-format-install clang-format-install addlicense-install - $(PYTHON) -m isort --project torchopt $(PYTHON_FILES) - $(PYTHON) -m black $(PYTHON_FILES) - clang-format -style=file -i $(CXX_FILES) + $(PYTHON) -m isort --project $(PROJECT_NAME) $(PYTHON_FILES) + $(PYTHON) -m black $(PYTHON_FILES) tutorials + $(CLANG_FORMAT) -style=file -i $(CXX_FILES) addlicense -c $(COPYRIGHT) -l apache -y 2022 $(SOURCE_FOLDERS) clean-py: diff --git a/README.md b/README.md index 13d005f5..3dc1155f 100644 --- a/README.md +++ b/README.md @@ -1,23 +1,41 @@ +
-![Python 3.7+](https://img.shields.io/badge/Python-3.7%2B-brightgreen.svg) -[![PyPI](https://img.shields.io/pypi/v/torchopt?label=PyPI)](https://pypi.org/project/torchopt) -![Status](https://img.shields.io/pypi/status/torchopt?label=Status) -![GitHub Workflow Status](https://img.shields.io/github/workflow/status/metaopt/torchopt/Tests?label=tests&logo=github) -[![Documentation Status](https://readthedocs.org/projects/torchopt/badge/?version=latest)](https://torchopt.readthedocs.io/en/latest/?badge=latest) -[![Downloads](https://static.pepy.tech/personalized-badge/torchopt?period=month&left_color=grey&right_color=blue&left_text=Downloads/month)](https://pepy.tech/project/torchopt) -[![GitHub Repo Stars](https://img.shields.io/github/stars/metaopt/torchopt?label=Stars&logo=github&color=brightgreen)](https://github.com/metaopt/torchopt/stargazers) -[![License](https://img.shields.io/github/license/metaopt/torchopt?label=License)](#license) +
+ + ![Python 3.7+](https://img.shields.io/badge/Python-3.7%2B-brightgreen.svg) + ![PyPI](https://img.shields.io/pypi/v/torchopt?logo=pypi) + ![GitHub Workflow Status](https://img.shields.io/github/workflow/status/metaopt/torchopt/Tests?label=tests&logo=github) + ![Documentation Status](https://img.shields.io/readthedocs/torchopt?logo=readthedocs) + ![Downloads](https://static.pepy.tech/personalized-badge/torchopt?period=total&left_color=grey&right_color=blue&left_text=downloads) + ![GitHub Repo Stars](https://img.shields.io/github/stars/metaopt/torchopt?color=brightgreen&logo=github) + ![License](https://img.shields.io/github/license/metaopt/torchopt?label=license&logo=) + +
+ +

+ Installation | + Documentation | + Tutorials | + Examples | + Paper | + Citation +

-**TorchOpt** is a high-performance optimizer library built upon [PyTorch](https://pytorch.org/) for easy implementation of functional optimization and gradient-based meta-learning. It consists of two main features: +**TorchOpt** is an efficient library for differentiable optimization built upon [PyTorch](https://pytorch.org). +TorchOpt is: -- TorchOpt provides functional optimizer which enables [JAX-like](https://github.com/google/jax) composable functional optimizer for PyTorch. With TorchOpt, one can easily conduct neural network optimization in PyTorch with functional style optimizer, similar to [Optax](https://github.com/deepmind/optax) in JAX. -- With the design of functional programing, TorchOpt provides efficient, flexible, and easy-to-implement differentiable optimizer for gradient-based meta-learning research. It largely reduces the efforts required to implement sophisticated meta-learning algorithms. +- **Comprehensive**: TorchOpt provides three differentiation mode - explicit differentiation, implicit differentiation and zero-order differentiation for handling different differentiable optimization situations. +- **Flexible**: TorchOpt provides both functional and objective-oriented API for user different preferences. Users can implement differentiable optimization in JAX-like or PyTorch-like style. +- **Efficient**: TorchOpt provides (1) CPU/GPU acceleration differentiable optimizer (2) RPC-based distributed training framework (3) Fast Tree Operations, to largely increase the training efficiency for bi-level optimization problem. + +Beyond differentiable optimization, TorchOpt can also be regarded as a functional optimizer which enables [JAX-like](https://github.com/google/jax) composable functional optimizer for PyTorch. +With TorchOpt, users can easily conduct neural network optimization in PyTorch with functional style optimizer, similar to [Optax](https://github.com/deepmind/optax) in JAX. -------------------------------------------------------------------------------- @@ -27,36 +45,37 @@ The README is organized as follows: - [Optax-Like API](#optax-like-api) - [PyTorch-Like API](#pytorch-like-api) - [Differentiable](#differentiable) -- [TorchOpt as Differentiable Optimizer for Meta-Learning](#torchopt-as-differentiable-optimizer-for-meta-learning) - - [Meta-Learning API](#meta-learning-api) -- [Examples](#examples) -- [High-Performance](#high-performance) +- [TorchOpt for Differentiable Optimization](#torchopt-for-differentiable-optimization) + - [Explicit Gradient (EG)](#explicit-gradient-eg) + - [Implicit Gradient (IG)](#implicit-gradient-ig) + - [Zero-order Differentiation (ZD)](#zero-order-differentiation-zd) +- [High-Performance and Distributed Training](#high-performance-and-distributed-training) + - [CPU/GPU accelerated differentiable optimizer](#cpugpu-accelerated-differentiable-optimizer) + - [Distributed Training](#distributed-training) + - [OpTree](#optree) - [Visualization](#visualization) +- [Examples](#examples) - [Installation](#installation) -- [Future Plan](#future-plan) - [Changelog](#changelog) -- [The Team](#the-team) - [Citing TorchOpt](#citing-torchopt) +- [The Team](#the-team) +- [License](#license) -------------------------------------------------------------------------------- ## TorchOpt as Functional Optimizer -The design of TorchOpt follows the philosophy of functional programming. Aligned with [`functorch`](https://github.com/pytorch/functorch), users can conduct functional style programing with models, optimizers and training in PyTorch. We use the Adam optimizer as an example in the following illustration. You can also check out the tutorial notebook [Functional Optimizer](tutorials/1_Functional_Optimizer.ipynb) for more details. +The design of TorchOpt follows the philosophy of functional programming. +Aligned with [`functorch`](https://github.com/pytorch/functorch), users can conduct functional style programing with models, optimizers and training in PyTorch. +We use the Adam optimizer as an example in the following illustration. +You can also check out the tutorial notebook [Functional Optimizer](tutorials/1_Functional_Optimizer.ipynb) for more details. ### Optax-Like API -For those users who prefer fully functional programing, we offer Optax-Like API by passing gradients and optimizers states to the optimizer function. We design base class `torchopt.Optimizer` that has the same interface as `torch.optim.Optimizer`. Here is an example coupled with `functorch`: +For those users who prefer fully functional programing, we offer Optax-Like API by passing gradients and optimizers states to the optimizer function. +Here is an example coupled with `functorch`: ```python -import functorch -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.utils.data import DataLoader - -import torchopt - class Net(nn.Module): ... class Loader(DataLoader): ... @@ -77,9 +96,26 @@ updates, opt_state = optimizer.update(grads, opt_state) # get updates params = torchopt.apply_updates(params, updates) # update network parameters ``` +We also provide a wrapper `torchopt.FuncOptimizer` to make maintaining the optimizer state easier: + +```python +net = Net() # init +loader = Loader() +optimizer = torchopt.FuncOptimizer(torchopt.adam()) # wrap with `torchopt.FuncOptimizer` + +model, params = functorch.make_functional(net) # use functorch extract network parameters + +for xs, ys in loader: # get data + pred = model(params, xs) # forward + loss = F.cross_entropy(pred, ys) # compute loss + + params = optimizer.step(loss, params) # update network parameters +``` + ### PyTorch-Like API -We also offer origin PyTorch APIs (e.g. `zero_grad()` or `step()`) by wrapping our Optax-Like API for traditional PyTorch user: +We also design base class `torchopt.Optimizer` that has the same interface as `torch.optim.Optimizer`. +We offer origin PyTorch APIs (e.g. `zero_grad()` or `step()`) by wrapping our Optax-Like API for traditional PyTorch users. ```python net = Net() # init @@ -97,137 +133,261 @@ optimizer.step() # step updates ### Differentiable -On top of the same optimization function as `torch.optim`, an important benefit of functional optimizer is that one can implement differentiable optimization easily. This is particularly helpful when the algorithm requires to differentiate through optimization update (such as meta learning practices). We take as the inputs the gradients and optimizer states, use non-in-place operators to compute and output the updates. The processes can be automatically implemented, with the only need from users being to pass the argument `inplace=False` to the functions: - -```python -# Get updates -updates, opt_state = optimizer.update(grad, opt_state, inplace=False) -# Update network parameters -params = torchopt.apply_updates(params, updates, inplace=False) -``` +On top of the same optimization function as `torch.optim`, an important benefit of functional optimizer is that one can implement differentiable optimization easily. +This is particularly helpful when the algorithm requires to differentiate through optimization update (such as meta-learning practices). +We take as the inputs the gradients and optimizer states, use non-in-place operators to compute and output the updates. +The processes can be automatically implemented, with the only need from users being to pass the argument `inplace=False` to the functions. +Check out section [Explicit Gradient (EG)](#explicit-gradient-eg) functional API for example. -------------------------------------------------------------------------------- -## TorchOpt as Differentiable Optimizer for Meta-Learning +## TorchOpt for Differentiable Optimization -Meta-Learning has gained enormous attention in both Supervised Learning and Reinforcement Learning. Meta-Learning algorithms often contain a bi-level optimization process with *inner loop* updating the network parameters and *outer loop* updating meta parameters. The figure below illustrates the basic formulation for meta-optimization in Meta-Learning. The main feature is that the gradients of *outer loss* will back-propagate through all `inner.step` operations. +We design a bilevel-optimization updating scheme, which can be easily extended to realize various differentiable optimization processes.
- +
-Since network parameters become a node of computation graph, a flexible Meta-Learning library should enable users manually control the gradient graph connection which means that users should have access to the network parameters and optimizer states for manually detaching or connecting the computation graph. In PyTorch designing, the network parameters or optimizer states are members of network (a.k.a. `torch.nn.Module`) or optimizer (a.k.a. `torch.optim.Optimizer`), this design significantly introducing difficulty for user control network parameters or optimizer states. Previous differentiable optimizer Repo [`higher`](https://github.com/facebookresearch/higher), [`learn2learn`](https://github.com/learnables/learn2learn) follows the PyTorch designing which leads to inflexible API. +As shown above, the scheme contains an outer level that has parameters $\phi$ that can be learned end-to-end through the inner level parameters solution $\theta^{\prime}(\phi)$ by using the best-response derivatives $\partial \theta^{\prime}(\phi) / \partial \phi$. +TorchOpt supports three differentiation modes. +It can be seen that the key component of this algorithm is to calculate the best-response (BR) Jacobian. +From the BR-based perspective, existing gradient methods can be categorized into three groups: explicit gradient over unrolled optimization, implicit differentiation, and zero-order gradient differentiation. -In contrast to them, TorchOpt realizes differentiable optimizer with functional programing, where Meta-Learning researchers could control the network parameters or optimizer states as normal variables (a.k.a. `torch.Tensor`). This functional optimizer design of TorchOpt is beneficial for implementing complex gradient flow Meta-Learning algorithms and allow us to improve computational efficiency by using techniques like operator fusion. +### Explicit Gradient (EG) -### Meta-Learning API +The idea of explicit gradient is to treat the gradient step as a differentiable function and try to backpropagate through the unrolled optimization path. +This differentiation mode is suitable for algorithms when the inner-level optimization solution is obtained by a few gradient steps, such as [MAML](https://arxiv.org/abs/1703.03400) and [MGRL](https://arxiv.org/abs/1805.09801). +TorchOpt offers both functional and object-oriented API for EG to fit different user applications. -- We design a base class `torchopt.MetaOptimizer` for managing network updates in Meta-Learning. The constructor of `MetaOptimizer` takes as input the network rather than network parameters. `MetaOptimizer` exposed interface `step(loss)` takes as input the loss for step the network parameter. Refer to the tutorial notebook [Meta Optimizer](tutorials/3_Meta_Optimizer.ipynb) for more details. -- We offer `torchopt.chain` which can apply a list of chainable update transformations. Combined with `MetaOptimizer`, it can help you conduct gradient transformation such as gradient clip before the Meta optimizer steps. Refer to the tutorial notebook [Meta Optimizer](tutorials/3_Meta_Optimizer.ipynb) for more details. -- We observe that different Meta-Learning algorithms vary in inner-loop parameter recovery. TorchOpt provides basic functions for users to extract or recover network parameters and optimizer states anytime anywhere they want. -- Some algorithms such as MGRL ([arXiv:1805.09801](https://arxiv.org/abs/1805.09801)) initialize the inner-loop parameters inherited from previous inner-loop process when conducting a new bi-level process. TorchOpt also provides a finer function `stop_gradient` for manipulating the gradient graph, which is helpful for this kind of algorithms. Refer to the notebook [Stop Gradient](tutorials/4_Stop_Gradient.ipynb) for more details. +#### Functional API -We give an example of MAML ([arXiv:1703.03400](https://arxiv.org/abs/1703.03400)) with inner-loop Adam optimizer to illustrate TorchOpt APIs: +The functional API is to conduct optimization in a functional programming style. +Note that we pass the argument `inplace=False` to the functions to make the optimization differentiable. +Refer to the tutorial notebook [Functional Optimizer](tutorials/1_Functional_Optimizer.ipynb) for more guidances. ```python -net = Net() # init +# Define functional optimizer +optimizer = torchopt.adam() +# Define meta and inner parameters +meta_params = ... +fmodel, params = make_functional(model) +# Initial state +state = optimizer.init(params) + +for iter in range(iter_times): + loss = inner_loss(fmodel, params, meta_params) + grads = torch.autograd.grad(loss, params) + # Apply non-inplace parameter update + updates, state = optimizer.update(grads, state, inplace=False) + params = torchopt.apply_updates(params, updates) + +loss = outer_loss(fmodel, params, meta_params) +meta_grads = torch.autograd.grad(loss, meta_params) +``` + +#### OOP API -# The constructor `MetaOptimizer` takes as input the network -inner_optim = torchopt.MetaAdam(net) -outer_optim = torchopt.Adam(net.parameters()) - -for train_iter in range(train_iters): - outer_loss = 0 - for task in range(tasks): - loader = Loader(tasks) - - # Store states at the initial points - net_state = torchopt.extract_state_dict(net) # extract state - optim_state = torchopt.extract_state_dict(inner_optim) - for inner_iter in range(inner_iters): - # Compute inner loss and perform inner update - xs, ys = next(loader) - pred = net(xs) - inner_loss = F.cross_entropy(pred, ys) - inner_optim.step(inner_loss) - - # Compute outer loss and back-propagate - xs, ys = next(loader) - pred = net(xs) - outer_loss = outer_loss + F.cross_entropy(pred, ys) - - # Recover network and optimizer states at the initial point for the next task - torchopt.recover_state_dict(inner_optim, optim_state) - torchopt.recover_state_dict(net, net_state) - - outer_loss = outer_loss / len(tasks) # task average - outer_optim.zero_grad() - outer_loss.backward() - outer_optim.step() - - # Stop gradient if necessary - torchopt.stop_gradient(net) - torchopt.stop_gradient(inner_optim) +TorchOpt also provides OOP API compatible with PyTorch programming style. +Refer to the example and the tutorial notebook [Meta-Optimizer](tutorials/3_Meta_Optimizer.ipynb), [Stop Gradient](tutorials/4_Stop_Gradient.ipynb) for more guidances. + +```python +# Define meta and inner parameters +meta_params = ... +model = ... +# Define differentiable optimizer +optimizer = torchopt.MetaAdam(model) # a model instance as argument instead of model.parameters() + +for iter in range(iter_times): + # Perform inner update + loss = inner_loss(model, meta_params) + optimizer.step(loss) + +loss = outer_loss(model, meta_params) +loss.backward() ``` --------------------------------------------------------------------------------- +### Implicit Gradient (IG) -## Examples +By treating the solution $\theta^{\prime}$ as an implicit function of $\phi$, the idea of IG is to directly get analytical best-response derivatives $\partial \theta^{\prime} (\phi) / \partial \phi$ by [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem). +This is suitable for algorithms when the inner-level optimal solution is achieved ${\left. \frac{\partial F (\theta, \phi)}{\partial \theta} \right\rvert}_{\theta=\theta^{\prime}} = 0$ or reaches some stationary conditions $F (\theta^{\prime}, \phi) = 0$, such as [iMAML](https://arxiv.org/abs/1909.04630) and [DEQ](https://arxiv.org/abs/1909.01377). +TorchOpt offers both functional and OOP APIs for supporting both [conjugate gradient-based](https://arxiv.org/abs/1909.04630) and [Neumann series-based](https://arxiv.org/abs/1911.02590) IG methods. +Refer to the example [iMAML](https://github.com/waterhorse1/torchopt/tree/readme/examples/iMAML) and the notebook [Implicit Gradient](tutorials/5_Implicit_Differentiation.ipynb) for more guidances. -In [`examples`](examples), we offer several examples of functional optimizer and 5 light-weight meta-learning examples with TorchOpt. The meta-learning examples covers 2 Supervised Learning and 3 Reinforcement Learning algorithms. +#### Functional API -- [Model Agnostic Meta Learning (MAML) - Supervised Learning](https://arxiv.org/abs/1703.03400) (ICML 2017) -- [Learning to Reweight Examples for Robust Deep Learning](https://arxiv.org/abs/1803.09050) (ICML 2018) -- [Model Agnostic Meta Learning (MAML) - Reinforcement Learning](https://arxiv.org/abs/1703.03400) (ICML 2017) -- [Meta Gradient Reinforcement Learning (MGRL)](https://arxiv.org/abs/1805.09801) (NeurIPS 2018) -- [Learning through opponent learning process (LOLA)](https://arxiv.org/abs/1709.04326) (AAMAS 2018) +For implicit gradient, users need to define the stationary condition and TorchOpt provides the decorator to wrap the solve function for enabling implicit gradient computation. + +```python +# The stationary condition for the inner-loop +def stationary(params, meta_params, data): + # Stationary condition construction + return stationary condition + +# Decorator for wrapping the function +# Optionally specify the linear solver (conjugate gradient or Neumann series) +@torchopt.diff.implicit.custom_root(stationary, solve=linear_solver) +def solve(params, meta_params, data): + # Forward optimization process for params + return output + +# Define params, meta_params and get data +params, meta_prams, data = ..., ..., ... +optimal_params = solve(params, meta_params, data) +loss = outer_loss(optimal_params) + +meta_grads = torch.autograd.grad(loss, meta_params) +``` + +#### OOP API + +TorchOpt also offer an OOP API, users need to inherit from the class `torchopt.nn.ImplicitMetaGradientModule` to construct the inner-loop network. +Users need to define the stationary condition/objective function and the inner-loop solve function to enable implicit gradient computation. + +```python +# Inherited from the class ImplicitMetaGradientModule +# Optionally specify the linear solver (conjugate gradient or Neumann series) +class InnerNet(ImplicitMetaGradientModule, linear_solver): + def __init__(self, meta_param): + super().__init__() + self.meta_param = meta_param + ... + + def forward(self, batch): + # Forward process + ... + + def optimality(self, batch, labels): + # Stationary condition construction for calculating implicit gradient + # NOTE: If this method is not implemented, it will be automatically + # derived from the gradient of the `objective` function. + ... + + def objective(self, batch, labels): + # Define the inner-loop optimization objective + ... + + def solve(self, batch, labels): + # Conduct the inner-loop optimization + ... + +# Get meta_params and data +meta_params, data = ..., ... +inner_net = InnerNet(meta_params) + +# Solve for inner-loop process related with the meta-parameters +optimal_inner_net = inner_net.solve(data) + +# Get outer loss and solve for meta-gradient +loss = outer_loss(optimal_inner_net) +meta_grads = torch.autograd.grad(loss, meta_params) +``` + +### Zero-order Differentiation (ZD) + +When the inner-loop process is non-differentiable or one wants to eliminate the heavy computation burdens in the previous two modes (brought by Hessian), one can choose Zero-order Differentiation (ZD). +ZD typically gets gradients based on zero-order estimation, such as finite-difference, or [Evolutionary Strategy](https://arxiv.org/abs/1703.03864). +Instead of optimizing the objective $F$, ES optimizes a smoothed objective. +TorchOpt provides both functional and OOP APIs for the ES method. +Refer to the tutorial notebook [Zero-order Differentiation](tutorials/6_Zero_Order_Differentiation.ipynb) for more guidances. + +#### Functional API + +```python +# Customize the noise sampling function in ES +def sample(sample_shape): + ... + return sample_noise + +# Specify method and hyper-parameter of ES +@torchopt.diff.zero_order(sample, method) +def forward(params, batch, labels): + # forward process + return output +``` -------------------------------------------------------------------------------- -## High-Performance +## High-Performance and Distributed Training -One can think of the scale procedures on gradients of optimizer algorithms as a combination of several operations. For example, the implementation of the Adam algorithm often includes addition, multiplication, power and square operations, one can fuse these operations into several compound functions. The operator fusion could greatly simplify the computation graph and reduce the GPU function launching stall. In addition, one can also implement the optimizer backward function and manually reuse some intermediate tensors to improve the backward performance. Users can pass argument `use_accelerated_op=True` to `adam`, `Adam` and `MetaAdam` to enable the fused accelerated operator. The arguments are the same between the two kinds of implementations. +### CPU/GPU accelerated differentiable optimizer -Here we evaluate the performance using the MAML-Omniglot code with the inner-loop Adam optimizer on GPU. We comparable the run time of the overall algorithm and the meta-optimization (outer-loop optimization) under different network architecture/inner-step numbers. We choose [`higher`](https://github.com/facebookresearch/higher) as our baseline. The figure below illustrate that our accelerated Adam can achieve at least $1/3$ efficiency improvement over the baseline. +We take the optimizer as a whole instead of separating it into several basic operators (e.g., `sqrt` and `div`). +Therefore, by manually writing the forward and backward functions, we can perform the symbolic reduction. +In addition, we can store some intermediate data that can be reused during the backpropagation. +We write the accelerated functions in C++ OpenMP and CUDA, bind them by [`pybind11`](https://github.com/pybind/pybind11) to allow they can be called by Python, and then we define the forward and backward behavior using `torch.autograd.Function`. +Users can use by simply setting the `use_accelerated_op` flag as `True`. +Refer to the corresponding sections in tutorials [Functional Optimizer](tutorials/1_Functional_Optimizer.ipynb) and [Meta-Optimizer](tutorials/3_Meta_Optimizer.ipynb) -
- -
+```python +optimizer = torchopt.MetaAdam(model, lr, use_accelerated_op=True) +``` -Notably, the operator fusion not only increases performance but also help simplify the computation graph, which will be discussed in the next section. +### Distributed Training + +`TorchOpt` provides distributed training features based on the PyTorch RPC module for better training speed and multi-node multi-GPU support. +Different from the MPI-like parallelization paradigm, which uses multiple homogenous workers and requires carefully designed communication hooks, the RPC APIs allow users to build their optimization pipeline more flexibly. +Experimental results show that we achieve approximately linear relationship between the speed-up ratio and the number of workers. +Check out the [distributed MAML example](https://github.com/metaopt/torchopt/tree/main/examples/distributed/few-shot) for more specific guidance. + +### OpTree + +We implement the *PyTree* to enable fast nested structure flatten using C++. +The tree operations (e.g., flatten and unflatten) are very important in enabling functional and Just-In-Time (JIT) features of deep learning frameworks. +By implementing it in C++, we can use some cache/memory friendly structures (e.g., `absl::InlinedVector`) to improve the performance. +For more guidance and comparison results, please refer to our open source project [`OpTree`](https://github.com/metaopt/optree). -------------------------------------------------------------------------------- ## Visualization -Complex gradient flow in meta-learning brings in a great challenge for managing the gradient flow and verifying the correctness of it. TorchOpt provides a visualization tool that draw variable (e.g. network parameters or meta parameters) names on the gradient graph for better analyzing. The visualization tool is modified from [`torchviz`](https://github.com/szagoruyko/pytorchviz). We provide an example using the [visualization code](examples/visualize.py). Also refer to the notebook [Visualization](tutorials/2_Visualization.ipynb) for more details. +Complex gradient flow in meta-learning brings in a great challenge for managing the gradient flow and verifying the correctness of it. +TorchOpt provides a visualization tool that draw variable (e.g., network parameters or meta-parameters) names on the gradient graph for better analyzing. +The visualization tool is modified from [`torchviz`](https://github.com/szagoruyko/pytorchviz). +Refer to the example [visualization code](examples/visualize.py) and the tutorial notebook [Visualization](tutorials/2_Visualization.ipynb) for more details. -The figure below show the visualization result. Compared with [`torchviz`](https://github.com/szagoruyko/pytorchviz), TorchOpt fuses the operations within the `Adam` together (orange) to reduce the complexity and provide simpler visualization. +The figure below show the visualization result. +Compared with [`torchviz`](https://github.com/szagoruyko/pytorchviz), TorchOpt fuses the operations within the `Adam` together (orange) to reduce the complexity and provide simpler visualization.
- +
-------------------------------------------------------------------------------- +## Examples + +In the [`examples`](examples) directory, we offer several examples of functional optimizer and light-weight meta-learning examples with TorchOpt. + +- [Model-Agnostic Meta-Learning (MAML) - Supervised Learning](https://arxiv.org/abs/1703.03400) (ICML 2017) +- [Learning to Reweight Examples for Robust Deep Learning](https://arxiv.org/abs/1803.09050) (ICML 2018) +- [Model-Agnostic Meta-Learning (MAML) - Reinforcement Learning](https://arxiv.org/abs/1703.03400) (ICML 2017) +- [Meta-Gradient Reinforcement Learning (MGRL)](https://arxiv.org/abs/1805.09801) (NeurIPS 2018) +- [Learning through opponent learning process (LOLA)](https://arxiv.org/abs/1709.04326) (AAMAS 2018) +- [Meta-Learning with Implicit Gradients](https://arxiv.org/abs/1909.04630) (NeurIPS 2019) + +Also check [`examples`](examples) for more distributed/visualization/functorch-compatible examples. + +-------------------------------------------------------------------------------- + ## Installation Requirements - PyTorch - (Optional) For visualizing computation graphs - - [Graphviz](https://graphviz.org/download/) (for Linux users use `apt/yum install graphviz` or `conda install -c anaconda python-graphviz`) + - [Graphviz](https://graphviz.org/download) (for Linux users use `apt/yum install graphviz` or `conda install -c anaconda python-graphviz`) -**Please follow the instructions at to install PyTorch in your Python environment first.** Then run the following command to install TorchOpt from PyPI ([![PyPI](https://img.shields.io/pypi/v/torchopt?label=PyPI)](https://pypi.org/project/torchopt) / ![Status](https://img.shields.io/pypi/status/torchopt?label=Status)): +**Please follow the instructions at to install PyTorch in your Python environment first.** Then run the following command to install TorchOpt from PyPI ([![PyPI](https://img.shields.io/pypi/v/torchopt?label=pypi&logo=pypi)](https://pypi.org/project/torchopt) / ![Status](https://img.shields.io/pypi/status/torchopt?label=status)): ```bash pip3 install torchopt ``` -If the minimum version of PyTorch is not satisfied, `pip` will install/upgrade it for you. Please be careful about the `torch` build for CPU / CUDA support (e.g. `cpu`, `cu102`, `cu113`). You may need to specify the extra index URL for the `torch` package: +If the minimum version of PyTorch is not satisfied, `pip` will install/upgrade it for you. Please be careful about the `torch` build for CPU / CUDA support (e.g. `cpu`, `cu116`, `cu117`). You may need to specify the extra index URL for the `torch` package: ```bash -pip3 install torchopt --extra-index-url https://download.pytorch.org/whl/cu116 +pip3 install torchopt --extra-index-url https://download.pytorch.org/whl/cu117 ``` See for more information about installing PyTorch. @@ -247,7 +407,7 @@ git clone https://github.com/metaopt/torchopt.git cd torchopt # You may need `CONDA_OVERRIDE_CUDA` if conda fails to detect the NVIDIA driver (e.g. in docker or WSL2) -CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe.yaml +CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe-minimal.yaml conda activate torchopt make install-editable # or run `pip3 install --no-build-isolation --editable .` @@ -255,36 +415,29 @@ make install-editable # or run `pip3 install --no-build-isolation --editable .` -------------------------------------------------------------------------------- -## Future Plan - -- [x] CPU-accelerated optimizer -- [ ] Support general implicit differentiation with functional programing -- [X] Support more optimizers such as AdamW, RMSProp -- [ ] Zero order optimization -- [ ] Distributed optimizers -- [ ] Support `complex` data type - ## Changelog See [CHANGELOG.md](CHANGELOG.md). -------------------------------------------------------------------------------- -## The Team - -TorchOpt is a work by Jie Ren, Xidong Feng, [Bo Liu](https://github.com/Benjamin-eecs), [Xuehai Pan](https://github.com/XuehaiPan), [Luo Mai](https://luomai.github.io/) and [Yaodong Yang](https://www.yangyaodong.com/). - ## Citing TorchOpt If you find TorchOpt useful, please cite it in your publications. ```bibtex -@software{TorchOpt, - author = {Jie Ren and Xidong Feng and Bo Liu and Xuehai Pan and Luo Mai and Yaodong Yang}, - title = {TorchOpt}, - year = {2022}, - publisher = {GitHub}, - journal = {GitHub repository}, - howpublished = {\url{https://github.com/metaopt/torchopt}}, +@article{torchopt, + title = {TorchOpt: An Efficient Library for Differentiable Optimization}, + author = {Ren, Jie and Feng, Xidong and Liu, Bo and Pan, Xuehai and Fu, Yao and Mai, Luo and Yang, Yaodong}, + journal = {arXiv preprint arXiv:2211.06934}, + year = {2022} } ``` + +## The Team + +TorchOpt is a work by [Jie Ren](https://github.com/JieRen98), [Xidong Feng](https://github.com/waterhorse1), [Bo Liu](https://github.com/Benjamin-eecs), [Xuehai Pan](https://github.com/XuehaiPan), [Luo Mai](https://luomai.github.io), and [Yaodong Yang](https://www.yangyaodong.com). + +## License + +TorchOpt is released under the Apache License, Version 2.0. diff --git a/conda-recipe-minimal.yaml b/conda-recipe-minimal.yaml new file mode 100644 index 00000000..4ae91303 --- /dev/null +++ b/conda-recipe-minimal.yaml @@ -0,0 +1,56 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Create virtual environment with command: +# +# $ CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe-minimal.yaml +# + +name: torchopt + +channels: + - pytorch + - nvidia/label/cuda-11.7.1 + - defaults + - conda-forge + +dependencies: + - python = 3.9 + - pip + + # Learning + - pytorch::pytorch >= 1.13 # sync with project.dependencies + - pytorch::torchvision + - pytorch::pytorch-mutex = *=*cuda* + - pip: + - torchviz + + # Device select + - nvidia/label/cuda-11.7.1::cuda-toolkit = 11.7 + + # Build toolchain + - cmake >= 3.11 + - make + - cxx-compiler + - gxx = 10 + - nvidia/label/cuda-11.7.1::cuda-nvcc + - nvidia/label/cuda-11.7.1::cuda-cudart-dev + - pybind11 >= 2.10.1 + + # Misc + - optree >= 0.4.1 + - typing-extensions >= 4.0.0 + - numpy + - python-graphviz diff --git a/conda-recipe.yaml b/conda-recipe.yaml index 19229136..9eacbfaa 100644 --- a/conda-recipe.yaml +++ b/conda-recipe.yaml @@ -1,3 +1,18 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# # Create virtual environment with command: # # $ CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe.yaml @@ -7,78 +22,79 @@ name: torchopt channels: - pytorch + - nvidia/label/cuda-11.7.1 - defaults - - nvidia/label/cuda-11.6.2 - - nvidia - conda-forge dependencies: - - python = 3.8 + - python = 3.9 - pip # Learning - - pytorch::pytorch >= 1.12 + - pytorch::pytorch >= 1.13 # sync with project.dependencies - pytorch::torchvision - pytorch::pytorch-mutex = *=*cuda* - pip: - - functorch >= 0.2 - torchviz - sphinxcontrib-katex # for documentation - jax # for tutorials - jaxlib >= 0.3=*cuda* # for tutorials - optax # for tutorials + - jaxopt # for tests - tensorboard # for examples - - wandb # Device select - - nvidia::cudatoolkit = 11.6 - - cudnn + - nvidia/label/cuda-11.7.1::cuda-toolkit = 11.7 # Build toolchain - - cmake >= 3.4 + - cmake >= 3.11 - make - cxx-compiler - gxx = 10 - - nvidia/label/cuda-11.6.2::cuda-nvcc - - nvidia/label/cuda-11.6.2::cuda-cudart-dev - - patchelf >= 0.9 - - pybind11 + - nvidia/label/cuda-11.7.1::cuda-nvcc + - nvidia/label/cuda-11.7.1::cuda-cudart-dev + - patchelf >= 0.14 + - pybind11 >= 2.10.1 # Misc - - typing-extensions + - optree >= 0.4.1 + - typing-extensions >= 4.0.0 - numpy - matplotlib-base - seaborn - python-graphviz - pillow + - setproctitle # Documentation - - sphinx + - sphinx >= 5.2.1 - sphinx_rtd_theme - sphinx-autobuild - sphinx-copybutton - sphinxcontrib-spelling - sphinxcontrib-bibtex - - sphinx-autodoc-typehints + - sphinx-autodoc-typehints >= 1.19.2 - pyenchant + - hunspell-en - myst-nb - ipykernel - pandoc - - docutils = 0.16 + - docutils # Testing - pytest - pytest-cov - pytest-xdist - isort - - conda-forge::black >= 22.6.0 - - pylint - - mypy + - conda-forge::black-jupyter >= 22.6.0 + - pylint >= 2.15.0 + - mypy >= 0.990 + - types-setuptools - flake8 - flake8-bugbear - doc8 < 1.0.0a0 - pydocstyle - - clang-format + - clang-format >= 14 - clang-tools # clang-tidy - cpplint - pre-commit diff --git a/docs/conda-recipe.yaml b/docs/conda-recipe.yaml index 7ba50adb..a26b613b 100644 --- a/docs/conda-recipe.yaml +++ b/docs/conda-recipe.yaml @@ -22,34 +22,34 @@ name: torchopt-docs channels: - pytorch + - nvidia/label/cuda-11.7.1 - defaults - conda-forge dependencies: - - python = 3.8 + - python = 3.9 - pip # Learning - - pytorch::pytorch >= 1.12 + - pytorch::pytorch >= 1.13 # sync with project.dependencies + - pytorch::cpuonly - pytorch::pytorch-mutex = *=*cpu* - pip: - - functorch >= 0.2 - torchviz - sphinxcontrib-katex # for documentation - - tensorboard - - wandb # Build toolchain - - cmake >= 3.4 + - cmake >= 3.11 - make - cxx-compiler - gxx = 10 - - nvidia/label/cuda-11.6.2::cuda-nvcc - - nvidia/label/cuda-11.6.2::cuda-cudart-dev - - pybind11 + - nvidia/label/cuda-11.7.1::cuda-nvcc + - nvidia/label/cuda-11.7.1::cuda-cudart-dev + - pybind11 >= 2.10.1 # Misc - - typing-extensions + - optree >= 0.4.1 + - typing-extensions >= 4.0.0 - numpy - matplotlib-base - seaborn @@ -57,15 +57,16 @@ dependencies: - pillow # Documentation - - sphinx + - sphinx >= 5.2.1 - sphinx_rtd_theme - sphinx-autobuild - sphinx-copybutton - sphinxcontrib-spelling - sphinxcontrib-bibtex - - sphinx-autodoc-typehints + - sphinx-autodoc-typehints >= 1.19.2 - pyenchant + - hunspell-en - myst-nb - ipykernel - pandoc - - docutils = 0.16 + - docutils diff --git a/docs/requirements.txt b/docs/requirements.txt index cdfc5b18..9ac98898 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,20 +1,20 @@ --extra-index-url https://download.pytorch.org/whl/cpu -torch >= 1.12 -functorch >= 0.2 +# Sync with project.dependencies +torch >= 1.13 --requirement ../requirements.txt -sphinx >= 5.0 +sphinx >= 5.2.1 sphinx-autoapi sphinx-autobuild sphinx-copybutton sphinx-rtd-theme sphinxcontrib-katex sphinxcontrib-bibtex -sphinx-autodoc-typehints +sphinx-autodoc-typehints >= 1.19.2 IPython ipykernel pandoc myst_nb -docutils == 0.16 +docutils matplotlib diff --git a/docs/source/api/api.rst b/docs/source/api/api.rst index 545a8d54..27d16a64 100644 --- a/docs/source/api/api.rst +++ b/docs/source/api/api.rst @@ -29,11 +29,18 @@ Functional Optimizers .. autosummary:: + FuncOptimizer adam sgd rmsprop adamw +Wrapper for Function Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: FuncOptimizer + :members: + Functional Adam Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -124,40 +131,92 @@ Differentiable Meta-RMSProp Optimizer ------ +Implicit differentiation +======================== + +.. currentmodule:: torchopt.diff.implicit + +.. autosummary:: + + custom_root + nn.ImplicitMetaGradientModule + +Custom solvers +~~~~~~~~~~~~~~ + +.. autofunction:: custom_root + + +Implicit Meta-Gradient Module +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. currentmodule:: torchopt.diff.implicit.nn + +.. autoclass:: ImplicitMetaGradientModule + :members: + +------ + +Linear system solvers +===================== + +.. currentmodule:: torchopt.linear_solve + +.. autosummary:: + + solve_cg + solve_normal_cg + solve_inv + +Indirect solvers +~~~~~~~~~~~~~~~~ + +.. autofunction:: solve_cg +.. autofunction:: solve_normal_cg +.. autofunction:: solve_inv + +------ + Optimizer Hooks =============== -.. currentmodule:: torchopt._src.hook +.. currentmodule:: torchopt.hook .. autosummary:: register_hook zero_nan_hook + nan_to_num_hook Hook ~~~~ .. autofunction:: register_hook .. autofunction:: zero_nan_hook +.. autofunction:: nan_to_num_hook + +------ Gradient Transformation ======================= -.. currentmodule:: torchopt._src.clip +.. currentmodule:: torchopt .. autosummary:: clip_grad_norm + nan_to_num Transforms ~~~~~~~~~~ .. autofunction:: clip_grad_norm +.. autofunction:: nan_to_num Optimizer Schedules =================== -.. currentmodule:: torchopt._src.schedule +.. currentmodule:: torchopt.schedule .. autosummary:: @@ -188,7 +247,7 @@ Apply Updates Combining Optimizers ==================== -.. currentmodule:: torchopt._src.combine +.. currentmodule:: torchopt.combine .. autosummary:: @@ -230,7 +289,7 @@ Stop Gradient Visualizing Gradient Flow ========================= -.. currentmodule:: torchopt._src.visual +.. currentmodule:: torchopt.visual .. autosummary:: diff --git a/docs/source/conf.py b/docs/source/conf.py index 694086fe..96736ebb 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -25,6 +25,7 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # +import logging import os import pathlib import sys @@ -43,6 +44,24 @@ def get_version() -> str: return version.__version__ +try: + import sphinx_autodoc_typehints +except ImportError: + pass +else: + + class RecursiveForwardRefFilter(logging.Filter): + def filter(self, record): + if ( + "name 'TensorTree' is not defined" in record.getMessage() + or "name 'OptionalTensorTree' is not defined" in record.getMessage() + ): + return False + return super().filter(record) + + sphinx_autodoc_typehints._LOGGER.logger.addFilter(RecursiveForwardRefFilter()) + + # -- Project information ----------------------------------------------------- project = 'TorchOpt' @@ -75,7 +94,7 @@ def get_version() -> str: 'sphinxcontrib.bibtex', 'sphinxcontrib.katex', 'sphinx_autodoc_typehints', - 'myst_nb', # This is used for the .ipynb notebooks + 'myst_nb', # this is used for the .ipynb notebooks ] if not os.getenv('READTHEDOCS', None): @@ -120,6 +139,7 @@ def get_version() -> str: 'exclude-members': '__module__, __dict__, __repr__, __str__, __weakref__', } autoclass_content = 'both' +simplify_optional_unions = False # -- Options for bibtex ----------------------------------------------------- @@ -134,7 +154,7 @@ def get_version() -> str: # See: https://sphinxcontrib-katex.readthedocs.io/en/0.4.1/macros.html latex_macros = r""" - \def \d #1{\operatorname{#1}} + \def \d #1{\operatorname{#1}} """ # Translate LaTeX macros to KaTeX and add to options for HTML builder diff --git a/docs/source/developer/contributing.rst b/docs/source/developer/contributing.rst index 93d0cc50..b4c4c825 100644 --- a/docs/source/developer/contributing.rst +++ b/docs/source/developer/contributing.rst @@ -43,7 +43,7 @@ in the main directory. This installation is removable by: .. code-block:: bash - pip3 uninstall torchopt + make uninstall Lint Check @@ -91,8 +91,8 @@ To build compatible **manylinux2014** (:pep:`599`) wheels for distribution, you pip3 install --upgrade cibuildwheel - export TEST_TORCH_SPECS="cpu cu113 cu116" # `torch` builds for testing - export CUDA_VERSION="11.6" # version of `nvcc` for compilation + export TEST_TORCH_SPECS="cpu cu116" # `torch` builds for testing + export CUDA_VERSION="11.7" # version of `nvcc` for compilation python3 -m cibuildwheel --platform=linux --output-dir=wheelhouse --config-file=pyproject.toml It will install the CUDA compiler with ``CUDA_VERSION`` in the build container. Then build wheel binaries for all supported CPython versions. The outputs will be placed in the ``wheelhouse`` directory. diff --git a/docs/source/examples/MAML.rst b/docs/source/examples/MAML.rst index bba6c35a..ee5a638c 100644 --- a/docs/source/examples/MAML.rst +++ b/docs/source/examples/MAML.rst @@ -1,7 +1,7 @@ Model-Agnostic Meta-Learning ============================ -Meta reinforcement learning has achieved significant successes in various applications. +Meta-reinforcement learning has achieved significant successes in various applications. **Model-Agnostic Meta-Learning** (MAML) :cite:`MAML` is the pioneer one. In this tutorial, we will show how to train MAML on few-shot Omniglot classification with TorchOpt step by step. The full script is at :gitcode:`examples/few-shot/maml_omniglot.py`. @@ -63,16 +63,17 @@ TorchOpt supports any user-defined PyTorch networks. Here is an example: net = nn.Sequential( nn.Conv2d(1, 64, 3), - nn.BatchNorm2d(64, momentum=1., affine=True), + nn.BatchNorm2d(64, momentum=1.0, affine=True), nn.ReLU(inplace=False), nn.MaxPool2d(2, 2), nn.Conv2d(64, 64, 3), - nn.BatchNorm2d(64, momentum=1., affine=True), + nn.BatchNorm2d(64, momentum=1.0, affine=True), nn.ReLU(inplace=False), nn.MaxPool2d(2, 2), nn.Conv2d(64, 64, 3), - nn.BatchNorm2d(64, momentum=1., affine=True), - nn.ReLU(inplace=False), nn.MaxPool2d(2, 2), + nn.BatchNorm2d(64, momentum=1.0, affine=True), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), nn.Flatten(), nn.Linear(64, args.n_way), ).to(device) @@ -98,8 +99,7 @@ Define the ``train`` function: # Sample a batch of support and query images and labels. x_spt, y_spt, x_qry, y_qry = db.next() - task_num, setsz, c_, h, w = x_spt.size() - querysz = x_qry.size(1) + task_num = x_spt.size(0) # TODO: Maybe pull this out into a separate module so it # doesn't have to be duplicated between `train` and `test`? @@ -128,28 +128,24 @@ Define the ``train`` function: # These will be used to update the model's meta-parameters. qry_logits = net(x_qry[i]) qry_loss = F.cross_entropy(qry_logits, y_qry[i]) - qry_losses.append(qry_loss.detach()) - qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).sum().item() / querysz - qry_accs.append(qry_acc) - - # Update the model's meta-parameters to optimize the query - # losses across all of the tasks sampled in this batch. - # This unrolls through the gradient steps. - qry_loss.backward() + qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean() + qry_losses.append(qry_loss) + qry_accs.append(qry_acc.item()) torchopt.recover_state_dict(net, net_state_dict) torchopt.recover_state_dict(inner_opt, optim_state_dict) + qry_losses = torch.mean(torch.stack(qry_losses)) + qry_losses.backward() meta_opt.step() - qry_losses = sum(qry_losses) / task_num - qry_accs = 100. * sum(qry_accs) / task_num + qry_losses = qry_losses.item() + qry_accs = 100.0 * np.mean(qry_accs) i = epoch + float(batch_idx) / n_train_iter iter_time = time.time() - start_time print( f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}' ) - log.append( { 'epoch': i, @@ -183,8 +179,7 @@ Define the ``test`` function: for batch_idx in range(n_test_iter): x_spt, y_spt, x_qry, y_qry = db.next('test') - task_num, setsz, c_, h, w = x_spt.size() - querysz = x_qry.size(1) + task_num = x_spt.size(0) # TODO: Maybe pull this out into a separate module so it # doesn't have to be duplicated between `train` and `test`? @@ -203,15 +198,17 @@ Define the ``test`` function: # The query loss and acc induced by these parameters. qry_logits = net(x_qry[i]).detach() - qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction='none') - qry_losses.append(qry_loss.detach()) - qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach()) + qry_loss = F.cross_entropy(qry_logits, y_qry[i]) + qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean() + qry_losses.append(qry_loss.item()) + qry_accs.append(qry_acc.item()) torchopt.recover_state_dict(net, net_state_dict) torchopt.recover_state_dict(inner_opt, optim_state_dict) - qry_losses = torch.cat(qry_losses).mean().item() - qry_accs = 100. * torch.cat(qry_accs).float().mean().item() + qry_losses = np.mean(qry_losses) + qry_accs = 100.0 * np.mean(qry_accs) + print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}') log.append( { diff --git a/docs/source/index.rst b/docs/source/index.rst index fd488b6e..a4c20e22 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -3,7 +3,7 @@ TorchOpt -------- -**TorchOpt** is a high-performance optimizer library built upon `PyTorch `_ for easy implementation of functional optimization and gradient-based meta-learning. It consists of two main features: +**TorchOpt** is a high-performance optimizer library built upon `PyTorch `_ for easy implementation of functional optimization and gradient-based meta-learning. It consists of two main features: * TorchOpt provides functional optimizer which enables `JAX-like `_ composable functional optimizer for PyTorch. With TorchOpt, one can easily conduct neural network optimization in PyTorch with functional style optimizer, similar to `Optax `_ in JAX. * With the design of functional programming, TorchOpt provides efficient, flexible, and easy-to-implement differentiable optimizer for gradient-based meta-learning research. It largely reduces the efforts required to implement sophisticated meta-learning algorithms. @@ -13,8 +13,8 @@ Installation Requirements: -* `PyTorch `_ -* (Optional) `Graphviz `_ +* `PyTorch `_ +* (Optional) `Graphviz `_ Please follow the instructions at https://pytorch.org to install PyTorch in your Python environment first. Then run the following command to install TorchOpt from PyPI: @@ -38,37 +38,37 @@ We provide a `conda `_ environment recipe to ins cd torchopt # You may need `CONDA_OVERRIDE_CUDA` if conda fails to detect the NVIDIA driver (e.g. in docker or WSL2) - CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe.yaml + CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe-minimal.yaml conda activate torchopt .. toctree:: - :caption: Getting Started - :maxdepth: 1 + :caption: Getting Started + :maxdepth: 1 - torchopt101/torchopt-101.rst + torchopt101/torchopt-101.rst .. toctree:: - :caption: Examples - :maxdepth: 1 + :caption: Examples + :maxdepth: 1 - examples/MAML.rst + examples/MAML.rst .. toctree:: - :caption: Developer Documentation - :maxdepth: 1 + :caption: Developer Documentation + :maxdepth: 1 - developer/contributing.rst - developer/contributor.rst + developer/contributing.rst + developer/contributor.rst .. toctree:: - :caption: API Documentation - :maxdepth: 2 + :caption: API Documentation + :maxdepth: 2 - api/api.rst + api/api.rst The Team -------- @@ -97,3 +97,23 @@ License ------- TorchOpt is licensed under the Apache 2.0 License. + +Citing +------ + +If you find TorchOpt useful, please cite it in your publications. + +.. code-block:: bibtex + + @article{torchopt, + title = {TorchOpt: An Efficient Library for Differentiable Optimization}, + author = {Ren, Jie and Feng, Xidong and Liu, Bo and Pan, Xuehai and Fu, Yao and Mai, Luo and Yang, Yaodong}, + journal = {arXiv preprint arXiv:2211.06934}, + year = {2022} + } + + +Indices and tables +================== + +* :ref:`genindex` diff --git a/docs/source/spelling_wordlist.txt b/docs/source/spelling_wordlist.txt index ca34dd05..e76966ef 100644 --- a/docs/source/spelling_wordlist.txt +++ b/docs/source/spelling_wordlist.txt @@ -26,7 +26,7 @@ Pan Yao Fu Jupyter -Colaboratory +Colab Omniglot differentiable Dataset @@ -56,10 +56,12 @@ iterable nan param Graphviz +Autograd autograd attrs GradientTransformations args +kwargs chainable adam Adam @@ -78,3 +80,67 @@ Loshchilov pytree booleans subtrees +optimality +argnums +matvec +hermitian +deepcopy +deepclone +RRef +rref +ints +Karush +Kuhn +Tucker +Neumann +KKT +num +posinf +neginf +backpropagated +backpropagating +backpropagation +backprop +fmt +pragma +broadcasted +keepdim +ndim +partitioner +RPC +maxiter +str +bool +algo +const +attr +sys +recurse +boldsymbol +optim +optimizer's +stateful +preload +submodules +prepend +jit +compilable +RMS +LLC +ns +th +treespec +namespace +atol +rtol +pre +numerics +parallelize +parallelizing +Optax +func +subfn +vjp +jvp +ATen +samplable diff --git a/docs/source/torchopt101/torchopt-101.rst b/docs/source/torchopt101/torchopt-101.rst index 87bffd4c..89809691 100644 --- a/docs/source/torchopt101/torchopt-101.rst +++ b/docs/source/torchopt101/torchopt-101.rst @@ -1,9 +1,11 @@ Get Started with Jupyter Notebook ================================= -In this tutorial, we will use Google Colaboratory to show you the most basic usages of TorchOpt. +In this tutorial, we will use Google Colab notebooks to show you the most basic usages of TorchOpt. -- 1: `Functional Optimizer `_ -- 2: `Visualization `_ -- 3: `Meta Optimizer `_ -- 4: `Stop Gradient `_ +- 1: `Functional Optimizer `_ +- 2: `Visualization `_ +- 3: `Meta-Optimizer `_ +- 4: `Stop Gradient `_ +- 5: `Implicit Differentiation `_ +- 6: `Zero-order Differentiation `_ diff --git a/examples/FuncTorch/maml_omniglot_vmap.py b/examples/FuncTorch/maml_omniglot_vmap.py index 9bbb30ce..41c17db8 100644 --- a/examples/FuncTorch/maml_omniglot_vmap.py +++ b/examples/FuncTorch/maml_omniglot_vmap.py @@ -39,16 +39,10 @@ https://github.com/bamos/HowToTrainYourMAMLPytorch """ - -import os -import sys - - -cur = os.path.abspath(os.path.dirname(__file__)) -root = os.path.split(cur)[0] -sys.path.append(root + '/few-shot') import argparse import functools +import pathlib +import sys import time import functorch @@ -59,12 +53,17 @@ import torch import torch.nn.functional as F import torch.optim as optim -from support.omniglot_loaders import OmniglotNShot from torch import nn import torchopt +CWD = pathlib(__file__).absolute().parent +sys.path.append(str(CWD.parent / 'few-shot')) + +from helpers.omniglot_loaders import OmniglotNShot + + mpl.use('Agg') plt.style.use('bmh') @@ -148,8 +147,6 @@ def loss_for_task(net, n_inner_iter, x_spt, y_spt, x_qry, y_qry): opt = torchopt.sgd(lr=1e-1) opt_state = opt.init(params) - querysz = x_qry.size(0) - def compute_loss(new_params, buffers, x, y): logits = fnet(new_params, buffers, x) loss = F.cross_entropy(logits, y) @@ -167,7 +164,7 @@ def compute_loss(new_params, buffers, x, y): # These will be used to update the model's meta-parameters. qry_logits = fnet(new_params, buffers, x_qry) qry_loss = F.cross_entropy(qry_logits, y_qry) - qry_acc = (qry_logits.argmax(dim=1) == y_qry).sum() / querysz + qry_acc = (qry_logits.argmax(dim=1) == y_qry).float().mean() return qry_loss, qry_acc @@ -192,18 +189,19 @@ def train(db, net, device, meta_opt, epoch, log): qry_losses, qry_accs = functorch.vmap(compute_loss_for_task)(x_spt, y_spt, x_qry, y_qry) # Compute the maml loss by summing together the returned losses. - qry_losses.sum().backward() - + qry_losses = torch.mean(torch.stack(qry_losses)) + qry_losses.backward() meta_opt.step() - qry_losses = qry_losses.detach().sum() / task_num - qry_accs = 100.0 * qry_accs.sum() / task_num + qry_losses = qry_losses.item() + qry_accs = 100.0 * torch.mean(torch.stack(qry_accs)).item() i = epoch + float(batch_idx) / n_train_iter iter_time = time.time() - start_time + torch.cuda.empty_cache() + if batch_idx % 4 == 0: print( f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}' ) - log.append( { 'epoch': i, @@ -249,8 +247,10 @@ def test(db, net, device, epoch, log): qry_losses.append(qry_loss.detach()) qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach()) - qry_losses = torch.cat(qry_losses).mean().item() - qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item() + qry_losses = torch.mean(torch.stack(qry_losses)).item() + qry_accs = 100.0 * torch.mean(torch.stack(qry_accs)).item() + torch.cuda.empty_cache() + print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}') log.append( { diff --git a/examples/L2R/helper/argument.py b/examples/L2R/helpers/argument.py similarity index 100% rename from examples/L2R/helper/argument.py rename to examples/L2R/helpers/argument.py diff --git a/examples/L2R/helper/model.py b/examples/L2R/helpers/model.py similarity index 100% rename from examples/L2R/helper/model.py rename to examples/L2R/helpers/model.py diff --git a/examples/L2R/helper/utils.py b/examples/L2R/helpers/utils.py similarity index 100% rename from examples/L2R/helper/utils.py rename to examples/L2R/helpers/utils.py diff --git a/examples/L2R/l2r.py b/examples/L2R/l2r.py index cd093313..e77faa14 100644 --- a/examples/L2R/l2r.py +++ b/examples/L2R/l2r.py @@ -39,9 +39,9 @@ # isort: off -from helper.argument import parse_args -from helper.model import LeNet5 -from helper.utils import get_imbalance_dataset, plot, set_seed +from helpers.argument import parse_args +from helpers.model import LeNet5 +from helpers.utils import get_imbalance_dataset, plot, set_seed def run_baseline(args, mnist_train, mnist_test): @@ -74,7 +74,7 @@ def run_baseline(args, mnist_train, mnist_test): test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=True, num_workers=1) model = LeNet5(args).to(args.device) - model_optimiser = torch.optim.Adam(model.parameters(), lr=args.lr) + model_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) step = 0 running_train_loss = [] @@ -85,9 +85,9 @@ def run_baseline(args, mnist_train, mnist_test): train_x, train_label = train_x.to(args.device), train_label.to(args.device) outer_loss = model.outer_loss(train_x, train_label) - model_optimiser.zero_grad() + model_optimizer.zero_grad() outer_loss.backward() - model_optimiser.step() + model_optimizer.step() running_train_loss.append(outer_loss.item()) writer.add_scalar('train_loss', outer_loss.item(), step) @@ -142,8 +142,8 @@ def run_L2R(args, mnist_train, mnist_test): valid_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=True, num_workers=1) test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=True, num_workers=1) model = LeNet5(args).to(args.device) - model_optimiser = torchopt.MetaSGD(model, lr=args.lr) - real_model_optimiser = torch.optim.Adam(model.parameters(), lr=args.lr) + model_optimizer = torchopt.MetaSGD(model, lr=args.lr) + real_model_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) step = 0 time_bp = 0 @@ -170,11 +170,11 @@ def run_L2R(args, mnist_train, mnist_test): model.reset_meta(size=train_x.size(0)) net_state_dict = torchopt.extract_state_dict(model) - optim_state_dict = torchopt.extract_state_dict(model_optimiser) + optim_state_dict = torchopt.extract_state_dict(model_optimizer) for _ in range(1): inner_loss = model.inner_loss(train_x, train_label) - model_optimiser.step(inner_loss) + model_optimizer.step(inner_loss) # caclulate outer_loss, deirve meta-gradient and normalise outer_loss = model.outer_loss(valid_x, valid_label) @@ -186,17 +186,17 @@ def run_L2R(args, mnist_train, mnist_test): running_valid_loss.append(outer_loss.item()) writer.add_scalar('validation_loss', outer_loss.item(), step) - # reset the model and model optimiser + # reset the model and model optimizer torchopt.recover_state_dict(model, net_state_dict) - torchopt.recover_state_dict(model_optimiser, optim_state_dict) + torchopt.recover_state_dict(model_optimizer, optim_state_dict) # reuse inner_adapt to conduct real update based on learned meta weights inner_loss = model.inner_loss(train_x, train_label) for _ in range(1): inner_loss = model.inner_loss(train_x, train_label) - real_model_optimiser.zero_grad() + real_model_optimizer.zero_grad() inner_loss.backward() - real_model_optimiser.step() + real_model_optimizer.step() running_train_loss.append(inner_loss.item()) writer.add_scalar('weighted_train_loss', inner_loss.item(), step) diff --git a/examples/LOLA/helper/agent.py b/examples/LOLA/helpers/agent.py similarity index 96% rename from examples/LOLA/helper/agent.py rename to examples/LOLA/helpers/agent.py index 8b30a983..3b37daf2 100644 --- a/examples/LOLA/helper/agent.py +++ b/examples/LOLA/helpers/agent.py @@ -44,7 +44,7 @@ def __init__(self, args): def set_virtual(self): self.virtual_theta = theta_model(self.theta) - self.virtual_optimiser = torchopt.MetaSGD(self.virtual_theta, lr=self.args.lr_in) + self.virtual_optimizer = torchopt.MetaSGD(self.virtual_theta, lr=self.args.lr_in) def value_update(self, loss): self.value_optimizer.zero_grad() diff --git a/examples/LOLA/helper/argument.py b/examples/LOLA/helpers/argument.py similarity index 100% rename from examples/LOLA/helper/argument.py rename to examples/LOLA/helpers/argument.py diff --git a/examples/LOLA/helper/env.py b/examples/LOLA/helpers/env.py similarity index 100% rename from examples/LOLA/helper/env.py rename to examples/LOLA/helpers/env.py diff --git a/examples/LOLA/helper/utils.py b/examples/LOLA/helpers/utils.py similarity index 100% rename from examples/LOLA/helper/utils.py rename to examples/LOLA/helpers/utils.py diff --git a/examples/LOLA/lola_dice.py b/examples/LOLA/lola_dice.py index 61d2e22c..4b6b2567 100644 --- a/examples/LOLA/lola_dice.py +++ b/examples/LOLA/lola_dice.py @@ -21,10 +21,10 @@ # isort: off -from helper.agent import Agent -from helper.argument import parse_args -from helper.env import IPD -from helper.utils import sample, step +from helpers.agent import Agent +from helpers.argument import parse_args +from helpers.env import IPD +from helpers.utils import sample, step def main(args): @@ -49,7 +49,7 @@ def main(args): args, ) inner_loss = memory1.dice_objective(use_baseline=args.use_baseline) - agent1.virtual_optimiser.step(inner_loss) + agent1.virtual_optimizer.step(inner_loss) # agent 1 assumes that agent 2 conducts n-step lookahead for _ in range(n_lookaheads): @@ -60,7 +60,7 @@ def main(args): args, ) inner_loss = memory2.dice_objective(use_baseline=args.use_baseline) - agent2.virtual_optimiser.step(inner_loss) + agent2.virtual_optimizer.step(inner_loss) # update agent 1 memory1, memory2 = sample( diff --git a/examples/MAML-RL/func_maml.py b/examples/MAML-RL/func_maml.py new file mode 100644 index 00000000..6413cc71 --- /dev/null +++ b/examples/MAML-RL/func_maml.py @@ -0,0 +1,196 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import argparse +from typing import NamedTuple + +import functorch +import gym +import numpy as np +import torch +import torch.optim as optim + +import torchopt +from helpers.policy import CategoricalMLPPolicy + + +TASK_NUM = 40 +TRAJ_NUM = 20 +TRAJ_LEN = 10 + +STATE_DIM = 10 +ACTION_DIM = 5 + +GAMMA = 0.99 +LAMBDA = 0.95 + +outer_iters = 500 +inner_iters = 1 + + +class Traj(NamedTuple): + obs: np.ndarray + acs: np.ndarray + next_obs: np.ndarray + rews: np.ndarray + gammas: np.ndarray + + +def sample_traj(env, task, fpolicy, params): + env.reset_task(task) + obs_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM, STATE_DIM), dtype=np.float32) + next_obs_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM, STATE_DIM), dtype=np.float32) + acs_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM), dtype=np.int8) + rews_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM), dtype=np.float32) + gammas_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM), dtype=np.float32) + with torch.no_grad(): + for batch in range(TRAJ_NUM): + ob = env.reset() + for step in range(TRAJ_LEN): + ob_tensor = torch.from_numpy(ob) + pi, _ = fpolicy(params, ob_tensor) + ac_tensor = pi.sample() + ac = ac_tensor.cpu().numpy() + next_ob, rew, done, info = env.step(ac) + + obs_buf[step][batch] = ob + next_obs_buf[step][batch] = next_ob + acs_buf[step][batch] = ac + rews_buf[step][batch] = rew + gammas_buf[step][batch] = done * GAMMA + ob = next_ob + return Traj(obs=obs_buf, acs=acs_buf, next_obs=next_obs_buf, rews=rews_buf, gammas=gammas_buf) + + +def a2c_loss(traj, fpolicy, params, value_coef): + lambdas = np.ones_like(traj.gammas) * LAMBDA + _, next_values = fpolicy(params, torch.from_numpy(traj.next_obs)) + next_values = torch.squeeze(next_values, -1).detach().numpy() + # Work backwards to compute `G_{T-1}`, ..., `G_0`. + returns = [] + g = next_values[-1, :] + for i in reversed(range(next_values.shape[0])): + g = traj.rews[i, :] + traj.gammas[i, :] * ( + (1 - lambdas[i, :]) * next_values[i, :] + lambdas[i, :] * g + ) + returns.insert(0, g) + lambda_returns = torch.from_numpy(np.array(returns)) + pi, values = fpolicy(params, torch.from_numpy(traj.obs)) + log_probs = pi.log_prob(torch.from_numpy(traj.acs)) + advs = lambda_returns - torch.squeeze(values, -1) + action_loss = -(advs.detach() * log_probs).mean() + value_loss = advs.pow(2).mean() + + loss = action_loss + value_coef * value_loss + return loss + + +def evaluate(env, seed, task_num, fpolicy, params): + pre_reward_ls = [] + post_reward_ls = [] + inner_opt = torchopt.MetaSGD(lr=0.5) + env = gym.make( + 'TabularMDP-v0', + **dict( + num_states=STATE_DIM, num_actions=ACTION_DIM, max_episode_steps=TRAJ_LEN, seed=args.seed + ), + ) + tasks = env.sample_tasks(num_tasks=task_num) + + for idx in range(task_num): + for _ in range(inner_iters): + pre_trajs = sample_traj(env, tasks[idx], fpolicy, params) + + inner_loss = a2c_loss(pre_trajs, fpolicy, params, value_coef=0.5) + params = inner_opt.step(inner_loss, params) + post_trajs = sample_traj(env, tasks[idx], fpolicy, params) + + # Logging + pre_reward_ls.append(np.sum(pre_trajs.rews, axis=0).mean()) + post_reward_ls.append(np.sum(post_trajs.rews, axis=0).mean()) + + return pre_reward_ls, post_reward_ls + + +def main(args): + # init training + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + # Env + env = gym.make( + 'TabularMDP-v0', + **dict( + num_states=STATE_DIM, num_actions=ACTION_DIM, max_episode_steps=TRAJ_LEN, seed=args.seed + ), + ) + # Policy + policy = CategoricalMLPPolicy(input_size=STATE_DIM, output_size=ACTION_DIM) + fpolicy, params = functorch.make_functional(policy) + + inner_opt = torchopt.MetaSGD(lr=0.5) + outer_opt = optim.Adam(params, lr=1e-3) + train_pre_reward = [] + train_post_reward = [] + test_pre_reward = [] + test_post_reward = [] + + for i in range(outer_iters): + tasks = env.sample_tasks(num_tasks=TASK_NUM) + train_pre_reward_ls = [] + train_post_reward_ls = [] + + outer_opt.zero_grad() + + param_orig = [p.detach().clone().requires_grad_() for p in params] + _params = list(params) + for idx in range(TASK_NUM): + + for _ in range(inner_iters): + pre_trajs = sample_traj(env, tasks[idx], fpolicy, _params) + inner_loss = a2c_loss(pre_trajs, fpolicy, _params, value_coef=0.5) + _params = inner_opt.step(inner_loss, _params) + post_trajs = sample_traj(env, tasks[idx], fpolicy, _params) + outer_loss = a2c_loss(post_trajs, fpolicy, _params, value_coef=0.5) + outer_loss.backward() + _params = [p.detach().clone().requires_grad_() for p in param_orig] + + # Logging + train_pre_reward_ls.append(np.sum(pre_trajs.rews, axis=0).mean()) + train_post_reward_ls.append(np.sum(post_trajs.rews, axis=0).mean()) + outer_opt.step() + + test_pre_reward_ls, test_post_reward_ls = evaluate( + env, args.seed, TASK_NUM, fpolicy, params + ) + + train_pre_reward.append(sum(train_pre_reward_ls) / TASK_NUM) + train_post_reward.append(sum(train_post_reward_ls) / TASK_NUM) + test_pre_reward.append(sum(test_pre_reward_ls) / TASK_NUM) + test_post_reward.append(sum(test_post_reward_ls) / TASK_NUM) + + print('Train_iters', i) + print('train_pre_reward', sum(train_pre_reward_ls) / TASK_NUM) + print('train_post_reward', sum(train_post_reward_ls) / TASK_NUM) + print('test_pre_reward', sum(test_pre_reward_ls) / TASK_NUM) + print('test_post_reward', sum(test_post_reward_ls) / TASK_NUM) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Reinforcement learning with Model-Agnostic Meta-Learning (MAML) - Train' + ) + parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)') + args = parser.parse_args() + main(args) diff --git a/examples/MAML-RL/maml.py b/examples/MAML-RL/maml.py index f2bb38e9..447f540e 100644 --- a/examples/MAML-RL/maml.py +++ b/examples/MAML-RL/maml.py @@ -99,8 +99,9 @@ def a2c_loss(traj, policy, value_coef): advs = lambda_returns - torch.squeeze(values, -1) action_loss = -(advs.detach() * log_probs).mean() value_loss = advs.pow(2).mean() - a2c_loss = action_loss + value_coef * value_loss - return a2c_loss + + loss = action_loss + value_coef * value_loss + return loss def evaluate(env, seed, task_num, policy): diff --git a/examples/distributed/few-shot/README.md b/examples/distributed/few-shot/README.md new file mode 100644 index 00000000..a0a758fa --- /dev/null +++ b/examples/distributed/few-shot/README.md @@ -0,0 +1,18 @@ +# MAML few-shot Omniglot classification-examples + +Code on MAML few-shot Omniglot classification in paper [Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks](https://arxiv.org/abs/1703.03400) using TorchOpt. We use `MetaSGD` as the inner-loop optimizer. + +## Usage + +```bash +### Run +torchrun --nnode 1 --nproc_per_node 8 maml_omniglot.py +``` + +## Results + +The figure illustrate the experimental result. + +
+ +
diff --git a/examples/few-shot/support/omniglot_loaders.py b/examples/distributed/few-shot/helpers/omniglot_loaders.py similarity index 100% rename from examples/few-shot/support/omniglot_loaders.py rename to examples/distributed/few-shot/helpers/omniglot_loaders.py diff --git a/examples/distributed/few-shot/maml-accs.png b/examples/distributed/few-shot/maml-accs.png new file mode 100644 index 00000000..8d70607c Binary files /dev/null and b/examples/distributed/few-shot/maml-accs.png differ diff --git a/examples/distributed/few-shot/maml_omniglot.py b/examples/distributed/few-shot/maml_omniglot.py new file mode 100644 index 00000000..879792ff --- /dev/null +++ b/examples/distributed/few-shot/maml_omniglot.py @@ -0,0 +1,315 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/facebookresearch/higher/blob/main/examples/maml-omniglot.py +# ============================================================================== +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This example shows how to use TorchOpt to do Model Agnostic Meta Learning (MAML) +for few-shot Omniglot classification. +For more details see the original MAML paper: +https://arxiv.org/abs/1703.03400 +This code has been modified from Jackie Loong's PyTorch MAML implementation: +https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py +Our MAML++ fork and experiments are available at: +https://github.com/bamos/HowToTrainYourMAMLPytorch +""" + +import argparse +import os +import random +import time + +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from setproctitle import getproctitle, setproctitle + +import torchopt +import torchopt.distributed as todist + + +from helpers.omniglot_loaders import OmniglotNShot # isort: skip + + +mpl.use('Agg') +plt.style.use('bmh') + + +def worker_init(): + world_info = todist.get_world_info() + + proctitle = f'{world_info.worker_name}: {getproctitle().strip()}' + print(f'Worker init:=> {proctitle}') + setproctitle(proctitle) + + seed = world_info.local_rank + + os.environ['PYTHONHASHSEED'] = str(seed) + + random.seed(seed) + np.random.seed(seed) + + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + if world_info.local_rank < torch.cuda.device_count(): + torch.cuda.set_device(world_info.local_rank) + + +def build_model(args, device): + return nn.Sequential( + nn.Conv2d(1, 64, 3), + nn.BatchNorm2d(64, momentum=1.0, affine=True), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, momentum=1.0, affine=True), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, momentum=1.0, affine=True), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), + nn.Flatten(), + nn.Linear(64, args.n_way), + ).to(device) + + +@todist.rank_zero_only +def get_data_loader(args, device): + rng = np.random.default_rng(args.seed) + + return OmniglotNShot( + '/tmp/omniglot-data', + batchsz=args.task_num, + n_way=args.n_way, + k_shot=args.k_spt, + k_query=args.k_qry, + imgsz=28, + rng=rng, + device=device, + ) + + +@todist.auto_init_rpc(worker_init) +def main(): + argparser = argparse.ArgumentParser() + argparser.add_argument('--n_way', type=int, help='n way', default=5) + argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=5) + argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15) + argparser.add_argument( + '--task_num', type=int, help='meta batch size, namely task num', default=32 + ) + argparser.add_argument('--seed', type=int, help='random seed', default=1) + args = argparser.parse_args() + + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + np.random.seed(args.seed) + + # Set up the Omniglot loader. + db = get_data_loader(args, device=torch.device('cpu')) + + # Create a vanilla PyTorch neural network. + net = build_model(args, device=torch.device('cpu')) + + # We will use Adam to (meta-)optimize the initial parameters + # to be adapted. + meta_opt = optim.Adam(net.parameters(), lr=1e-3) + + log = [] + test(db, net, epoch=-1, log=log) + for epoch in range(10): + train(db, net, meta_opt, epoch=epoch, log=log) + test(db, net, epoch=epoch, log=log) + plot(log) + + +def transpose_mean_reducer(results): + qry_losses, qry_accs = tuple(zip(*results)) + qry_loss = torch.mean(torch.stack(qry_losses)) + qry_acc = np.mean(qry_accs) + return qry_loss, qry_acc + + +@todist.parallelize( + partitioner=todist.dim_partitioner(dim=0, exclusive=True, keepdim=False), + reducer=transpose_mean_reducer, +) +def inner_loop(net_rref, x_spt, y_spt, x_qry, y_qry, n_inner_iter): + if torch.cuda.is_available(): + device = torch.device(f'cuda:{todist.get_local_rank() % torch.cuda.device_count()}') + torch.cuda.set_device(device) + else: + device = None + + original_net = net_rref.to_here() + # The local net can be shared across multiple RPC calls on the current worker + # We need to detach the buffers to avoid sharing the same buffers across + net = torchopt.module_clone(original_net, by='reference', detach_buffers=True, device=device) + if device is not None: + x_spt = x_spt.to(device) + y_spt = y_spt.to(device) + x_qry = x_qry.to(device) + y_qry = y_qry.to(device) + + inner_opt = torchopt.MetaSGD(net, lr=1e-1) + + for _ in range(n_inner_iter): + spt_logits = net(x_spt) + spt_loss = F.cross_entropy(spt_logits, y_spt) + inner_opt.step(spt_loss) + + qry_logits = net(x_qry) + qry_loss = F.cross_entropy(qry_logits, y_qry).cpu() + qry_acc = (qry_logits.argmax(dim=1) == y_qry).float().mean().item() + + return qry_loss, qry_acc + + +@todist.rank_zero_only +def train(db: OmniglotNShot, net: nn.Module, meta_opt: optim.Adam, epoch: int, log: list): + net.train() + n_train_iter = db.x_train.shape[0] // db.batchsz + + for batch_idx in range(n_train_iter): + start_time = time.time() + # Sample a batch of support and query images and labels. + x_spt, y_spt, x_qry, y_qry = db.next() + + # TODO: Maybe pull this out into a separate module so it + # doesn't have to be duplicated between `train` and `test`? + + # Initialize the inner optimizer to adapt the parameters to + # the support set. + n_inner_iter = 5 + + meta_opt.zero_grad() + # Sending modules contains nn.Parameter will detach from the current computation graph + # Here we explicitly convert the parameters to tensors with `CloneBackward` + net_rref = todist.rpc.RRef(torchopt.module_clone(net, by='copy')) + with todist.autograd.context() as context_id: + qry_loss, qry_acc = inner_loop(net_rref, x_spt, y_spt, x_qry, y_qry, n_inner_iter) + todist.autograd.backward(context_id, qry_loss) + meta_opt.step() + + qry_loss = qry_loss.item() + qry_acc = 100.0 * qry_acc + i = epoch + float(batch_idx) / n_train_iter + iter_time = time.time() - start_time + torch.cuda.empty_cache() + + print( + f'[Epoch {i:.2f}] Train Loss: {qry_loss:.2f} | Acc: {qry_acc:.2f} | Time: {iter_time:.2f}' + ) + log.append( + { + 'epoch': i, + 'loss': qry_loss, + 'acc': qry_acc, + 'mode': 'train', + 'time': time.time(), + } + ) + + +@todist.rank_zero_only +def test(db, net, epoch, log): + # Crucially in our testing procedure here, we do *not* fine-tune + # the model during testing for simplicity. + # Most research papers using MAML for this task do an extra + # stage of fine-tuning here that should be added if you are + # adapting this code for research. + net.train() + n_test_iter = db.x_test.shape[0] // db.batchsz + + qry_losses = [] + qry_accs = [] + + net_rref = todist.rpc.RRef(net) + for _ in range(n_test_iter): + x_spt, y_spt, x_qry, y_qry = db.next('test') + + # TODO: Maybe pull this out into a separate module so it + # doesn't have to be duplicated between `train` and `test`? + n_inner_iter = 5 + + qry_loss, qry_acc = inner_loop(net_rref, x_spt, y_spt, x_qry, y_qry, n_inner_iter) + qry_losses.append(qry_loss.item()) + qry_accs.append(qry_acc) + + qry_losses = np.mean(qry_losses) + qry_accs = 100.0 * np.mean(qry_accs) + torch.cuda.empty_cache() + + print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}') + log.append( + { + 'epoch': epoch + 1, + 'loss': qry_losses, + 'acc': qry_accs, + 'mode': 'test', + 'time': time.time(), + } + ) + + +@todist.rank_zero_only +def plot(log): + # Generally you should pull your plotting code out of your training + # script but we are doing it here for brevity. + df = pd.DataFrame(log) + + fig, ax = plt.subplots(figsize=(8, 4), dpi=250) + train_df = df[df['mode'] == 'train'] + test_df = df[df['mode'] == 'test'] + ax.plot(train_df['epoch'], train_df['acc'], label='Train') + ax.plot(test_df['epoch'], test_df['acc'], label='Test') + ax.set_xlabel('Epoch') + ax.set_ylabel('Accuracy') + ax.set_ylim(85, 100) + ax.set_title('Distributed MAML Omniglot') + ax.legend(ncol=2, loc='lower right') + fig.tight_layout() + fname = 'maml-accs.png' + print(f'--- Plotting accuracy to {fname}') + fig.savefig(fname) + plt.close(fig) + + +if __name__ == '__main__': + main() diff --git a/examples/distributed/few-shot/maml_omniglot_local_loader.py b/examples/distributed/few-shot/maml_omniglot_local_loader.py new file mode 100644 index 00000000..f7f9e4f0 --- /dev/null +++ b/examples/distributed/few-shot/maml_omniglot_local_loader.py @@ -0,0 +1,359 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/facebookresearch/higher/blob/main/examples/maml-omniglot.py +# ============================================================================== +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This example shows how to use TorchOpt to do Model Agnostic Meta Learning (MAML) +for few-shot Omniglot classification. +For more details see the original MAML paper: +https://arxiv.org/abs/1703.03400 +This code has been modified from Jackie Loong's PyTorch MAML implementation: +https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py +Our MAML++ fork and experiments are available at: +https://github.com/bamos/HowToTrainYourMAMLPytorch +""" + +import argparse +import copy +import os +import random +import threading +import time + +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from setproctitle import getproctitle, setproctitle + +import torchopt +import torchopt.distributed as todist + + +from helpers.omniglot_loaders import OmniglotNShot # isort: skip + + +mpl.use('Agg') +plt.style.use('bmh') + + +LOCK = threading.Lock() +LOCAL_DATA_LOADER = None +TASK_DATA_LOADERS = {} +LOCAL_DEVICE = None + + +def worker_init(): + global LOCAL_DEVICE + + world_info = todist.get_world_info() + + proctitle = f'{world_info.worker_name}: {getproctitle().strip()}' + print(f'Worker init:=> {proctitle}') + setproctitle(proctitle) + + seed = world_info.world_rank + local_rank = world_info.local_rank + + os.environ['PYTHONHASHSEED'] = str(seed) + + random.seed(seed) + np.random.seed(seed) + + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + if world_info.local_rank < torch.cuda.device_count(): + torch.cuda.set_device(world_info.local_rank) + + if torch.cuda.is_available(): + device = torch.device(f'cuda:{local_rank % torch.cuda.device_count()}') + torch.cuda.set_device(device) + else: + device = None + LOCAL_DEVICE = device + + +def build_model(args, device): + return nn.Sequential( + nn.Conv2d(1, 64, 3), + nn.BatchNorm2d(64, momentum=1.0, affine=True), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, momentum=1.0, affine=True), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, momentum=1.0, affine=True), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), + nn.Flatten(), + nn.Linear(64, args.n_way), + ).to(device) + + +def set_local_data_loader(args, device): + global LOCAL_DATA_LOADER + + if LOCAL_DATA_LOADER is None: + rng = np.random.default_rng(args.seed) + + with LOCK: + LOCAL_DATA_LOADER = OmniglotNShot( + '/tmp/omniglot-data', + batchsz=args.task_num, + n_way=args.n_way, + k_shot=args.k_spt, + k_query=args.k_qry, + imgsz=28, + rng=rng, + device=device, + ) + + return LOCAL_DATA_LOADER + + +def get_next_batch(task_id, mode): + assert LOCAL_DATA_LOADER is not None + + if task_id not in TASK_DATA_LOADERS: + with LOCK: + TASK_DATA_LOADERS[task_id] = copy.deepcopy(LOCAL_DATA_LOADER) + + db = TASK_DATA_LOADERS[task_id] + x_spt, y_spt, x_qry, y_qry = db.next(mode) + x_spt, y_spt, x_qry, y_qry = x_spt[task_id], y_spt[task_id], x_qry[task_id], y_qry[task_id] + return x_qry, y_qry, x_spt, y_spt + + +@todist.auto_init_rpc(worker_init) +def main(): + argparser = argparse.ArgumentParser() + argparser.add_argument('--n_way', type=int, help='n way', default=5) + argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=5) + argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15) + argparser.add_argument( + '--task_num', type=int, help='meta batch size, namely task num', default=32 + ) + argparser.add_argument('--seed', type=int, help='random seed', default=1) + args = argparser.parse_args() + + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + np.random.seed(args.seed) + + # Set up the Omniglot loader. + set_local_data_loader(args, device=LOCAL_DEVICE) + todist.barrier() # ensure that all workers have loaded the data + + # Create a vanilla PyTorch neural network. + net = build_model(args, device=torch.device('cpu')) + + # We will use Adam to (meta-)optimize the initial parameters + # to be adapted. + meta_opt = optim.Adam(net.parameters(), lr=1e-3) + + log = [] + test(net, epoch=-1, log=log) + for epoch in range(10): + train(net, meta_opt, epoch=epoch, log=log) + test(net, epoch=epoch, log=log) + plot(log) + + +def args_replicator(net_rref, n_inner_iter, task_id, task_num, mode): + del task_id + num_workers = todist.get_world_size() + return [ + (task_id % num_workers, (net_rref, n_inner_iter, task_id, task_num, mode), None) + for task_id in range(task_num) + ] + + +def transpose_mean_reducer(results): + qry_losses, qry_accs = tuple(zip(*results)) + qry_loss = torch.mean(torch.stack(qry_losses)) + qry_acc = np.mean(qry_accs) + return qry_loss, qry_acc + + +@todist.parallelize(partitioner=args_replicator, reducer=transpose_mean_reducer) +def inner_loop(net_rref, n_inner_iter, task_id, task_num, mode): + device = LOCAL_DEVICE + + original_net = net_rref.to_here() + # The local net can be shared across multiple RPC calls on the current worker + # We need to detach the buffers to avoid sharing the same buffers across + net = torchopt.module_clone(original_net, by='reference', detach_buffers=True, device=device) + + x_spt, y_spt, x_qry, y_qry = get_next_batch(task_id, mode) + if device is not None: + x_spt = x_spt.to(device) + y_spt = y_spt.to(device) + x_qry = x_qry.to(device) + y_qry = y_qry.to(device) + + inner_opt = torchopt.MetaSGD(net, lr=1e-1) + + for _ in range(n_inner_iter): + spt_logits = net(x_spt) + spt_loss = F.cross_entropy(spt_logits, y_spt) + inner_opt.step(spt_loss) + + qry_logits = net(x_qry) + qry_loss = F.cross_entropy(qry_logits, y_qry).cpu() + qry_acc = (qry_logits.argmax(dim=1) == y_qry).float().mean().item() + + return qry_loss, qry_acc + + +@todist.rank_zero_only +def train(net: nn.Module, meta_opt: optim.Adam, epoch: int, log: list): + net.train() + + db = LOCAL_DATA_LOADER + n_train_iter = db.x_train.shape[0] // db.batchsz + task_num = db.x_train.shape[1] + + net_rref = todist.rpc.RRef(net) + for batch_idx in range(n_train_iter): + start_time = time.time() + + # TODO: Maybe pull this out into a separate module so it + # doesn't have to be duplicated between `train` and `test`? + + # Initialize the inner optimizer to adapt the parameters to + # the support set. + n_inner_iter = 5 + + meta_opt.zero_grad() + # Sending modules contains nn.Parameter will detach from the current computation graph + # Here we explicitly convert the parameters to tensors with `CloneBackward` + net_rref = todist.rpc.RRef(torchopt.module_clone(net, by='copy')) + with todist.autograd.context() as context_id: + qry_loss, qry_acc = inner_loop(net_rref, n_inner_iter, None, task_num, 'train') + todist.autograd.backward(context_id, qry_loss) + meta_opt.step() + + qry_loss = qry_loss.item() + qry_acc = 100.0 * qry_acc + i = epoch + float(batch_idx) / n_train_iter + iter_time = time.time() - start_time + torch.cuda.empty_cache() + + print( + f'[Epoch {i:.2f}] Train Loss: {qry_loss:.2f} | Acc: {qry_acc:.2f} | Time: {iter_time:.2f}' + ) + log.append( + { + 'epoch': i, + 'loss': qry_loss, + 'acc': qry_acc, + 'mode': 'train', + 'time': time.time(), + } + ) + + +@todist.rank_zero_only +def test(net, epoch, log): + # Crucially in our testing procedure here, we do *not* fine-tune + # the model during testing for simplicity. + # Most research papers using MAML for this task do an extra + # stage of fine-tuning here that should be added if you are + # adapting this code for research. + net.train() + + db = LOCAL_DATA_LOADER + n_test_iter = db.x_test.shape[0] // db.batchsz + task_num = db.x_train.shape[1] + + qry_losses = [] + qry_accs = [] + + net_rref = todist.rpc.RRef(net) + for _ in range(n_test_iter): + # TODO: Maybe pull this out into a separate module so it + # doesn't have to be duplicated between `train` and `test`? + n_inner_iter = 5 + + qry_loss, qry_acc = inner_loop(net_rref, n_inner_iter, None, task_num, 'test') + qry_losses.append(qry_loss.item()) + qry_accs.append(qry_acc) + + qry_losses = np.mean(qry_losses) + qry_accs = 100.0 * np.mean(qry_accs) + torch.cuda.empty_cache() + + print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}') + log.append( + { + 'epoch': epoch + 1, + 'loss': qry_losses, + 'acc': qry_accs, + 'mode': 'test', + 'time': time.time(), + } + ) + + +@todist.rank_zero_only +def plot(log): + # Generally you should pull your plotting code out of your training + # script but we are doing it here for brevity. + df = pd.DataFrame(log) + + fig, ax = plt.subplots(figsize=(8, 4), dpi=250) + train_df = df[df['mode'] == 'train'] + test_df = df[df['mode'] == 'test'] + ax.plot(train_df['epoch'], train_df['acc'], label='Train') + ax.plot(test_df['epoch'], test_df['acc'], label='Test') + ax.set_xlabel('Epoch') + ax.set_ylabel('Accuracy') + ax.set_ylim(85, 100) + ax.set_title('Distributed MAML Omniglot') + ax.legend(ncol=2, loc='lower right') + fig.tight_layout() + fname = 'maml-accs.png' + print(f'--- Plotting accuracy to {fname}') + fig.savefig(fname) + plt.close(fig) + + +if __name__ == '__main__': + main() diff --git a/examples/few-shot/README.md b/examples/few-shot/README.md index d25eafc4..df6578f3 100644 --- a/examples/few-shot/README.md +++ b/examples/few-shot/README.md @@ -14,5 +14,5 @@ python3 maml_omniglot.py The figure illustrate the experimental result.
- +
diff --git a/examples/few-shot/helpers/omniglot_loaders.py b/examples/few-shot/helpers/omniglot_loaders.py new file mode 100644 index 00000000..d857d386 --- /dev/null +++ b/examples/few-shot/helpers/omniglot_loaders.py @@ -0,0 +1,327 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# These Omniglot loaders are from Jackie Loong's PyTorch MAML implementation: +# https://github.com/dragen1860/MAML-Pytorch +# https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot.py +# https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglotNShot.py +# ============================================================================== + +import errno +import os + +import numpy as np +import torch +import torch.utils.data as data +import torchvision.transforms as transforms +from PIL import Image + + +class Omniglot(data.Dataset): + """ + The items are ``(filename, category)``. The index of all the categories can be found in + :attr:`idx_classes`. + + Args: + root: the directory where the dataset will be stored + transform: how to transform the input + target_transform: how to transform the target + download: need to download the dataset + """ + + urls = [ + 'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip', + 'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip', + ] + raw_folder = 'raw' + processed_folder = 'processed' + training_file = 'training.pt' + test_file = 'test.pt' + + def __init__(self, root, transform=None, target_transform=None, download=False): + self.root = root + self.transform = transform + self.target_transform = target_transform + + if not self._check_exists(): + if download: + self.download() + else: + raise RuntimeError('Dataset not found. You can use download=True to download it') + + self.all_items = find_classes(os.path.join(self.root, self.processed_folder)) + self.idx_classes = index_classes(self.all_items) + + def __getitem__(self, index): + filename = self.all_items[index][0] + img = str.join('/', [self.all_items[index][2], filename]) + + target = self.idx_classes[self.all_items[index][1]] + if self.transform is not None: + img = self.transform(img) + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self): + return len(self.all_items) + + def _check_exists(self): + return os.path.exists( + os.path.join(self.root, self.processed_folder, 'images_evaluation') + ) and os.path.exists(os.path.join(self.root, self.processed_folder, 'images_background')) + + def download(self): + import zipfile + + from six.moves import urllib + + if self._check_exists(): + return + + # download files + try: + os.makedirs(os.path.join(self.root, self.raw_folder)) + os.makedirs(os.path.join(self.root, self.processed_folder)) + except OSError as e: + if e.errno == errno.EEXIST: + pass + else: + raise + + for url in self.urls: + print('== Downloading ' + url) + data = urllib.request.urlopen(url) + filename = url.rpartition('/')[2] + file_path = os.path.join(self.root, self.raw_folder, filename) + with open(file_path, 'wb') as f: + f.write(data.read()) + file_processed = os.path.join(self.root, self.processed_folder) + print('== Unzip from ' + file_path + ' to ' + file_processed) + zip_ref = zipfile.ZipFile(file_path, 'r') + zip_ref.extractall(file_processed) + zip_ref.close() + print('Download finished.') + + +def find_classes(root_dir): + retour = [] + for (root, dirs, files) in os.walk(root_dir): + for f in files: + if f.endswith('png'): + r = root.split('/') + lr = len(r) + retour.append((f, r[lr - 2] + '/' + r[lr - 1], root)) + print('== Found %d items ' % len(retour)) + return retour + + +def index_classes(items): + idx = {} + for i in items: + if i[1] not in idx: + idx[i[1]] = len(idx) + print('== Found %d classes' % len(idx)) + return idx + + +class OmniglotNShot: + def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, rng, device=None): + """ + Different from mnistNShot, the + :param root: + :param batchsz: task num + :param n_way: + :param k_shot: + :param k_qry: + :param imgsz: + """ + + self.resize = imgsz + self.rng = rng + self.device = device + if not os.path.isfile(os.path.join(root, 'omniglot.npy')): + # if root/data.npy does not exist, just download it + self.x = Omniglot( + root, + download=True, + transform=transforms.Compose( + [ + lambda x: Image.open(x).convert('L'), + lambda x: x.resize((imgsz, imgsz)), + lambda x: np.reshape(x, (imgsz, imgsz, 1)), + lambda x: np.transpose(x, [2, 0, 1]), + lambda x: x / 255.0, + ] + ), + ) + + # {label: [img1, img2..., img20], label2: [img1, img2, ...], ... 1623 labels in total} + temp = {} + for (img, label) in self.x: + if label in temp.keys(): + temp[label].append(img) + else: + temp[label] = [img] + + self.x = [] + for ( + label, + imgs, + ) in temp.items(): # labels info deserted , each label contains 20imgs + self.x.append(np.array(imgs)) + + # as different class may have different number of imgs + self.x = np.array(self.x).astype(np.float) # [[20 imgs],..., 1623 classes in total] + # each character contains 20 imgs + print('data shape:', self.x.shape) # [1623, 20, 84, 84, 1] + temp = [] # Free memory + # save all dataset into npy file. + np.save(os.path.join(root, 'omniglot.npy'), self.x) + print('write into omniglot.npy.') + else: + # if data.npy exists, just load it. + self.x = np.load(os.path.join(root, 'omniglot.npy')) + print('load from omniglot.npy.') + + # [1623, 20, 84, 84, 1] + # TODO: can not shuffle here, we must keep training and test set distinct! + self.x_train, self.x_test = self.x[:1200], self.x[1200:] + + # self.normalization() + + self.batchsz = batchsz + self.n_cls = self.x.shape[0] # 1623 + self.n_way = n_way # n way + self.k_shot = k_shot # k shot + self.k_query = k_query # k query + assert (k_shot + k_query) <= 20 + + # save pointer of current read batch in total cache + self.indexes = {'train': 0, 'test': 0} + self.datasets = { + 'train': self.x_train, + 'test': self.x_test, + } # original data cached + print('DB: train', self.x_train.shape, 'test', self.x_test.shape) + + self.datasets_cache = { + 'train': self.load_data_cache(self.datasets['train']), # current epoch data cached + 'test': self.load_data_cache(self.datasets['test']), + } + + def normalization(self): + """ + Normalizes our data, to have a mean of 0 and sdt of 1 + """ + self.mean = np.mean(self.x_train) + self.std = np.std(self.x_train) + self.max = np.max(self.x_train) + self.min = np.min(self.x_train) + # print("before norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std) + self.x_train = (self.x_train - self.mean) / self.std + self.x_test = (self.x_test - self.mean) / self.std + + self.mean = np.mean(self.x_train) + self.std = np.std(self.x_train) + self.max = np.max(self.x_train) + self.min = np.min(self.x_train) + + # print("after norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std) + + def load_data_cache(self, data_pack): + """ + Collects several batches data for N-shot learning + :param data_pack: [cls_num, 20, 84, 84, 1] + :return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks + """ + + # take 5 way 1 shot as example: 5 * 1 + setsz = self.k_shot * self.n_way + querysz = self.k_query * self.n_way + data_cache = [] + + # print('preload next 50 caches of batchsz of batch.') + for sample in range(10): # num of episodes + + x_spts, y_spts, x_qrys, y_qrys = [], [], [], [] + for i in range(self.batchsz): # one batch means one set + + x_spt, y_spt, x_qry, y_qry = [], [], [], [] + selected_cls = self.rng.choice(data_pack.shape[0], self.n_way, False) + + for j, cur_class in enumerate(selected_cls): + + selected_img = self.rng.choice(20, self.k_shot + self.k_query, False) + + # meta-training and meta-test + x_spt.append(data_pack[cur_class][selected_img[: self.k_shot]]) + x_qry.append(data_pack[cur_class][selected_img[self.k_shot :]]) + y_spt.append([j for _ in range(self.k_shot)]) + y_qry.append([j for _ in range(self.k_query)]) + + # shuffle inside a batch + perm = self.rng.permutation(self.n_way * self.k_shot) + x_spt = np.array(x_spt).reshape( + self.n_way * self.k_shot, 1, self.resize, self.resize + )[perm] + y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm] + perm = self.rng.permutation(self.n_way * self.k_query) + x_qry = np.array(x_qry).reshape( + self.n_way * self.k_query, 1, self.resize, self.resize + )[perm] + y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm] + + # append [sptsz, 1, 84, 84] => [b, setsz, 1, 84, 84] + x_spts.append(x_spt) + y_spts.append(y_spt) + x_qrys.append(x_qry) + y_qrys.append(y_qry) + + # [b, setsz, 1, 84, 84] + x_spts = np.array(x_spts, dtype=np.float32).reshape( + self.batchsz, setsz, 1, self.resize, self.resize + ) + y_spts = np.array(y_spts, dtype=np.int).reshape(self.batchsz, setsz) + # [b, qrysz, 1, 84, 84] + x_qrys = np.array(x_qrys, dtype=np.float32).reshape( + self.batchsz, querysz, 1, self.resize, self.resize + ) + y_qrys = np.array(y_qrys, dtype=np.int).reshape(self.batchsz, querysz) + + x_spts, y_spts, x_qrys, y_qrys = [ + torch.from_numpy(z).to(self.device) for z in [x_spts, y_spts, x_qrys, y_qrys] + ] + + data_cache.append([x_spts, y_spts, x_qrys, y_qrys]) + + return data_cache + + def next(self, mode='train'): + """ + Gets next batch from the dataset with name. + :param mode: The name of the splitting (one of "train", "val", "test") + :return: + """ + + # update cache if indexes is larger cached num + if self.indexes[mode] >= len(self.datasets_cache[mode]): + self.indexes[mode] = 0 + self.datasets_cache[mode] = self.load_data_cache(self.datasets[mode]) + + next_batch = self.datasets_cache[mode][self.indexes[mode]] + self.indexes[mode] += 1 + + return next_batch diff --git a/examples/few-shot/maml-accs.png b/examples/few-shot/maml-accs.png index a3a0f4ce..df0b37db 100644 Binary files a/examples/few-shot/maml-accs.png and b/examples/few-shot/maml-accs.png differ diff --git a/examples/few-shot/maml_omniglot.py b/examples/few-shot/maml_omniglot.py index 30b10559..879a235a 100644 --- a/examples/few-shot/maml_omniglot.py +++ b/examples/few-shot/maml_omniglot.py @@ -54,7 +54,7 @@ import torchopt -from support.omniglot_loaders import OmniglotNShot # isort: skip +from helpers.omniglot_loaders import OmniglotNShot # isort: skip mpl.use('Agg') @@ -75,11 +75,13 @@ def main(): torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True np.random.seed(args.seed) rng = np.random.default_rng(args.seed) # Set up the Omniglot loader. - device = torch.device('cuda:0') + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') db = OmniglotNShot( '/tmp/omniglot-data', batchsz=args.task_num, @@ -114,9 +116,10 @@ def main(): meta_opt = optim.Adam(net.parameters(), lr=1e-3) log = [] + test(db, net, epoch=-1, log=log) for epoch in range(10): - train(db, net, meta_opt, epoch, log) - test(db, net, epoch, log) + train(db, net, meta_opt, epoch=epoch, log=log) + test(db, net, epoch=epoch, log=log) plot(log) @@ -130,8 +133,7 @@ def train(db, net, meta_opt, epoch, log): # Sample a batch of support and query images and labels. x_spt, y_spt, x_qry, y_qry = db.next() - task_num, setsz, c_, h, w = x_spt.size() - querysz = x_qry.size(1) + task_num = x_spt.size(0) # TODO: Maybe pull this out into a separate module so it # doesn't have to be duplicated between `train` and `test`? @@ -144,8 +146,8 @@ def train(db, net, meta_opt, epoch, log): qry_accs = [] meta_opt.zero_grad() - net_state_dict = torchopt.extract_state_dict(net) - optim_state_dict = torchopt.extract_state_dict(inner_opt) + net_state_dict = torchopt.extract_state_dict(net, by='reference', detach_buffers=True) + optim_state_dict = torchopt.extract_state_dict(inner_opt, by='reference') for i in range(task_num): # Optimize the likelihood of the support set by taking # gradient steps w.r.t. the model's parameters. @@ -162,28 +164,25 @@ def train(db, net, meta_opt, epoch, log): # These will be used to update the model's meta-parameters. qry_logits = net(x_qry[i]) qry_loss = F.cross_entropy(qry_logits, y_qry[i]) - qry_losses.append(qry_loss.detach()) - qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).sum().item() / querysz - qry_accs.append(qry_acc) - - # Update the model's meta-parameters to optimize the query - # losses across all of the tasks sampled in this batch. - # This unrolls through the gradient steps. - qry_loss.backward() + qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean() + qry_losses.append(qry_loss) + qry_accs.append(qry_acc.item()) torchopt.recover_state_dict(net, net_state_dict) torchopt.recover_state_dict(inner_opt, optim_state_dict) + qry_losses = torch.mean(torch.stack(qry_losses)) + qry_losses.backward() meta_opt.step() - qry_losses = sum(qry_losses) / task_num - qry_accs = 100.0 * sum(qry_accs) / task_num + qry_losses = qry_losses.item() + qry_accs = 100.0 * np.mean(qry_accs) i = epoch + float(batch_idx) / n_train_iter iter_time = time.time() - start_time + torch.cuda.empty_cache() print( f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}' ) - log.append( { 'epoch': i, @@ -211,15 +210,14 @@ def test(db, net, epoch, log): for batch_idx in range(n_test_iter): x_spt, y_spt, x_qry, y_qry = db.next('test') - task_num, setsz, c_, h, w = x_spt.size() - querysz = x_qry.size(1) + task_num = x_spt.size(0) # TODO: Maybe pull this out into a separate module so it # doesn't have to be duplicated between `train` and `test`? n_inner_iter = 5 - net_state_dict = torchopt.extract_state_dict(net) - optim_state_dict = torchopt.extract_state_dict(inner_opt) + net_state_dict = torchopt.extract_state_dict(net, by='reference', detach_buffers=True) + optim_state_dict = torchopt.extract_state_dict(inner_opt, by='reference') for i in range(task_num): # Optimize the likelihood of the support set by taking # gradient steps w.r.t. the model's parameters. @@ -231,15 +229,18 @@ def test(db, net, epoch, log): # The query loss and acc induced by these parameters. qry_logits = net(x_qry[i]).detach() - qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction='none') - qry_losses.append(qry_loss.detach()) - qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach()) + qry_loss = F.cross_entropy(qry_logits, y_qry[i]) + qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean() + qry_losses.append(qry_loss.item()) + qry_accs.append(qry_acc.item()) torchopt.recover_state_dict(net, net_state_dict) torchopt.recover_state_dict(inner_opt, optim_state_dict) - qry_losses = torch.cat(qry_losses).mean().item() - qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item() + qry_losses = np.mean(qry_losses) + qry_accs = 100.0 * np.mean(qry_accs) + torch.cuda.empty_cache() + print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}') log.append( { @@ -257,15 +258,16 @@ def plot(log): # script but we are doing it here for brevity. df = pd.DataFrame(log) - fig, ax = plt.subplots(figsize=(6, 4)) + fig, ax = plt.subplots(figsize=(8, 4), dpi=250) train_df = df[df['mode'] == 'train'] test_df = df[df['mode'] == 'test'] ax.plot(train_df['epoch'], train_df['acc'], label='Train') ax.plot(test_df['epoch'], test_df['acc'], label='Test') ax.set_xlabel('Epoch') ax.set_ylabel('Accuracy') - ax.set_ylim(70, 100) - fig.legend(ncol=2, loc='lower right') + ax.set_ylim(85, 100) + ax.set_title('MAML Omniglot') + ax.legend(ncol=2, loc='lower right') fig.tight_layout() fname = 'maml-accs.png' print(f'--- Plotting accuracy to {fname}') diff --git a/examples/iMAML/README.md b/examples/iMAML/README.md new file mode 100644 index 00000000..6208bc81 --- /dev/null +++ b/examples/iMAML/README.md @@ -0,0 +1,23 @@ +# implicit MAML few-shot Omniglot classification-examples + +Code on implicit MAML few-shot Omniglot classification in paper [Meta-Learning with Implicit Gradients](https://arxiv.org/abs/1909.04630) using TorchOpt. We use `torchopt.sgd` as the inner-loop optimizer. + +## Usage + +```bash +### Run +python3 imaml_omniglot.py --inner_steps 5 # use OOP APIs +python3 imaml_omniglot_functional.py --inner_steps 5 # use functional APIs +``` + +## Results + +The figure illustrate the experimental result. + +
+ +
+ +
+ +
diff --git a/examples/iMAML/helpers/omniglot_loaders.py b/examples/iMAML/helpers/omniglot_loaders.py new file mode 100644 index 00000000..d857d386 --- /dev/null +++ b/examples/iMAML/helpers/omniglot_loaders.py @@ -0,0 +1,327 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# These Omniglot loaders are from Jackie Loong's PyTorch MAML implementation: +# https://github.com/dragen1860/MAML-Pytorch +# https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot.py +# https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglotNShot.py +# ============================================================================== + +import errno +import os + +import numpy as np +import torch +import torch.utils.data as data +import torchvision.transforms as transforms +from PIL import Image + + +class Omniglot(data.Dataset): + """ + The items are ``(filename, category)``. The index of all the categories can be found in + :attr:`idx_classes`. + + Args: + root: the directory where the dataset will be stored + transform: how to transform the input + target_transform: how to transform the target + download: need to download the dataset + """ + + urls = [ + 'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip', + 'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip', + ] + raw_folder = 'raw' + processed_folder = 'processed' + training_file = 'training.pt' + test_file = 'test.pt' + + def __init__(self, root, transform=None, target_transform=None, download=False): + self.root = root + self.transform = transform + self.target_transform = target_transform + + if not self._check_exists(): + if download: + self.download() + else: + raise RuntimeError('Dataset not found. You can use download=True to download it') + + self.all_items = find_classes(os.path.join(self.root, self.processed_folder)) + self.idx_classes = index_classes(self.all_items) + + def __getitem__(self, index): + filename = self.all_items[index][0] + img = str.join('/', [self.all_items[index][2], filename]) + + target = self.idx_classes[self.all_items[index][1]] + if self.transform is not None: + img = self.transform(img) + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self): + return len(self.all_items) + + def _check_exists(self): + return os.path.exists( + os.path.join(self.root, self.processed_folder, 'images_evaluation') + ) and os.path.exists(os.path.join(self.root, self.processed_folder, 'images_background')) + + def download(self): + import zipfile + + from six.moves import urllib + + if self._check_exists(): + return + + # download files + try: + os.makedirs(os.path.join(self.root, self.raw_folder)) + os.makedirs(os.path.join(self.root, self.processed_folder)) + except OSError as e: + if e.errno == errno.EEXIST: + pass + else: + raise + + for url in self.urls: + print('== Downloading ' + url) + data = urllib.request.urlopen(url) + filename = url.rpartition('/')[2] + file_path = os.path.join(self.root, self.raw_folder, filename) + with open(file_path, 'wb') as f: + f.write(data.read()) + file_processed = os.path.join(self.root, self.processed_folder) + print('== Unzip from ' + file_path + ' to ' + file_processed) + zip_ref = zipfile.ZipFile(file_path, 'r') + zip_ref.extractall(file_processed) + zip_ref.close() + print('Download finished.') + + +def find_classes(root_dir): + retour = [] + for (root, dirs, files) in os.walk(root_dir): + for f in files: + if f.endswith('png'): + r = root.split('/') + lr = len(r) + retour.append((f, r[lr - 2] + '/' + r[lr - 1], root)) + print('== Found %d items ' % len(retour)) + return retour + + +def index_classes(items): + idx = {} + for i in items: + if i[1] not in idx: + idx[i[1]] = len(idx) + print('== Found %d classes' % len(idx)) + return idx + + +class OmniglotNShot: + def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, rng, device=None): + """ + Different from mnistNShot, the + :param root: + :param batchsz: task num + :param n_way: + :param k_shot: + :param k_qry: + :param imgsz: + """ + + self.resize = imgsz + self.rng = rng + self.device = device + if not os.path.isfile(os.path.join(root, 'omniglot.npy')): + # if root/data.npy does not exist, just download it + self.x = Omniglot( + root, + download=True, + transform=transforms.Compose( + [ + lambda x: Image.open(x).convert('L'), + lambda x: x.resize((imgsz, imgsz)), + lambda x: np.reshape(x, (imgsz, imgsz, 1)), + lambda x: np.transpose(x, [2, 0, 1]), + lambda x: x / 255.0, + ] + ), + ) + + # {label: [img1, img2..., img20], label2: [img1, img2, ...], ... 1623 labels in total} + temp = {} + for (img, label) in self.x: + if label in temp.keys(): + temp[label].append(img) + else: + temp[label] = [img] + + self.x = [] + for ( + label, + imgs, + ) in temp.items(): # labels info deserted , each label contains 20imgs + self.x.append(np.array(imgs)) + + # as different class may have different number of imgs + self.x = np.array(self.x).astype(np.float) # [[20 imgs],..., 1623 classes in total] + # each character contains 20 imgs + print('data shape:', self.x.shape) # [1623, 20, 84, 84, 1] + temp = [] # Free memory + # save all dataset into npy file. + np.save(os.path.join(root, 'omniglot.npy'), self.x) + print('write into omniglot.npy.') + else: + # if data.npy exists, just load it. + self.x = np.load(os.path.join(root, 'omniglot.npy')) + print('load from omniglot.npy.') + + # [1623, 20, 84, 84, 1] + # TODO: can not shuffle here, we must keep training and test set distinct! + self.x_train, self.x_test = self.x[:1200], self.x[1200:] + + # self.normalization() + + self.batchsz = batchsz + self.n_cls = self.x.shape[0] # 1623 + self.n_way = n_way # n way + self.k_shot = k_shot # k shot + self.k_query = k_query # k query + assert (k_shot + k_query) <= 20 + + # save pointer of current read batch in total cache + self.indexes = {'train': 0, 'test': 0} + self.datasets = { + 'train': self.x_train, + 'test': self.x_test, + } # original data cached + print('DB: train', self.x_train.shape, 'test', self.x_test.shape) + + self.datasets_cache = { + 'train': self.load_data_cache(self.datasets['train']), # current epoch data cached + 'test': self.load_data_cache(self.datasets['test']), + } + + def normalization(self): + """ + Normalizes our data, to have a mean of 0 and sdt of 1 + """ + self.mean = np.mean(self.x_train) + self.std = np.std(self.x_train) + self.max = np.max(self.x_train) + self.min = np.min(self.x_train) + # print("before norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std) + self.x_train = (self.x_train - self.mean) / self.std + self.x_test = (self.x_test - self.mean) / self.std + + self.mean = np.mean(self.x_train) + self.std = np.std(self.x_train) + self.max = np.max(self.x_train) + self.min = np.min(self.x_train) + + # print("after norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std) + + def load_data_cache(self, data_pack): + """ + Collects several batches data for N-shot learning + :param data_pack: [cls_num, 20, 84, 84, 1] + :return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks + """ + + # take 5 way 1 shot as example: 5 * 1 + setsz = self.k_shot * self.n_way + querysz = self.k_query * self.n_way + data_cache = [] + + # print('preload next 50 caches of batchsz of batch.') + for sample in range(10): # num of episodes + + x_spts, y_spts, x_qrys, y_qrys = [], [], [], [] + for i in range(self.batchsz): # one batch means one set + + x_spt, y_spt, x_qry, y_qry = [], [], [], [] + selected_cls = self.rng.choice(data_pack.shape[0], self.n_way, False) + + for j, cur_class in enumerate(selected_cls): + + selected_img = self.rng.choice(20, self.k_shot + self.k_query, False) + + # meta-training and meta-test + x_spt.append(data_pack[cur_class][selected_img[: self.k_shot]]) + x_qry.append(data_pack[cur_class][selected_img[self.k_shot :]]) + y_spt.append([j for _ in range(self.k_shot)]) + y_qry.append([j for _ in range(self.k_query)]) + + # shuffle inside a batch + perm = self.rng.permutation(self.n_way * self.k_shot) + x_spt = np.array(x_spt).reshape( + self.n_way * self.k_shot, 1, self.resize, self.resize + )[perm] + y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm] + perm = self.rng.permutation(self.n_way * self.k_query) + x_qry = np.array(x_qry).reshape( + self.n_way * self.k_query, 1, self.resize, self.resize + )[perm] + y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm] + + # append [sptsz, 1, 84, 84] => [b, setsz, 1, 84, 84] + x_spts.append(x_spt) + y_spts.append(y_spt) + x_qrys.append(x_qry) + y_qrys.append(y_qry) + + # [b, setsz, 1, 84, 84] + x_spts = np.array(x_spts, dtype=np.float32).reshape( + self.batchsz, setsz, 1, self.resize, self.resize + ) + y_spts = np.array(y_spts, dtype=np.int).reshape(self.batchsz, setsz) + # [b, qrysz, 1, 84, 84] + x_qrys = np.array(x_qrys, dtype=np.float32).reshape( + self.batchsz, querysz, 1, self.resize, self.resize + ) + y_qrys = np.array(y_qrys, dtype=np.int).reshape(self.batchsz, querysz) + + x_spts, y_spts, x_qrys, y_qrys = [ + torch.from_numpy(z).to(self.device) for z in [x_spts, y_spts, x_qrys, y_qrys] + ] + + data_cache.append([x_spts, y_spts, x_qrys, y_qrys]) + + return data_cache + + def next(self, mode='train'): + """ + Gets next batch from the dataset with name. + :param mode: The name of the splitting (one of "train", "val", "test") + :return: + """ + + # update cache if indexes is larger cached num + if self.indexes[mode] >= len(self.datasets_cache[mode]): + self.indexes[mode] = 0 + self.datasets_cache[mode] = self.load_data_cache(self.datasets[mode]) + + next_batch = self.datasets_cache[mode][self.indexes[mode]] + self.indexes[mode] += 1 + + return next_batch diff --git a/examples/iMAML/imaml-accs-functional.png b/examples/iMAML/imaml-accs-functional.png new file mode 100644 index 00000000..34922bc0 Binary files /dev/null and b/examples/iMAML/imaml-accs-functional.png differ diff --git a/examples/iMAML/imaml-accs.png b/examples/iMAML/imaml-accs.png new file mode 100644 index 00000000..1a6a5636 Binary files /dev/null and b/examples/iMAML/imaml-accs.png differ diff --git a/examples/iMAML/imaml_omniglot.py b/examples/iMAML/imaml_omniglot.py new file mode 100644 index 00000000..2b0c9738 --- /dev/null +++ b/examples/iMAML/imaml_omniglot.py @@ -0,0 +1,285 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +This example shows how to use TorchOpt to do iMAML-GD (see [1] for more details) +for few-shot Omniglot classification. + +[1] Rajeswaran, A., Finn, C., Kakade, S. M., & Levine, S. (2019). + Meta-learning with implicit gradients. In Advances in Neural Information Processing Systems (pp. 113-124). + https://arxiv.org/abs/1909.04630 +""" + +import argparse +import time + +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F + +import torchopt +from torchopt.diff.implicit import ImplicitMetaGradientModule + + +from helpers.omniglot_loaders import OmniglotNShot # isort: skip + + +mpl.use('Agg') +plt.style.use('bmh') + + +class InnerNet( + ImplicitMetaGradientModule, + linear_solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0), +): + def __init__(self, meta_net, n_inner_iter, reg_param): + super().__init__() + self.meta_net = meta_net + self.net = torchopt.module_clone(meta_net, by='deepcopy', detach_buffers=True) + self.n_inner_iter = n_inner_iter + self.reg_param = reg_param + + def forward(self, x): + return self.net(x) + + def objective(self, x, y): + y_pred = self(x) + loss = F.cross_entropy(y_pred, y) + regularization_loss = 0 + for p1, p2 in zip(self.parameters(), self.meta_parameters()): + regularization_loss += 0.5 * self.reg_param * torch.sum(torch.square(p1 - p2)) + return loss + regularization_loss + + def solve(self, x, y): + params = tuple(self.parameters()) + inner_optim = torchopt.SGD(params, lr=1e-1) + with torch.enable_grad(): + # Temporarily enable gradient computation for conducting the optimization + for _ in range(self.n_inner_iter): + loss = self.objective(x, y) + inner_optim.zero_grad() + loss.backward(inputs=params) + inner_optim.step() + return self + + +def main(): + argparser = argparse.ArgumentParser() + argparser.add_argument('--n_way', type=int, help='n way', default=5) + argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=5) + argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=5) + argparser.add_argument('--inner_steps', type=int, help='number of inner steps', default=5) + argparser.add_argument( + '--reg_params', type=float, help='regularization parameters', default=2.0 + ) + argparser.add_argument( + '--task_num', type=int, help='meta batch size, namely task num', default=16 + ) + argparser.add_argument('--seed', type=int, help='random seed', default=1) + args = argparser.parse_args() + + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + np.random.seed(args.seed) + rng = np.random.default_rng(args.seed) + + # Set up the Omniglot loader. + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + db = OmniglotNShot( + '/tmp/omniglot-data', + batchsz=args.task_num, + n_way=args.n_way, + k_shot=args.k_spt, + k_query=args.k_qry, + imgsz=28, + rng=rng, + device=device, + ) + + # Create a vanilla PyTorch neural network. + net = nn.Sequential( + nn.Conv2d(1, 64, 3), + nn.BatchNorm2d(64, momentum=1.0, affine=True, track_running_stats=False), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, momentum=1.0, affine=True, track_running_stats=False), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, momentum=1.0, affine=True, track_running_stats=False), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), + nn.Flatten(), + nn.Linear(64, args.n_way), + ).to(device) + + # We will use Adam to (meta-)optimize the initial parameters + # to be adapted. + net.train() + meta_opt = torchopt.Adam(net.parameters(), lr=1e-3) + + log = [] + test(db, net, epoch=-1, log=log, args=args) + for epoch in range(10): + train(db, net, meta_opt, epoch, log, args) + test(db, net, epoch, log, args) + plot(log) + + +def train(db, net, meta_opt, epoch, log, args): + n_train_iter = db.x_train.shape[0] // db.batchsz + # Given this module we've created, rip out the parameters and buffers + # and return a functional version of the module. `fnet` is stateless + # and can be called with `fnet(params, buffers, args, kwargs)` + # fnet, params, buffers = functorch.make_functional_with_buffers(net) + + for batch_idx in range(n_train_iter): + start_time = time.time() + # Sample a batch of support and query images and labels. + x_spt, y_spt, x_qry, y_qry = db.next() + + task_num = x_spt.size(0) + + n_inner_iter = args.inner_steps + reg_param = args.reg_params + + qry_losses = [] + qry_accs = [] + meta_opt.zero_grad() + + for i in range(task_num): + # Optimize the likelihood of the support set by taking + # gradient steps w.r.t. the model's parameters. + # This adapts the model's meta-parameters to the task. + + inner_net = InnerNet(net, n_inner_iter, reg_param) + optimal_inner_net = inner_net.solve(x_spt[i], y_spt[i]) + + # The final set of adapted parameters will induce some + # final loss and accuracy on the query dataset. + # These will be used to update the model's meta-parameters. + qry_logits = optimal_inner_net(x_qry[i]) + qry_loss = F.cross_entropy(qry_logits, y_qry[i]) + qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean() + qry_losses.append(qry_loss) + qry_accs.append(qry_acc.item()) + + qry_losses = torch.mean(torch.stack(qry_losses)) + qry_losses.backward() + meta_opt.step() + qry_losses = qry_losses.item() + qry_accs = 100.0 * np.mean(qry_accs) + i = epoch + float(batch_idx) / n_train_iter + iter_time = time.time() - start_time + torch.cuda.empty_cache() + + print( + f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}' + ) + log.append( + { + 'epoch': i, + 'loss': qry_losses, + 'acc': qry_accs, + 'mode': 'train', + 'time': time.time(), + } + ) + + +def test(db, net, epoch, log, args): + # Crucially in our testing procedure here, we do *not* fine-tune + # the model during testing for simplicity. + # Most research papers using MAML for this task do an extra + # stage of fine-tuning here that should be added if you are + # adapting this code for research. + n_test_iter = db.x_test.shape[0] // db.batchsz + + qry_losses = [] + qry_accs = [] + + # TODO: Maybe pull this out into a separate module so it + # doesn't have to be duplicated between `train` and `test`? + n_inner_iter = args.inner_steps + reg_param = args.reg_params + + for batch_idx in range(n_test_iter): + x_spt, y_spt, x_qry, y_qry = db.next('test') + + task_num = x_spt.size(0) + + for i in range(task_num): + # Optimize the likelihood of the support set by taking + # gradient steps w.r.t. the model's parameters. + # This adapts the model's meta-parameters to the task. + + inner_net = InnerNet(net, n_inner_iter, reg_param) + with torch.no_grad(): + optimal_inner_net = inner_net.solve(x_spt[i], y_spt[i]) + + # The query loss and acc induced by these parameters. + qry_logits = optimal_inner_net(x_qry[i]) + qry_loss = F.cross_entropy(qry_logits, y_qry[i]) + qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean() + qry_losses.append(qry_loss.item()) + qry_accs.append(qry_acc.item()) + + qry_losses = np.mean(qry_losses) + qry_accs = 100.0 * np.mean(qry_accs) + torch.cuda.empty_cache() + + print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}') + log.append( + { + 'epoch': epoch + 1, + 'loss': qry_losses, + 'acc': qry_accs, + 'mode': 'test', + 'time': time.time(), + } + ) + + +def plot(log): + # Generally you should pull your plotting code out of your training + # script but we are doing it here for brevity. + df = pd.DataFrame(log) + + fig, ax = plt.subplots(figsize=(8, 4), dpi=250) + train_df = df[df['mode'] == 'train'] + test_df = df[df['mode'] == 'test'] + ax.plot(train_df['epoch'], train_df['acc'], label='Train') + ax.plot(test_df['epoch'], test_df['acc'], label='Test') + ax.set_xlabel('Epoch') + ax.set_ylabel('Accuracy') + ax.set_ylim(80, 100) + ax.set_title('iMAML Omniglot') + ax.legend(ncol=2, loc='lower right') + fig.tight_layout() + fname = 'imaml-accs.png' + print(f'--- Plotting accuracy to {fname}') + fig.savefig(fname) + plt.close(fig) + + +if __name__ == '__main__': + main() diff --git a/examples/iMAML/imaml_omniglot_functional.py b/examples/iMAML/imaml_omniglot_functional.py new file mode 100644 index 00000000..88314366 --- /dev/null +++ b/examples/iMAML/imaml_omniglot_functional.py @@ -0,0 +1,334 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +This example shows how to use TorchOpt to do iMAML-GD (see [1] for more details) +for few-shot Omniglot classification. + +[1] Rajeswaran, A., Finn, C., Kakade, S. M., & Levine, S. (2019). + Meta-learning with implicit gradients. In Advances in Neural Information Processing Systems (pp. 113-124). + https://arxiv.org/abs/1909.04630 +""" + +import argparse +import time + +import functorch +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F + +import torchopt +from torchopt import pytree + + +from helpers.omniglot_loaders import OmniglotNShot # isort: skip + + +mpl.use('Agg') +plt.style.use('bmh') + + +def main(): + argparser = argparse.ArgumentParser() + argparser.add_argument('--n_way', type=int, help='n way', default=5) + argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=5) + argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=5) + argparser.add_argument('--inner_steps', type=int, help='number of inner steps', default=5) + argparser.add_argument( + '--reg_params', type=float, help='regularization parameters', default=2.0 + ) + argparser.add_argument( + '--task_num', type=int, help='meta batch size, namely task num', default=16 + ) + argparser.add_argument('--seed', type=int, help='random seed', default=1) + args = argparser.parse_args() + + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + np.random.seed(args.seed) + rng = np.random.default_rng(args.seed) + + # Set up the Omniglot loader. + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + db = OmniglotNShot( + '/tmp/omniglot-data', + batchsz=args.task_num, + n_way=args.n_way, + k_shot=args.k_spt, + k_query=args.k_qry, + imgsz=28, + rng=rng, + device=device, + ) + + # Create a vanilla PyTorch neural network. + net = nn.Sequential( + nn.Conv2d(1, 64, 3), + nn.BatchNorm2d(64, momentum=1.0, affine=True, track_running_stats=False), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, momentum=1.0, affine=True, track_running_stats=False), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, momentum=1.0, affine=True, track_running_stats=False), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), + nn.Flatten(), + nn.Linear(64, args.n_way), + ).to(device) + + # We will use Adam to (meta-)optimize the initial parameters + # to be adapted. + net.train() + fnet, meta_params = model = functorch.make_functional(net) + meta_opt = torchopt.adam(lr=1e-3) + meta_opt_state = meta_opt.init(meta_params) + + log = [] + test(db, model, epoch=-1, log=log, args=args) + for epoch in range(10): + meta_opt, meta_opt_state = train(db, model, (meta_opt, meta_opt_state), epoch, log, args) + test(db, model, epoch, log, args) + plot(log) + + +def train(db, model, meta_opt_and_state, epoch, log, args): + n_train_iter = db.x_train.shape[0] // db.batchsz + fnet, meta_params = model + meta_opt, meta_opt_state = meta_opt_and_state + # Given this module we've created, rip out the parameters and buffers + # and return a functional version of the module. `fnet` is stateless + # and can be called with `fnet(params, buffers, args, kwargs)` + # fnet, params, buffers = functorch.make_functional_with_buffers(net) + + for batch_idx in range(n_train_iter): + start_time = time.time() + # Sample a batch of support and query images and labels. + x_spt, y_spt, x_qry, y_qry = db.next() + + task_num = x_spt.size(0) + + n_inner_iter = args.inner_steps + reg_param = args.reg_params + + qry_losses = [] + qry_accs = [] + + for i in range(task_num): + # Optimize the likelihood of the support set by taking + # gradient steps w.r.t. the model's parameters. + # This adapts the model's meta-parameters to the task. + + init_params = pytree.tree_map( + lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), + meta_params, + ) + optimal_params = train_imaml_inner_solver( + init_params, + meta_params, + (x_spt[i], y_spt[i]), + (fnet, n_inner_iter, reg_param), + ) + # The final set of adapted parameters will induce some + # final loss and accuracy on the query dataset. + # These will be used to update the model's meta-parameters. + qry_logits = fnet(optimal_params, x_qry[i]) + qry_loss = F.cross_entropy(qry_logits, y_qry[i]) + qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean() + qry_losses.append(qry_loss) + qry_accs.append(qry_acc.item()) + + qry_losses = torch.mean(torch.stack(qry_losses)) + meta_grads = torch.autograd.grad(qry_losses, meta_params) + meta_updates, meta_opt_state = meta_opt.update(meta_grads, meta_opt_state) + meta_params = torchopt.apply_updates(meta_params, meta_updates) + qry_losses = qry_losses.item() + qry_accs = 100.0 * np.mean(qry_accs) + i = epoch + float(batch_idx) / n_train_iter + iter_time = time.time() - start_time + torch.cuda.empty_cache() + + print( + f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}' + ) + log.append( + { + 'epoch': i, + 'loss': qry_losses, + 'acc': qry_accs, + 'mode': 'train', + 'time': time.time(), + } + ) + + return (meta_opt, meta_opt_state) + + +def test(db, model, epoch, log, args): + # Crucially in our testing procedure here, we do *not* fine-tune + # the model during testing for simplicity. + # Most research papers using MAML for this task do an extra + # stage of fine-tuning here that should be added if you are + # adapting this code for research. + fnet, meta_params = model + n_test_iter = db.x_test.shape[0] // db.batchsz + + n_inner_iter = args.inner_steps + reg_param = args.reg_params + qry_losses = [] + qry_accs = [] + + for batch_idx in range(n_test_iter): + x_spt, y_spt, x_qry, y_qry = db.next('test') + + task_num = x_spt.size(0) + + for i in range(task_num): + # Optimize the likelihood of the support set by taking + # gradient steps w.r.t. the model's parameters. + # This adapts the model's meta-parameters to the task. + + init_params = pytree.tree_map( + lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), + meta_params, + ) + optimal_params = test_imaml_inner_solver( + init_params, + meta_params, + (x_spt[i], y_spt[i]), + (fnet, n_inner_iter, reg_param), + ) + + # The query loss and acc induced by these parameters. + qry_logits = fnet(optimal_params, x_qry[i]) + qry_loss = F.cross_entropy(qry_logits, y_qry[i]) + qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean() + qry_losses.append(qry_loss.item()) + qry_accs.append(qry_acc.item()) + + qry_losses = np.mean(qry_losses) + qry_accs = 100.0 * np.mean(qry_accs) + torch.cuda.empty_cache() + + print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}') + log.append( + { + 'epoch': epoch + 1, + 'loss': qry_losses, + 'acc': qry_accs, + 'mode': 'test', + 'time': time.time(), + } + ) + + +def imaml_objective(params, meta_params, data, aux): + x_spt, y_spt = data + fnet, n_inner_iter, reg_param = aux + y_pred = fnet(params, x_spt) + regularization_loss = 0 + for p1, p2 in zip(params, meta_params): + regularization_loss += 0.5 * reg_param * torch.sum(torch.square(p1 - p2)) + loss = F.cross_entropy(y_pred, y_spt) + regularization_loss + return loss + + +@torchopt.diff.implicit.custom_root( + functorch.grad(imaml_objective, argnums=0), + argnums=1, + has_aux=False, + solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0), +) +def train_imaml_inner_solver(params, meta_params, data, aux): + x_spt, y_spt = data + fnet, n_inner_iter, reg_param = aux + # Initial functional optimizer based on TorchOpt + inner_opt = torchopt.sgd(lr=1e-1) + inner_opt_state = inner_opt.init(params) + with torch.enable_grad(): + # Temporarily enable gradient computation for conducting the optimization + for _ in range(n_inner_iter): + pred = fnet(params, x_spt) + loss = F.cross_entropy(pred, y_spt) # compute loss + # Compute regularization loss + regularization_loss = 0 + for p1, p2 in zip(params, meta_params): + regularization_loss += 0.5 * reg_param * torch.sum(torch.square(p1 - p2)) + final_loss = loss + regularization_loss + grads = torch.autograd.grad(final_loss, params) # compute gradients + updates, inner_opt_state = inner_opt.update( + grads, inner_opt_state, inplace=True + ) # get updates + params = torchopt.apply_updates(params, updates, inplace=True) + return params + + +def test_imaml_inner_solver(params, meta_params, data, aux): + x_spt, y_spt = data + fnet, n_inner_iter, reg_param = aux + # Initial functional optimizer based on TorchOpt + inner_opt = torchopt.sgd(lr=1e-1) + inner_opt_state = inner_opt.init(params) + with torch.enable_grad(): + # Temporarily enable gradient computation for conducting the optimization + for _ in range(n_inner_iter): + pred = fnet(params, x_spt) + loss = F.cross_entropy(pred, y_spt) # compute loss + # Compute regularization loss + regularization_loss = 0 + for p1, p2 in zip(params, meta_params): + regularization_loss += 0.5 * reg_param * torch.sum(torch.square(p1 - p2)) + final_loss = loss + regularization_loss + grads = torch.autograd.grad(final_loss, params) # compute gradients + updates, inner_opt_state = inner_opt.update( + grads, inner_opt_state, inplace=True + ) # get updates + params = torchopt.apply_updates(params, updates, inplace=True) + return params + + +def plot(log): + # Generally you should pull your plotting code out of your training + # script but we are doing it here for brevity. + df = pd.DataFrame(log) + + fig, ax = plt.subplots(figsize=(8, 4), dpi=250) + train_df = df[df['mode'] == 'train'] + test_df = df[df['mode'] == 'test'] + ax.plot(train_df['epoch'], train_df['acc'], label='Train') + ax.plot(test_df['epoch'], test_df['acc'], label='Test') + ax.set_xlabel('Epoch') + ax.set_ylabel('Accuracy') + ax.set_ylim(80, 100) + ax.set_title('iMAML Omniglot (Functional)') + ax.legend(ncol=2, loc='lower right') + fig.tight_layout() + fname = 'imaml-accs-functional.png' + print(f'--- Plotting accuracy to {fname}') + fig.savefig(fname) + plt.close(fig) + + +if __name__ == '__main__': + main() diff --git a/examples/requirements.txt b/examples/requirements.txt index 66636aad..76bed365 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -1,7 +1,6 @@ ---extra-index-url https://download.pytorch.org/whl/cu116 -torch >= 1.12 +--extra-index-url https://download.pytorch.org/whl/cu117 +torch >= 1.13 torchvision -functorch >= 0.2 --requirement ../requirements.txt @@ -12,3 +11,4 @@ seaborn torchviz torchrl pillow +setproctitle diff --git a/image/TorchOpt.png b/image/TorchOpt.png deleted file mode 100644 index 04a90032..00000000 Binary files a/image/TorchOpt.png and /dev/null differ diff --git a/image/diffmode.png b/image/diffmode.png new file mode 100644 index 00000000..e33df7a9 Binary files /dev/null and b/image/diffmode.png differ diff --git a/image/time.png b/image/time.png deleted file mode 100644 index 6d246d2c..00000000 Binary files a/image/time.png and /dev/null differ diff --git a/image/torchviz_torchopt.jpg b/image/torchviz-vs-torchopt.jpg similarity index 100% rename from image/torchviz_torchopt.jpg rename to image/torchviz-vs-torchopt.jpg diff --git a/pyproject.toml b/pyproject.toml index 47af443f..f3e917af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,13 +1,16 @@ # Package ###################################################################### [build-system] -requires = ["setuptools", "torch >= 1.12", "numpy", "pybind11"] +# Sync with project.dependencies +requires = ["setuptools", "torch >= 1.13", "numpy", "pybind11 >= 2.10.1"] build-backend = "setuptools.build_meta" [project] name = "torchopt" -description = "A Jax-style optimizer for PyTorch." +description = "An efficient library for differentiable optimization for PyTorch." readme = "README.md" +# Change this if wheels for `torch` is available +# Search "requires-python" and update all corresponding items requires-python = ">= 3.7" authors = [ { name = "TorchOpt Contributors" }, @@ -29,12 +32,16 @@ keywords = [ classifiers = [ "Development Status :: 4 - Beta", "License :: OSI Approved :: Apache Software License", + # Sync with requires-python "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Operating System :: Microsoft :: Windows", "Operating System :: POSIX :: Linux", + "Operating System :: MacOS", "Environment :: GPU", "Environment :: GPU :: NVIDIA CUDA", "Intended Audience :: Developers", @@ -44,11 +51,12 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", ] dependencies = [ - "torch >= 1.12", - "optree", + # See also build-system.requires and project.requires-python + "torch >= 1.13", + "optree >= 0.4.1", "numpy", "graphviz", - "typing-extensions", + "typing-extensions >= 4.0.0", ] dynamic = ["version"] @@ -61,9 +69,10 @@ Documentation = "https://torchopt.readthedocs.io" [project.optional-dependencies] lint = [ "isort", - "black >= 22.6.0", - "pylint", - "mypy", + "black[jupyter] >= 22.6.0", + "pylint[spelling] >= 2.15.0", + "mypy >= 0.990", + "types-setuptools", "flake8", "flake8-bugbear", "doc8 < 1.0.0a0", @@ -73,10 +82,12 @@ lint = [ "pre-commit", ] test = [ - 'functorch >= 0.2', 'pytest', 'pytest-cov', 'pytest-xdist', + 'jax[cpu] >= 0.3', + 'jaxopt', + 'optax', ] [tool.setuptools.packages.find] @@ -85,19 +96,21 @@ include = ["torchopt", "torchopt.*"] # Wheel builder ################################################################ # Reference: https://cibuildwheel.readthedocs.io [tool.cibuildwheel] -archs = ["x86_64"] +archs = ["auto64"] build = "*manylinux*" -skip = "pp*" +skip = "pp* *musllinux*" build-frontend = "pip" build-verbosity = 3 environment.USE_FP16 = "ON" environment.CUDACXX = "/usr/local/cuda/bin/nvcc" environment.TORCH_CUDA_ARCH_LIST = "Common" -environment.DEFAULT_CUDA_VERSION = "11.6" -environment.DEFAULT_TEST_TORCH_SPECS = "cpu cu113 cu116" +environment.DEFAULT_CUDA_VERSION = "11.7" +environment.DEFAULT_TEST_TORCH_SPECS = "cpu cu116" environment-pass = ["CUDA_VERSION", "TEST_TORCH_SPECS"] container-engine = "docker" +test-extras = ["test"] +[tool.cibuildwheel.linux] before-all = """ CUDA_VERSION="${CUDA_VERSION:-"${DEFAULT_CUDA_VERSION}"}" if [[ "${CUDA_VERSION}" == "None" || "${CUDA_VERSION}" == "none" ]]; then @@ -111,32 +124,8 @@ before-all = """ yum install -y nvidia-driver-latest-libs "cuda-minimal-build-${CUDA_PKG_SUFFIX}" fi echo "cat torchopt/version.py"; cat torchopt/version.py - """ -test-extras = ["test"] -test-command = """ - SITE_PACKAGES="$(python -c 'print(__import__("sysconfig").get_path("purelib"))')" - TORCH_LIB_PATH="${SITE_PACKAGES}/torch/lib" - echo "LD_LIBRARY_PATH='${LD_LIBRARY_PATH}'" - echo "ls ${TORCH_LIB_PATH}"; ls -lh "${TORCH_LIB_PATH}" - find "${SITE_PACKAGES}/torchopt" -name "*.so" -print0 | - xargs -0 -I '{}' bash -c "echo 'ldd {}'; ldd '{}'; echo 'patchelf --print-rpath {}'; patchelf --print-rpath '{}'" - make -C "{project}" test || exit 1 - TORCH_VERSION="$(python -c 'print(__import__("torch").__version__.partition("+")[0])')" - TEST_TORCH_SPECS="${TEST_TORCH_SPECS:-"${DEFAULT_TEST_TORCH_SPECS}"}" - for spec in ${TEST_TORCH_SPECS}; do - python -m pip uninstall -y torch - export PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/${spec}" - echo "PIP_EXTRA_INDEX_URL='${PIP_EXTRA_INDEX_URL}'" - python -m pip install "torch==${TORCH_VERSION}" - echo "ls ${TORCH_LIB_PATH}"; ls -lh "${TORCH_LIB_PATH}" - find "${SITE_PACKAGES}/torchopt" -name "*.so" -print0 | - xargs -0 -I '{}' bash -c "echo 'ldd {}'; ldd '{}'; echo 'patchelf --print-rpath {}'; patchelf --print-rpath '{}'" - make -C "{project}" test || exit 1 - done - rm -rf ~/.pip/cache ~/.cache/pip - """ - -[tool.cibuildwheel.linux] + touch .first-python +""" repair-wheel-command = """ python -m pip install -r requirements.txt SITE_PACKAGES="$(python -c 'print(__import__("sysconfig").get_path("purelib"))')" @@ -148,7 +137,32 @@ repair-wheel-command = """ python -m auditwheel lddtree "{wheel}" python -m auditwheel repair --no-copy-site-libs --wheel-dir="{dest_dir}" "{wheel}" ) - """ +""" +test-command = """ + SITE_PACKAGES="$(python -c 'print(__import__("sysconfig").get_path("purelib"))')" + TORCH_LIB_PATH="${SITE_PACKAGES}/torch/lib" + echo "LD_LIBRARY_PATH='${LD_LIBRARY_PATH}'" + echo "ls ${TORCH_LIB_PATH}"; ls -lh "${TORCH_LIB_PATH}" + find "${SITE_PACKAGES}/torchopt" -name "*.so" -print0 | + xargs -0 -I '{}' bash -c "echo 'ldd {}'; ldd '{}'; echo 'patchelf --print-rpath {}'; patchelf --print-rpath '{}'" + make -C "{project}" test || exit 1 + TORCH_VERSION="$(python -c 'print(__import__("torch").__version__.partition("+")[0])')" + if [[ -f .first-python ]]; then + TEST_TORCH_SPECS="${TEST_TORCH_SPECS:-"${DEFAULT_TEST_TORCH_SPECS}"}" + for spec in ${TEST_TORCH_SPECS}; do + python -m pip uninstall -y torch + export PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/${spec}" + echo "PIP_EXTRA_INDEX_URL='${PIP_EXTRA_INDEX_URL}'" + python -m pip install "torch==${TORCH_VERSION}" + echo "ls ${TORCH_LIB_PATH}"; ls -lh "${TORCH_LIB_PATH}" + find "${SITE_PACKAGES}/torchopt" -name "*.so" -print0 | + xargs -0 -I '{}' bash -c "echo 'ldd {}'; ldd '{}'; echo 'patchelf --print-rpath {}'; patchelf --print-rpath '{}'" + make -C "{project}" test || exit 1 + done + rm -f .first-python + fi + rm -rf ~/.pip/cache ~/.cache/pip +""" # Linter tools ################################################################# @@ -156,27 +170,32 @@ repair-wheel-command = """ safe = true line-length = 100 skip-string-normalization = true -target-version = ["py37", "py38", "py39", "py310"] +# Sync with requires-python +target-version = ["py37", "py38", "py39", "py310", "py311"] [tool.isort] +atomic = true profile = "black" src_paths = ["torchopt", "examples", "tests"] +extra_standard_library = ["typing_extensions"] indent = 4 line_length = 100 lines_after_imports = 2 multi_line_output = 3 [tool.mypy] +# Sync with requires-python +python_version = 3.7 +pretty = true +show_error_codes = true +show_error_context = true +show_traceback = true allow_redefinition = true check_untyped_defs = true disallow_incomplete_defs = false disallow_untyped_defs = false ignore_missing_imports = true no_implicit_optional = true -pretty = true -show_error_codes = true -show_error_context = true -show_traceback = true strict_equality = true strict_optional = true warn_no_return = true diff --git a/requirements.txt b/requirements.txt index a2ced2f2..961ddf73 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ -torch >= 1.12 -optree +# Sync with project.dependencies +torch >= 1.13 +optree >= 0.4.1 numpy graphviz -typing-extensions +typing-extensions >= 4.0.0 diff --git a/setup.py b/setup.py index e0df95db..75f32750 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,10 @@ import os import pathlib +import platform +import re import shutil import sys +import sysconfig from setuptools import setup @@ -14,15 +17,17 @@ from setuptools.command.build_ext import build_ext HERE = pathlib.Path(__file__).absolute().parent +VERSION_FILE = HERE / 'torchopt' / 'version.py' -sys.path.insert(0, str(HERE / 'torchopt')) +sys.path.insert(0, str(VERSION_FILE.parent)) import version # noqa class CMakeExtension(Extension): - def __init__(self, name, source_dir='.', **kwargs): + def __init__(self, name, source_dir='.', target=None, **kwargs): super().__init__(name, sources=[], **kwargs) self.source_dir = os.path.abspath(source_dir) + self.target = target if target is not None else name.rpartition('.')[-1] class cmake_build_ext(build_ext): @@ -31,38 +36,42 @@ def build_extension(self, ext): super().build_extension(ext) return - import pybind11 from torch.utils import cpp_extension cmake = shutil.which('cmake') if cmake is None: raise RuntimeError('Cannot find CMake executable.') - build_temp = pathlib.Path(self.build_temp) + ext_path = pathlib.Path(self.get_ext_fullpath(ext.name)).absolute() + build_temp = pathlib.Path(self.build_temp).absolute() build_temp.mkdir(parents=True, exist_ok=True) config = 'Debug' if self.debug else 'Release' - extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) - print(self.get_ext_fullpath(ext.name)) - - PYTHON_INCLUDE_DIR = ';'.join(self.include_dirs) - TORCH_INCLUDE_PATH = ';'.join(cpp_extension.include_paths()) - TORCH_LIBRARY_PATH = ';'.join(cpp_extension.library_paths()) - cmake_args = [ f'-DCMAKE_BUILD_TYPE={config}', - f'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{config.upper()}={extdir}', - f'-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY_{config.upper()}={self.build_temp}', + f'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{config.upper()}={ext_path.parent}', + f'-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY_{config.upper()}={build_temp}', f'-DPYTHON_EXECUTABLE={sys.executable}', - f'-DPYBIND11_CMAKE_DIR={pybind11.get_cmake_dir()}', - f'-DPYTHON_INCLUDE_DIR={PYTHON_INCLUDE_DIR}', - f'-DTORCH_INCLUDE_PATH={TORCH_INCLUDE_PATH}', - f'-DTORCH_LIBRARY_PATH={TORCH_LIBRARY_PATH}', + f'-DPYTHON_INCLUDE_DIR={sysconfig.get_path("platinclude")}', + f'-DTORCH_INCLUDE_PATH={";".join(cpp_extension.include_paths())}', + f'-DTORCH_LIBRARY_PATH={";".join(cpp_extension.library_paths())}', ] - build_args = ['--config', config] + if platform.system() == 'Darwin': + # Cross-compile support for macOS - respect ARCHFLAGS if set + archs = re.findall(r'-arch (\S+)', os.environ.get('ARCHFLAGS', '')) + if archs: + cmake_args.append(f'-DCMAKE_OSX_ARCHITECTURES={";".join(archs)}') + + try: + import pybind11 + + cmake_args.append(f'-DPYBIND11_CMAKE_DIR={pybind11.get_cmake_dir()}') + except ImportError: + pass + build_args = ['--config', config] if ( 'CMAKE_BUILD_PARALLEL_LEVEL' not in os.environ and hasattr(self, 'parallel') @@ -72,6 +81,8 @@ def build_extension(self, ext): else: build_args.append('--parallel') + build_args.extend([f'--target={ext.target}', '--']) + try: os.chdir(build_temp) self.spawn(['cmake', ext.source_dir] + cmake_args) @@ -81,10 +92,53 @@ def build_extension(self, ext): os.chdir(HERE) -setup( - version=version.__version__, - package_data={'sharedlib': ['*.so', '*.pyd']}, - include_package_data=True, +CIBUILDWHEEL = os.getenv('CIBUILDWHEEL', '0') == '1' +LINUX = platform.system() == 'Linux' +MACOS = platform.system() == 'Darwin' +WINDOWS = platform.system() == 'Windows' +ext_kwargs = dict( cmdclass={'build_ext': cmake_build_ext}, - ext_modules=[CMakeExtension('torchopt._C', source_dir=HERE)], + ext_modules=[ + CMakeExtension( + 'torchopt._C', + source_dir=HERE, + optional=not (LINUX and CIBUILDWHEEL), + ) + ], ) + +TORCHOPT_NO_EXTENSIONS = ( + bool(os.getenv('TORCHOPT_NO_EXTENSIONS', '')) or WINDOWS or (MACOS and CIBUILDWHEEL) +) +if TORCHOPT_NO_EXTENSIONS: + ext_kwargs.clear() + + +VERSION_CONTENT = None + +try: + if not version.__release__: + try: + VERSION_CONTENT = VERSION_FILE.read_text(encoding='UTF-8') + VERSION_FILE.write_text( + data=re.sub( + r"""__version__\s*=\s*('[^']+'|"[^"]+")""", + f"__version__ = '{version.__version__}'", + string=VERSION_CONTENT, + ), + encoding='UTF-8', + ) + except OSError: + VERSION_CONTENT = None + + setup( + name='torchopt', + version=version.__version__, + package_data={'sharedlib': ['*.so', '*.pyd']}, + include_package_data=True, + **ext_kwargs, + ) +finally: + if VERSION_CONTENT is not None: + with VERSION_FILE.open(mode='wt', encoding='UTF-8', newline='') as file: + file.write(VERSION_CONTENT) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 6e3bebc9..2f4ae731 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -23,10 +23,10 @@ endif() list(APPEND torchopt_csrc "${adam_op_src}") -pybind11_add_module(_C THIN_LTO "${torchopt_csrc}") +pybind11_add_module(_C MODULE THIN_LTO "${torchopt_csrc}") target_link_libraries( _C PRIVATE - ${TORCH_LIBRARIES} + "${TORCH_LIBRARIES}" OpenMP::OpenMP_CXX ) diff --git a/src/adam_op/adam_op.cpp b/src/adam_op/adam_op.cpp index 01412126..18bb5d27 100644 --- a/src/adam_op/adam_op.cpp +++ b/src/adam_op/adam_op.cpp @@ -162,19 +162,19 @@ void buildSubmodule(py::module &mod) { // NOLINT py::arg("eps"), py::arg("eps_root"), py::arg("count")); - m.def("forwardMu", + m.def("forward_mu", &adamForwardMu, "Adam forward mu", py::arg("updates"), py::arg("mu"), py::arg("b1")); - m.def("forwardNu", + m.def("forward_nu", &adamForwardNu, "Adam forward nu", py::arg("updates"), py::arg("nu"), py::arg("b2")); - m.def("forwardUpdates", + m.def("forward_updates", &adamForwardUpdates, "Adam forward updates", py::arg("new_mu"), @@ -184,21 +184,21 @@ void buildSubmodule(py::module &mod) { // NOLINT py::arg("eps"), py::arg("eps_root"), py::arg("count")); - m.def("backwardMu", + m.def("backward_mu", &adamBackwardMu, "Adam backward mu", py::arg("dmu"), py::arg("updates"), py::arg("mu"), py::arg("b1")); - m.def("backwardNu", + m.def("backward_nu", &adamBackwardNu, "Adam backward nu", py::arg("dnu"), py::arg("updates"), py::arg("nu"), py::arg("b1")); - m.def("backwardUpdates", + m.def("backward_updates", &adamBackwardUpdates, "Adam backward updates", py::arg("dupdates"), diff --git a/src/adam_op/adam_op_impl_cpu.cpp b/src/adam_op/adam_op_impl_cpu.cpp index 82accd8c..cf734c4f 100644 --- a/src/adam_op/adam_op_impl_cpu.cpp +++ b/src/adam_op/adam_op_impl_cpu.cpp @@ -27,6 +27,8 @@ using std::size_t; namespace adam_op { +constexpr size_t MIN_NUMEL_USE_OMP = 1000; + template void adamForwardInplaceCPUKernel(const other_t b1, const other_t inv_one_minus_pow_b1, @@ -38,7 +40,9 @@ void adamForwardInplaceCPUKernel(const other_t b1, scalar_t *__restrict__ updates_ptr, scalar_t *__restrict__ mu_ptr, scalar_t *__restrict__ nu_ptr) { -#pragma omp parallel for num_threads(omp_get_num_procs()) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT for (size_t tid = 0; tid < n; ++tid) { const scalar_t updates = updates_ptr[tid]; const scalar_t mu = mu_ptr[tid]; @@ -90,7 +94,9 @@ void adamForwardMuCPUKernel(const scalar_t *__restrict__ updates_ptr, const other_t b1, const size_t n, scalar_t *__restrict__ mu_out_ptr) { -#pragma omp parallel for num_threads(omp_get_num_procs()) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT for (size_t tid = 0; tid < n; ++tid) { const scalar_t updates = updates_ptr[tid]; const scalar_t mu = mu_ptr[tid]; @@ -122,12 +128,14 @@ void adamForwardNuCPUKernel(const scalar_t *__restrict__ updates_ptr, const other_t b2, const size_t n, scalar_t *__restrict__ nu_out_ptr) { -#pragma omp parallel for num_threads(omp_get_num_procs()) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT for (size_t tid = 0; tid < n; ++tid) { const scalar_t updates = updates_ptr[tid]; const scalar_t nu = nu_ptr[tid]; - const scalar_t nu_out = b2 * nu + (1 - b2) * pow(updates, 2); + const scalar_t nu_out = b2 * nu + (1 - b2) * updates * updates; nu_out_ptr[tid] = nu_out; } } @@ -158,7 +166,9 @@ void adamForwardUpdatesCPUKernel(const scalar_t *__restrict__ new_mu_ptr, const other_t eps_root, const size_t n, scalar_t *__restrict__ updates_out_ptr) { -#pragma omp parallel for num_threads(omp_get_num_procs()) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT for (size_t tid = 0; tid < n; ++tid) { const scalar_t new_mu = new_mu_ptr[tid]; const scalar_t new_nu = new_nu_ptr[tid]; @@ -176,14 +186,11 @@ torch::Tensor adamForwardUpdatesCPU(const torch::Tensor &new_mu, const pyfloat_t eps_root, const pyuint_t count) { using other_t = pyfloat_t; + const other_t inv_one_minus_pow_b1 = 1 / (1 - std::pow(b1, count)); + const other_t inv_one_minus_pow_b2 = 1 / (1 - std::pow(b2, count)); auto updates_out = torch::empty_like(new_mu); - const other_t one_minus_pow_b1 = 1 - std::pow(b1, count); - const other_t inv_one_minus_pow_b1 = 1 / one_minus_pow_b1; - const other_t one_minus_pow_b2 = 1 - std::pow(b2, count); - const other_t inv_one_minus_pow_b2 = 1 / one_minus_pow_b2; - const size_t n = getTensorPlainSize(new_mu); AT_DISPATCH_SCALAR_TYPES(new_mu.scalar_type(), "adamForwardUpdatesCPU", ([&] { adamForwardUpdatesCPUKernel( @@ -205,7 +212,9 @@ void adamBackwardMuCPUKernel(const scalar_t *__restrict__ dmu_ptr, const size_t n, scalar_t *__restrict__ dupdates_out_ptr, scalar_t *__restrict__ dmu_out_ptr) { -#pragma omp parallel for num_threads(omp_get_num_procs()) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT for (size_t tid = 0; tid < n; ++tid) { const scalar_t dmu = dmu_ptr[tid]; @@ -240,7 +249,9 @@ void adamBackwardNuCPUKernel(const scalar_t *__restrict__ dnu_ptr, const size_t n, scalar_t *__restrict__ dupdates_out_ptr, scalar_t *__restrict__ dnu_out_ptr) { -#pragma omp parallel for num_threads(omp_get_num_procs()) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT for (size_t tid = 0; tid < n; ++tid) { const scalar_t dnu = dnu_ptr[tid]; const scalar_t updates = updates_ptr[tid]; @@ -279,7 +290,9 @@ void adamBackwardUpdatesCPUKernel(const scalar_t *__restrict__ dupdates_ptr, const size_t n, scalar_t *__restrict__ dnew_mu_out_ptr, scalar_t *__restrict__ dnew_nu_out_ptr) { -#pragma omp parallel for num_threads(omp_get_num_procs()) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT for (size_t tid = 0; tid < n; ++tid) { const scalar_t dupdates = dupdates_ptr[tid]; const scalar_t updates = updates_ptr[tid]; @@ -309,14 +322,12 @@ TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor &dupdates, const pyfloat_t b2, const pyuint_t count) { using other_t = pyfloat_t; + const other_t one_minus_pow_b1 = 1 - std::pow(b1, count); + const other_t inv_one_minus_pow_b2 = 1 / (1 - std::pow(b2, count)); auto dmu_out = torch::empty_like(new_mu); auto dnu_out = torch::empty_like(new_nu); - const other_t one_minus_pow_b1 = 1 - std::pow(b1, count); - const other_t one_minus_pow_b2 = 1 - std::pow(b2, count); - const other_t inv_one_minus_pow_b2 = 1 / one_minus_pow_b2; - const size_t n = getTensorPlainSize(dupdates); AT_DISPATCH_SCALAR_TYPES(dupdates.scalar_type(), "adamBackwardUpdatesCPU", ([&] { adamBackwardUpdatesCPUKernel( diff --git a/src/adam_op/adam_op_impl_cuda.cu b/src/adam_op/adam_op_impl_cuda.cu index c77d1790..4b65869f 100644 --- a/src/adam_op/adam_op_impl_cuda.cu +++ b/src/adam_op/adam_op_impl_cuda.cu @@ -24,7 +24,10 @@ namespace torchopt { namespace adam_op { -template +constexpr int UNROLL_SIZE = 4; +constexpr int BLOCK_SIZE = 256; + +template __global__ void adamForwardInplaceCUDAKernel(const other_t b1, const other_t inv_one_minus_pow_b1, const other_t b2, @@ -35,22 +38,26 @@ __global__ void adamForwardInplaceCUDAKernel(const other_t b1, scalar_t *__restrict__ updates_ptr, scalar_t *__restrict__ mu_ptr, scalar_t *__restrict__ nu_ptr) { - unsigned tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= n) { - return; + const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; +#pragma unroll + for (int i = 0; i < unroll_size; ++i) { + size_t tid = toffset + i; + if (tid >= n) { + return; + } + const scalar_t updates = updates_ptr[tid]; + const scalar_t mu = mu_ptr[tid]; + const scalar_t nu = nu_ptr[tid]; + + const scalar_t mu_out = b1 * mu + (1 - b1) * updates; + const scalar_t nu_out = b2 * nu + (1 - b2) * updates * updates; + const scalar_t updates_out = + mu_out * inv_one_minus_pow_b1 / (sqrt(nu_out * inv_one_minus_pow_b2 + eps_root) + eps); + + mu_ptr[tid] = mu_out; + nu_ptr[tid] = nu_out; + updates_ptr[tid] = updates_out; } - const scalar_t updates = updates_ptr[tid]; - const scalar_t mu = mu_ptr[tid]; - const scalar_t nu = nu_ptr[tid]; - - const scalar_t mu_out = b1 * mu + (1 - b1) * updates; - const scalar_t nu_out = b2 * nu + (1 - b2) * updates * updates; - const scalar_t updates_out = - mu_out * inv_one_minus_pow_b1 / (sqrt(nu_out * inv_one_minus_pow_b2 + eps_root) + eps); - - mu_ptr[tid] = mu_out; - nu_ptr[tid] = nu_out; - updates_ptr[tid] = updates_out; } TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor &updates, @@ -66,39 +73,61 @@ TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor &updates, const other_t inv_one_minus_pow_b2 = 1 / (1 - std::pow(b2, count)); const size_t n = getTensorPlainSize(updates); - const dim3 block(std::min(n, size_t(256))); - const dim3 grid((n - 1) / block.x + 1); - AT_DISPATCH_SCALAR_TYPES_CUDA(updates.scalar_type(), "adamForwardInplaceCUDA", ([&] { - adamForwardInplaceCUDAKernel - <<>>(scalar_t(b1), - scalar_t(inv_one_minus_pow_b1), - scalar_t(b2), - scalar_t(inv_one_minus_pow_b2), - scalar_t(eps), - scalar_t(eps_root), - n, - updates.data_ptr(), - mu.data_ptr(), - nu.data_ptr()); - })); + if (n < BLOCK_SIZE * UNROLL_SIZE) { + const dim3 block(std::min(n, size_t(BLOCK_SIZE))); + const dim3 grid((n - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(updates.scalar_type(), "adamForwardInplaceCUDA", ([&] { + adamForwardInplaceCUDAKernel + <<>>(scalar_t(b1), + scalar_t(inv_one_minus_pow_b1), + scalar_t(b2), + scalar_t(inv_one_minus_pow_b2), + scalar_t(eps), + scalar_t(eps_root), + n, + updates.data_ptr(), + mu.data_ptr(), + nu.data_ptr()); + })); + } else { + const dim3 block(std::min(n / UNROLL_SIZE, size_t(BLOCK_SIZE))); + const dim3 grid((n / UNROLL_SIZE - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(updates.scalar_type(), "adamForwardInplaceCUDA", ([&] { + adamForwardInplaceCUDAKernel + <<>>(scalar_t(b1), + scalar_t(inv_one_minus_pow_b1), + scalar_t(b2), + scalar_t(inv_one_minus_pow_b2), + scalar_t(eps), + scalar_t(eps_root), + n, + updates.data_ptr(), + mu.data_ptr(), + nu.data_ptr()); + })); + } return TensorArray<3>{updates, mu, nu}; } -template +template __global__ void adamForwardMuCUDAKernel(const scalar_t *__restrict__ updates_ptr, const scalar_t *__restrict__ mu_ptr, const other_t b1, const size_t n, scalar_t *__restrict__ mu_out_ptr) { - size_t tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= n) { - return; + const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; +#pragma unroll + for (int i = 0; i < unroll_size; ++i) { + size_t tid = toffset + i; + if (tid >= n) { + return; + } + + const scalar_t updates = updates_ptr[tid]; + const scalar_t mu = mu_ptr[tid]; + const scalar_t mu_out = b1 * mu + (1 - b1) * updates; + mu_out_ptr[tid] = mu_out; } - - const scalar_t updates = updates_ptr[tid]; - const scalar_t mu = mu_ptr[tid]; - const scalar_t mu_out = b1 * mu + (1 - b1) * updates; - mu_out_ptr[tid] = mu_out; } torch::Tensor adamForwardMuCUDA(const torch::Tensor &updates, @@ -107,35 +136,52 @@ torch::Tensor adamForwardMuCUDA(const torch::Tensor &updates, auto mu_out = torch::empty_like(mu); const size_t n = getTensorPlainSize(updates); - const dim3 block(std::min(n, size_t(256))); - const dim3 grid((n - 1) / block.x + 1); - AT_DISPATCH_SCALAR_TYPES_CUDA(updates.scalar_type(), "adamForwardMuCUDA", ([&] { - adamForwardMuCUDAKernel - <<>>(updates.data_ptr(), - mu.data_ptr(), - scalar_t(b1), - n, - mu_out.data_ptr()); - })); + if (n < BLOCK_SIZE * UNROLL_SIZE) { + const dim3 block(std::min(n, size_t(BLOCK_SIZE))); + const dim3 grid((n - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(updates.scalar_type(), "adamForwardMuCUDA", ([&] { + adamForwardMuCUDAKernel + <<>>(updates.data_ptr(), + mu.data_ptr(), + scalar_t(b1), + n, + mu_out.data_ptr()); + })); + } else { + const dim3 block(std::min(n / UNROLL_SIZE, size_t(BLOCK_SIZE))); + const dim3 grid((n / UNROLL_SIZE - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(updates.scalar_type(), "adamForwardMuCUDA", ([&] { + adamForwardMuCUDAKernel + <<>>(updates.data_ptr(), + mu.data_ptr(), + scalar_t(b1), + n, + mu_out.data_ptr()); + })); + } return mu_out; } -template +template __global__ void adamForwardNuCUDAKernel(const scalar_t *__restrict__ updates_ptr, const scalar_t *__restrict__ nu_ptr, const other_t b2, const size_t n, scalar_t *__restrict__ nu_out_ptr) { - size_t tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= n) { - return; + const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; +#pragma unroll + for (int i = 0; i < unroll_size; ++i) { + size_t tid = toffset + i; + if (tid >= n) { + return; + } + + const scalar_t updates = updates_ptr[tid]; + const scalar_t nu = nu_ptr[tid]; + + const scalar_t nu_out = b2 * nu + (1 - b2) * updates * updates; + nu_out_ptr[tid] = nu_out; } - - const scalar_t updates = updates_ptr[tid]; - const scalar_t nu = nu_ptr[tid]; - - const scalar_t nu_out = b2 * nu + (1 - b2) * pow(updates, 2); - nu_out_ptr[tid] = nu_out; } torch::Tensor adamForwardNuCUDA(const torch::Tensor &updates, @@ -144,20 +190,33 @@ torch::Tensor adamForwardNuCUDA(const torch::Tensor &updates, auto nu_out = torch::empty_like(nu); const size_t n = getTensorPlainSize(updates); - const dim3 block(std::min(n, size_t(256))); - const dim3 grid((n - 1) / block.x + 1); - AT_DISPATCH_SCALAR_TYPES_CUDA(updates.scalar_type(), "adamForwardNuCUDA", ([&] { - adamForwardNuCUDAKernel - <<>>(updates.data_ptr(), - nu.data_ptr(), - scalar_t(b2), - n, - nu_out.data_ptr()); - })); + if (n < BLOCK_SIZE * UNROLL_SIZE) { + const dim3 block(std::min(n, size_t(BLOCK_SIZE))); + const dim3 grid((n - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(updates.scalar_type(), "adamForwardNuCUDA", ([&] { + adamForwardNuCUDAKernel + <<>>(updates.data_ptr(), + nu.data_ptr(), + scalar_t(b2), + n, + nu_out.data_ptr()); + })); + } else { + const dim3 block(std::min(n / UNROLL_SIZE, size_t(BLOCK_SIZE))); + const dim3 grid((n / UNROLL_SIZE - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(updates.scalar_type(), "adamForwardNuCUDA", ([&] { + adamForwardNuCUDAKernel + <<>>(updates.data_ptr(), + nu.data_ptr(), + scalar_t(b2), + n, + nu_out.data_ptr()); + })); + } return nu_out; } -template +template __global__ void adamForwardUpdatesCUDAKernel(const scalar_t *__restrict__ new_mu_ptr, const scalar_t *__restrict__ new_nu_ptr, const other_t inv_one_minus_pow_b1, @@ -166,16 +225,20 @@ __global__ void adamForwardUpdatesCUDAKernel(const scalar_t *__restrict__ new_mu const other_t eps_root, const size_t n, scalar_t *__restrict__ updates_out_ptr) { - size_t tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= n) { - return; + const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; +#pragma unroll + for (int i = 0; i < unroll_size; ++i) { + size_t tid = toffset + i; + if (tid >= n) { + return; + } + + const scalar_t new_mu = new_mu_ptr[tid]; + const scalar_t new_nu = new_nu_ptr[tid]; + const scalar_t mu_hat = new_mu * inv_one_minus_pow_b1; + const scalar_t nu_hat = new_nu * inv_one_minus_pow_b2; + updates_out_ptr[tid] = mu_hat / (sqrt(nu_hat + eps_root) + eps); } - - const scalar_t new_mu = new_mu_ptr[tid]; - const scalar_t new_nu = new_nu_ptr[tid]; - const scalar_t mu_hat = new_mu * inv_one_minus_pow_b1; - const scalar_t nu_hat = new_nu * inv_one_minus_pow_b2; - updates_out_ptr[tid] = mu_hat / (sqrt(nu_hat + eps_root) + eps); } torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor &new_mu, @@ -186,46 +249,64 @@ torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor &new_mu, const pyfloat_t eps_root, const pyuint_t count) { using other_t = pyfloat_t; + const other_t inv_one_minus_pow_b1 = 1 / (1 - std::pow(b1, count)); + const other_t inv_one_minus_pow_b2 = 1 / (1 - std::pow(b2, count)); auto updates_out = torch::empty_like(new_mu); - const other_t one_minus_pow_b1 = 1 - std::pow(b1, count); - const other_t inv_one_minus_pow_b1 = 1 / one_minus_pow_b1; - const other_t one_minus_pow_b2 = 1 - std::pow(b2, count); - const other_t inv_one_minus_pow_b2 = 1 / one_minus_pow_b2; - const size_t n = getTensorPlainSize(new_mu); - const dim3 block(std::min(n, size_t(256))); - const dim3 grid((n - 1) / block.x + 1); - AT_DISPATCH_SCALAR_TYPES_CUDA(new_mu.scalar_type(), "adamForwardUpdatesCUDA", ([&] { - adamForwardUpdatesCUDAKernel - <<>>(new_mu.data_ptr(), - new_nu.data_ptr(), - scalar_t(inv_one_minus_pow_b1), - scalar_t(inv_one_minus_pow_b2), - scalar_t(eps), - scalar_t(eps_root), - n, - updates_out.data_ptr()); - })); + if (n < BLOCK_SIZE * UNROLL_SIZE) { + const dim3 block(std::min(n, size_t(BLOCK_SIZE))); + const dim3 grid((n - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(new_mu.scalar_type(), "adamForwardUpdatesCUDA", ([&] { + adamForwardUpdatesCUDAKernel + <<>>(new_mu.data_ptr(), + new_nu.data_ptr(), + scalar_t(inv_one_minus_pow_b1), + scalar_t(inv_one_minus_pow_b2), + scalar_t(eps), + scalar_t(eps_root), + n, + updates_out.data_ptr()); + })); + } else { + const dim3 block(std::min(n / UNROLL_SIZE, size_t(BLOCK_SIZE))); + const dim3 grid((n / UNROLL_SIZE - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(new_mu.scalar_type(), "adamForwardUpdatesCUDA", ([&] { + adamForwardUpdatesCUDAKernel + <<>>(new_mu.data_ptr(), + new_nu.data_ptr(), + scalar_t(inv_one_minus_pow_b1), + scalar_t(inv_one_minus_pow_b2), + scalar_t(eps), + scalar_t(eps_root), + n, + updates_out.data_ptr()); + })); + } + return updates_out; } -template +template __global__ void adamBackwardMuCUDAKernel(const scalar_t *__restrict__ dmu_ptr, const other_t b1, const size_t n, scalar_t *__restrict__ dupdates_out_ptr, scalar_t *__restrict__ dmu_out_ptr) { - size_t tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= n) { - return; + const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; +#pragma unroll + for (int i = 0; i < unroll_size; ++i) { + size_t tid = toffset + i; + if (tid >= n) { + return; + } + + const scalar_t dmu = dmu_ptr[tid]; + + dupdates_out_ptr[tid] = (1 - b1) * dmu; + dmu_out_ptr[tid] = b1 * dmu; } - - const scalar_t dmu = dmu_ptr[tid]; - - dupdates_out_ptr[tid] = (1 - b1) * dmu; - dmu_out_ptr[tid] = b1 * dmu; } TensorArray<2> adamBackwardMuCUDA(const torch::Tensor &dmu, @@ -236,36 +317,53 @@ TensorArray<2> adamBackwardMuCUDA(const torch::Tensor &dmu, auto dmu_out = torch::empty_like(mu); const size_t n = getTensorPlainSize(dmu); - const dim3 block(std::min(n, size_t(256))); - const dim3 grid((n - 1) / block.x + 1); - AT_DISPATCH_SCALAR_TYPES_CUDA(dmu.scalar_type(), "adamBackwardMuCUDA", ([&] { - adamBackwardMuCUDAKernel - <<>>(dmu.data_ptr(), - scalar_t(b1), - n, - dupdates_out.data_ptr(), - dmu_out.data_ptr()); - })); + if (n < BLOCK_SIZE * UNROLL_SIZE) { + const dim3 block(std::min(n, size_t(BLOCK_SIZE))); + const dim3 grid((n - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(dmu.scalar_type(), "adamBackwardMuCUDA", ([&] { + adamBackwardMuCUDAKernel + <<>>(dmu.data_ptr(), + scalar_t(b1), + n, + dupdates_out.data_ptr(), + dmu_out.data_ptr()); + })); + } else { + const dim3 block(std::min(n / UNROLL_SIZE, size_t(BLOCK_SIZE))); + const dim3 grid((n / UNROLL_SIZE - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(dmu.scalar_type(), "adamBackwardMuCUDA", ([&] { + adamBackwardMuCUDAKernel + <<>>(dmu.data_ptr(), + scalar_t(b1), + n, + dupdates_out.data_ptr(), + dmu_out.data_ptr()); + })); + } return TensorArray<2>{std::move(dupdates_out), std::move(dmu_out)}; } -template +template __global__ void adamBackwardNuCUDAKernel(const scalar_t *__restrict__ dnu_ptr, const scalar_t *__restrict__ updates_ptr, const other_t b2, const size_t n, scalar_t *__restrict__ dupdates_out_ptr, scalar_t *__restrict__ dnu_out_ptr) { - size_t tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= n) { - return; + const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; +#pragma unroll + for (int i = 0; i < unroll_size; ++i) { + size_t tid = toffset + i; + if (tid >= n) { + return; + } + + const scalar_t dnu = dnu_ptr[tid]; + const scalar_t updates = updates_ptr[tid]; + + dupdates_out_ptr[tid] = 2 * (1 - b2) * updates * dnu; + dnu_out_ptr[tid] = b2 * dnu; } - - const scalar_t dnu = dnu_ptr[tid]; - const scalar_t updates = updates_ptr[tid]; - - dupdates_out_ptr[tid] = 2 * (1 - b2) * updates * dnu; - dnu_out_ptr[tid] = b2 * dnu; } TensorArray<2> adamBackwardNuCUDA(const torch::Tensor &dnu, @@ -276,21 +374,35 @@ TensorArray<2> adamBackwardNuCUDA(const torch::Tensor &dnu, auto dnu_out = torch::empty_like(nu); const size_t n = getTensorPlainSize(dnu); - const dim3 block(std::min(n, size_t(256))); - const dim3 grid((n - 1) / block.x + 1); - AT_DISPATCH_SCALAR_TYPES_CUDA(dnu.scalar_type(), "adamForwardNuCUDA", ([&] { - adamBackwardNuCUDAKernel - <<>>(dnu.data_ptr(), - updates.data_ptr(), - scalar_t(b2), - n, - dupdates_out.data_ptr(), - dnu_out.data_ptr()); - })); + if (n < BLOCK_SIZE * UNROLL_SIZE) { + const dim3 block(std::min(n, size_t(BLOCK_SIZE))); + const dim3 grid((n - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(dnu.scalar_type(), "adamForwardNuCUDA", ([&] { + adamBackwardNuCUDAKernel + <<>>(dnu.data_ptr(), + updates.data_ptr(), + scalar_t(b2), + n, + dupdates_out.data_ptr(), + dnu_out.data_ptr()); + })); + } else { + const dim3 block(std::min(n / UNROLL_SIZE, size_t(BLOCK_SIZE))); + const dim3 grid((n / UNROLL_SIZE - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(dnu.scalar_type(), "adamForwardNuCUDA", ([&] { + adamBackwardNuCUDAKernel + <<>>(dnu.data_ptr(), + updates.data_ptr(), + scalar_t(b2), + n, + dupdates_out.data_ptr(), + dnu_out.data_ptr()); + })); + } return TensorArray<2>{std::move(dupdates_out), std::move(dnu_out)}; } -template +template __global__ void adamBackwardUpdatesCUDAKernel(const scalar_t *__restrict__ dupdates_ptr, const scalar_t *__restrict__ updates_ptr, const scalar_t *__restrict__ new_mu_ptr, @@ -299,28 +411,32 @@ __global__ void adamBackwardUpdatesCUDAKernel(const scalar_t *__restrict__ dupda const size_t n, scalar_t *__restrict__ dnew_mu_out_ptr, scalar_t *__restrict__ dnew_nu_out_ptr) { - size_t tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= n) { - return; + const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; +#pragma unroll + for (int i = 0; i < unroll_size; ++i) { + size_t tid = toffset + i; + if (tid >= n) { + return; + } + + const scalar_t dupdates = dupdates_ptr[tid]; + const scalar_t updates = updates_ptr[tid]; + const scalar_t new_mu = new_mu_ptr[tid]; + + if (new_mu == 0) { + dnew_mu_out_ptr[tid] = 0; + dnew_nu_out_ptr[tid] = 0; + return; + } + + const scalar_t updates_div_new_mu = updates / new_mu; + + const scalar_t denominator = updates_div_new_mu * one_minus_pow_b1; + + dnew_mu_out_ptr[tid] = dupdates * updates_div_new_mu; + dnew_nu_out_ptr[tid] = + -dupdates * updates * denominator * 0.5 * inv_one_minus_pow_b2 * denominator; } - - const scalar_t dupdates = dupdates_ptr[tid]; - const scalar_t updates = updates_ptr[tid]; - const scalar_t new_mu = new_mu_ptr[tid]; - - if (new_mu == 0) { - dnew_mu_out_ptr[tid] = 0; - dnew_nu_out_ptr[tid] = 0; - return; - } - - const scalar_t updates_div_new_mu = updates / new_mu; - - const scalar_t denominator = updates_div_new_mu * one_minus_pow_b1; - - dnew_mu_out_ptr[tid] = dupdates * updates_div_new_mu; - dnew_nu_out_ptr[tid] = - -dupdates * updates * denominator * 0.5 * inv_one_minus_pow_b2 * denominator; } TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor &dupdates, @@ -331,28 +447,42 @@ TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor &dupdates, const pyfloat_t b2, const pyuint_t count) { using other_t = pyfloat_t; + const other_t one_minus_pow_b1 = 1 - std::pow(b1, count); + const other_t inv_one_minus_pow_b2 = 1 / (1 - std::pow(b2, count)); auto dmu_out = torch::empty_like(new_mu); auto dnu_out = torch::empty_like(new_nu); - const other_t one_minus_pow_b1 = 1 - std::pow(b1, count); - const other_t one_minus_pow_b2 = 1 - std::pow(b2, count); - const other_t inv_one_minus_pow_b2 = 1 / one_minus_pow_b2; - const size_t n = getTensorPlainSize(dupdates); - const dim3 block(std::min(n, size_t(256))); - const dim3 grid((n - 1) / block.x + 1); - AT_DISPATCH_SCALAR_TYPES_CUDA(dupdates.scalar_type(), "adamBackwardUpdatesCUDA", ([&] { - adamBackwardUpdatesCUDAKernel - <<>>(dupdates.data_ptr(), - updates.data_ptr(), - new_mu.data_ptr(), - scalar_t(one_minus_pow_b1), - scalar_t(inv_one_minus_pow_b2), - n, - dmu_out.data_ptr(), - dnu_out.data_ptr()); - })); + if (n < BLOCK_SIZE * UNROLL_SIZE) { + const dim3 block(std::min(n, size_t(BLOCK_SIZE))); + const dim3 grid((n - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(dupdates.scalar_type(), "adamBackwardUpdatesCUDA", ([&] { + adamBackwardUpdatesCUDAKernel + <<>>(dupdates.data_ptr(), + updates.data_ptr(), + new_mu.data_ptr(), + scalar_t(one_minus_pow_b1), + scalar_t(inv_one_minus_pow_b2), + n, + dmu_out.data_ptr(), + dnu_out.data_ptr()); + })); + } else { + const dim3 block(std::min(n / UNROLL_SIZE, size_t(BLOCK_SIZE))); + const dim3 grid((n / UNROLL_SIZE - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(dupdates.scalar_type(), "adamBackwardUpdatesCUDA", ([&] { + adamBackwardUpdatesCUDAKernel + <<>>(dupdates.data_ptr(), + updates.data_ptr(), + new_mu.data_ptr(), + scalar_t(one_minus_pow_b1), + scalar_t(inv_one_minus_pow_b2), + n, + dmu_out.data_ptr(), + dnu_out.data_ptr()); + })); + } return TensorArray<2>{std::move(dmu_out), std::move(dnu_out)}; } diff --git a/tests/helpers.py b/tests/helpers.py index d34ad41e..6c7c4f01 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -23,6 +23,7 @@ import pytest import torch import torch.nn as nn +import torch.types from torch.utils import data @@ -34,6 +35,14 @@ MODEL_HIDDEN_SIZE = 64 +def dtype_numpy2torch(dtype: np.dtype) -> torch.dtype: + return torch.tensor(np.zeros(1, dtype=dtype)).dtype + + +def dtype_torch2numpy(dtype: torch.dtype) -> np.dtype: + return torch.zeros(1, dtype=dtype).numpy().dtype + + def parametrize(**argvalues) -> pytest.mark.parametrize: arguments = list(argvalues) @@ -46,6 +55,8 @@ def parametrize(**argvalues) -> pytest.mark.parametrize: argvalues = list(itertools.product(*tuple(map(argvalues.get, arguments)))) first_product = argvalues[0] argvalues.extend((dtype,) + first_product[1:] for dtype in dtypes[1:]) + else: + argvalues = list(itertools.product(*tuple(map(argvalues.get, arguments)))) ids = tuple( '-'.join(f'{arg}({val})' for arg, val in zip(arguments, values)) for values in argvalues @@ -69,45 +80,59 @@ def seed_everything(seed: int) -> None: pass +class MyLinear(nn.Module): + def __init__( + self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None + ) -> None: + super().__init__() + self.linear = nn.Linear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + dtype=dtype, + ) + self.unused_module = nn.Linear(1, 1, bias=False) + self.unused_parameter = nn.Parameter(torch.zeros(1, 1), requires_grad=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + @torch.no_grad() def get_models( - device: Optional[Union[str, torch.device]] = None, dtype: torch.dtype = torch.float32 + device: torch.types.Device = None, dtype: torch.dtype = torch.float32 ) -> Tuple[nn.Module, nn.Module, nn.Module, data.DataLoader]: seed_everything(seed=42) model_base = nn.Sequential( - nn.Linear( + MyLinear( in_features=MODEL_NUM_INPUTS, out_features=MODEL_HIDDEN_SIZE, bias=True, - dtype=dtype, ), nn.BatchNorm1d( num_features=MODEL_HIDDEN_SIZE, track_running_stats=True, - dtype=dtype, ), nn.ReLU(), nn.Linear( in_features=MODEL_HIDDEN_SIZE, out_features=MODEL_HIDDEN_SIZE, bias=True, - dtype=dtype, ), nn.BatchNorm1d( num_features=MODEL_HIDDEN_SIZE, track_running_stats=True, - dtype=dtype, ), nn.ReLU(), nn.Linear( in_features=MODEL_HIDDEN_SIZE, out_features=MODEL_NUM_CLASSES, - bias=True, - dtype=dtype, + bias=False, ), nn.Softmax(dim=-1), - ) + ).to(dtype=dtype) for name, param in model_base.named_parameters(recurse=True): if name.endswith('weight') and param.ndim >= 2: nn.init.orthogonal_(param) @@ -123,6 +148,7 @@ def get_models( dataset = data.TensorDataset( torch.randint(0, 1, (BATCH_SIZE * NUM_UPDATES, MODEL_NUM_INPUTS)), + # torch.empty((BATCH_SIZE * NUM_UPDATES, MODEL_NUM_INPUTS), dtype=dtype).uniform_(-1.0, +1.0), torch.randint(0, MODEL_NUM_CLASSES, (BATCH_SIZE * NUM_UPDATES,)), ) loader = data.DataLoader(dataset, BATCH_SIZE, shuffle=False) @@ -174,8 +200,8 @@ def assert_all_close( from torch.testing._comparison import get_tolerances rtol, atol = get_tolerances(actual, expected, rtol=rtol, atol=atol) - rtol *= 4 * NUM_UPDATES - atol *= 4 * NUM_UPDATES + rtol *= 5 * NUM_UPDATES + atol *= 5 * NUM_UPDATES torch.testing.assert_close( actual, diff --git a/tests/requirements.txt b/tests/requirements.txt index d02db980..b8c70827 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,18 +1,23 @@ ---extra-index-url https://download.pytorch.org/whl/cu116 -torch >= 1.12 -functorch >= 0.2 +--extra-index-url https://download.pytorch.org/whl/cu117 +torch >= 1.13 --requirement ../requirements.txt +jax[cpu] >= 0.3 +jaxopt +optax + pytest pytest-cov pytest-xdist isort -black >= 22.6.0 -pylint -mypy +black[jupyter] >= 22.6.0 +pylint[spelling] >= 2.15.0 +mypy >= 0.990 +types-setuptools flake8 flake8-bugbear +# https://github.com/PyCQA/doc8/issues/112 doc8 < 1.0.0a0 pydocstyle pyenchant diff --git a/tests/test_alias.py b/tests/test_alias.py index 6f37e939..50b42835 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -32,7 +32,7 @@ nesterov=[False, True], inplace=[True, False], weight_decay=[0.0, 1e-2], - maximize=[False], # TODO: test maximize after PyTorch 1.13 + maximize=[False, True], ) def test_sgd( dtype: torch.dtype, @@ -76,7 +76,7 @@ def test_sgd( loss = F.cross_entropy(pred, ys) loss_ref = F.cross_entropy(pred_ref, ys) - grads = torch.autograd.grad(loss, params) + grads = torch.autograd.grad(loss, params, allow_unused=True) updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) params = torchopt.apply_updates(params, updates, inplace=inplace) @@ -134,7 +134,7 @@ def test_adam( loss = F.cross_entropy(pred, ys) loss_ref = F.cross_entropy(pred_ref, ys) - grads = torch.autograd.grad(loss, params) + grads = torch.autograd.grad(loss, params, allow_unused=True) updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) params = torchopt.apply_updates(params, updates, inplace=inplace) @@ -192,7 +192,7 @@ def test_adamw( loss = F.cross_entropy(pred, ys) loss_ref = F.cross_entropy(pred_ref, ys) - grads = torch.autograd.grad(loss, params) + grads = torch.autograd.grad(loss, params, allow_unused=True) updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) params = torchopt.apply_updates(params, updates, inplace=inplace) @@ -251,7 +251,7 @@ def test_adam_accelerated_cpu( loss = F.cross_entropy(pred, ys) loss_ref = F.cross_entropy(pred_ref, ys) - grads = torch.autograd.grad(loss, params) + grads = torch.autograd.grad(loss, params, allow_unused=True) updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) params = torchopt.apply_updates(params, updates, inplace=inplace) @@ -313,7 +313,7 @@ def test_adam_accelerated_cuda( loss = F.cross_entropy(pred, ys) loss_ref = F.cross_entropy(pred_ref, ys) - grads = torch.autograd.grad(loss, params) + grads = torch.autograd.grad(loss, params, allow_unused=True) updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) params = torchopt.apply_updates(params, updates, inplace=inplace) @@ -374,7 +374,7 @@ def test_rmsprop( loss = F.cross_entropy(pred, ys) loss_ref = F.cross_entropy(pred_ref, ys) - grads = torch.autograd.grad(loss, params) + grads = torch.autograd.grad(loss, params, allow_unused=True) updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) params = torchopt.apply_updates(params, updates, inplace=inplace) diff --git a/tests/test_clip.py b/tests/test_clip.py index 420cfdaa..f8d3b289 100644 --- a/tests/test_clip.py +++ b/tests/test_clip.py @@ -30,7 +30,7 @@ dampening=[0.0, 0.5], nesterov=[False, True], weight_decay=[0.0, 1e-2], - maximize=[False], # TODO: test maximize after PyTorch 1.13 + maximize=[False, True], ) def test_sgd( dtype: torch.dtype, diff --git a/tests/test_implicit.py b/tests/test_implicit.py new file mode 100644 index 00000000..ac61b3be --- /dev/null +++ b/tests/test_implicit.py @@ -0,0 +1,681 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import copy +from collections import OrderedDict +from types import FunctionType +from typing import Tuple + +import functorch +import jax +import jax.numpy as jnp +import jaxopt +import numpy as np +import optax +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.types +from torch.utils import data + +import helpers +import torchopt +from torchopt import pytree +from torchopt.diff.implicit import ImplicitMetaGradientModule + + +BATCH_SIZE = 8 +NUM_UPDATES = 3 + +MODEL_NUM_INPUTS = 10 +MODEL_NUM_CLASSES = 10 + + +class FcNet(nn.Module): + def __init__(self, dim, out): + super().__init__() + self.fc = nn.Linear(in_features=dim, out_features=out, bias=True) + nn.init.ones_(self.fc.weight) + nn.init.zeros_(self.fc.bias) + + def forward(self, x): + return self.fc(x) + + +def get_model_jax(dtype: np.dtype = np.float32) -> Tuple[FunctionType, OrderedDict]: + helpers.seed_everything(seed=42) + + def func(params, x): + return x @ params['weight'] + params['bias'] + + params = OrderedDict( + [ + ('weight', jnp.ones((MODEL_NUM_INPUTS, MODEL_NUM_CLASSES), dtype=dtype)), + ('bias', jnp.zeros((MODEL_NUM_CLASSES,), dtype=dtype)), + ] + ) + return func, params + + +@torch.no_grad() +def get_model_torch( + device: torch.types.Device = None, dtype: torch.dtype = torch.float32 +) -> Tuple[nn.Module, data.DataLoader]: + helpers.seed_everything(seed=42) + + model = FcNet(MODEL_NUM_INPUTS, MODEL_NUM_CLASSES).to(dtype=dtype) + + if device is not None: + model = model.to(device=torch.device(device)) + + dataset = data.TensorDataset( + torch.randint(0, 1, (BATCH_SIZE * NUM_UPDATES, MODEL_NUM_INPUTS)), + torch.randint(0, MODEL_NUM_CLASSES, (BATCH_SIZE * NUM_UPDATES,)), + ) + loader = data.DataLoader(dataset, BATCH_SIZE, shuffle=False) + + return model, loader + + +@torch.no_grad() +def get_rr_dataset_torch() -> data.DataLoader: + helpers.seed_everything(seed=42) + + BATCH_SIZE = 1024 + NUM_UPDATES = 4 + dataset = data.TensorDataset( + torch.randn((BATCH_SIZE * NUM_UPDATES, MODEL_NUM_INPUTS)), + torch.randn((BATCH_SIZE * NUM_UPDATES,)), + torch.randn((BATCH_SIZE * NUM_UPDATES, MODEL_NUM_INPUTS)), + torch.randn((BATCH_SIZE * NUM_UPDATES,)), + ) + loader = data.DataLoader(dataset, BATCH_SIZE, shuffle=False) + + return loader + + +@helpers.parametrize( + dtype=[torch.float64, torch.float32], + lr=[1e-3, 1e-4], + inner_lr=[2e-2, 2e-3], + inner_update=[20, 50, 100], +) +def test_imaml_solve_normal_cg( + dtype: torch.dtype, lr: float, inner_lr: float, inner_update: int +) -> None: + np_dtype = helpers.dtype_torch2numpy(dtype) + + jax_model, jax_params = get_model_jax(dtype=np_dtype) + model, loader = get_model_torch(device='cpu', dtype=dtype) + + fmodel, params = functorch.make_functional(model) + optim = torchopt.sgd(lr) + optim_state = optim.init(params) + + optim_jax = optax.sgd(lr) + optim_state_jax = optim_jax.init(jax_params) + + def imaml_objective_torchopt(params, meta_params, data): + x, y, f = data + y_pred = f(params, x) + regularization_loss = 0 + for p1, p2 in zip(params, meta_params): + regularization_loss += 0.5 * torch.sum(torch.square(p1 - p2)) + loss = F.cross_entropy(y_pred, y) + regularization_loss + return loss + + @torchopt.diff.implicit.custom_root( + functorch.grad(imaml_objective_torchopt, argnums=0), + argnums=1, + has_aux=True, + solve=torchopt.linear_solve.solve_normal_cg(), + ) + def inner_solver_torchopt(params, meta_params, data): + # Initial functional optimizer based on TorchOpt + x, y, f = data + optimizer = torchopt.sgd(lr=inner_lr) + opt_state = optimizer.init(params) + with torch.enable_grad(): + # Temporarily enable gradient computation for conducting the optimization + for _ in range(inner_update): + pred = f(params, x) + loss = F.cross_entropy(pred, y) # compute loss + # Compute regularization loss + regularization_loss = 0 + for p1, p2 in zip(params, meta_params): + regularization_loss += 0.5 * torch.sum(torch.square(p1 - p2)) + final_loss = loss + regularization_loss + grads = torch.autograd.grad(final_loss, params) # compute gradients + updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get updates + params = torchopt.apply_updates(params, updates, inplace=True) + return params, (0, {'a': 1, 'b': 2}) + + def imaml_objective_jax(params, meta_params, x, y): + y_pred = jax_model(params, x) + loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(y_pred, y)) + regularization_loss = 0 + for p1, p2 in zip(params.values(), meta_params.values()): + regularization_loss += 0.5 * jnp.sum(jnp.square(p1 - p2)) + loss = loss + regularization_loss + return loss + + @jaxopt.implicit_diff.custom_root( + jax.grad(imaml_objective_jax, argnums=0), + has_aux=True, + solve=jaxopt.linear_solve.solve_normal_cg, + ) + def inner_solver_jax(params, meta_params, x, y): + """Solve ridge regression by conjugate gradient.""" + # Initial functional optimizer based on torchopt + optimizer = optax.sgd(inner_lr) + opt_state = optimizer.init(params) + + def compute_loss(params, meta_params, x, y): + pred = jax_model(params, x) + loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(pred, y)) + # Compute regularization loss + regularization_loss = 0 + for p1, p2 in zip(params.values(), meta_params.values()): + regularization_loss += 0.5 * jnp.sum(jnp.square(p1 - p2)) + final_loss = loss + regularization_loss + return final_loss + + for i in range(inner_update): + grads = jax.grad(compute_loss)(params, meta_params, x, y) # compute gradients + updates, opt_state = optimizer.update(grads, opt_state) # get updates + params = optax.apply_updates(params, updates) + return params, (0, {'a': 1, 'b': 2}) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + data = (xs, ys, fmodel) + inner_params = pytree.tree_map( + lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), params + ) + optimal_inner_params, aux = inner_solver_torchopt(inner_params, params, data) + assert aux == (0, {'a': 1, 'b': 2}) + outer_loss = fmodel(optimal_inner_params, xs).mean() + + grads = torch.autograd.grad(outer_loss, params) + updates, optim_state = optim.update(grads, optim_state) + params = torchopt.apply_updates(params, updates) + + xs = xs.numpy() + ys = ys.numpy() + + def outer_level(p, xs, ys): + optimal_params, aux = inner_solver_jax(copy.deepcopy(p), p, xs, ys) + assert aux == (0, {'a': 1, 'b': 2}) + outer_loss = jax_model(optimal_params, xs).mean() + return outer_loss + + grads_jax = jax.grad(outer_level, argnums=0)(jax_params, xs, ys) + updates_jax, optim_state_jax = optim_jax.update(grads_jax, optim_state_jax) # get updates + jax_params = optax.apply_updates(jax_params, updates_jax) + + jax_params_as_tensor = tuple( + nn.Parameter(torch.tensor(np.asarray(jax_params[j]), dtype=dtype)) for j in jax_params + ) + + for p, p_ref in zip(params, jax_params_as_tensor): + helpers.assert_all_close(p, p_ref) + + +@helpers.parametrize( + dtype=[torch.float64, torch.float32], + lr=[1e-3, 1e-4], + inner_lr=[2e-2, 2e-3], + inner_update=[20, 50, 100], + ns=[False, True], +) +def test_imaml_solve_inv( + dtype: torch.dtype, + lr: float, + inner_lr: float, + inner_update: int, + ns: bool, +) -> None: + np_dtype = helpers.dtype_torch2numpy(dtype) + + jax_model, jax_params = get_model_jax(dtype=np_dtype) + model, loader = get_model_torch(device='cpu', dtype=dtype) + + fmodel, params = functorch.make_functional(model) + optim = torchopt.sgd(lr) + optim_state = optim.init(params) + + optim_jax = optax.sgd(lr) + optim_state_jax = optim_jax.init(jax_params) + + def imaml_objective_torchopt(params, meta_params, data): + x, y, f = data + y_pred = f(params, x) + regularization_loss = 0 + for p1, p2 in zip(params, meta_params): + regularization_loss += 0.5 * torch.sum(torch.square(p1 - p2)) + loss = F.cross_entropy(y_pred, y) + regularization_loss + return loss + + @torchopt.diff.implicit.custom_root( + functorch.grad(imaml_objective_torchopt, argnums=0), + argnums=1, + solve=torchopt.linear_solve.solve_inv(ns=ns), + ) + def inner_solver_torchopt(params, meta_params, data): + # Initial functional optimizer based on TorchOpt + x, y, f = data + optimizer = torchopt.sgd(lr=inner_lr) + opt_state = optimizer.init(params) + with torch.enable_grad(): + # Temporarily enable gradient computation for conducting the optimization + for _ in range(inner_update): + pred = f(params, x) + loss = F.cross_entropy(pred, y) # compute loss + # Compute regularization loss + regularization_loss = 0 + for p1, p2 in zip(params, meta_params): + regularization_loss += 0.5 * torch.sum(torch.square(p1 - p2)) + final_loss = loss + regularization_loss + grads = torch.autograd.grad(final_loss, params) # compute gradients + updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get updates + params = torchopt.apply_updates(params, updates, inplace=True) + return params + + def imaml_objective_jax(params, meta_params, x, y): + y_pred = jax_model(params, x) + loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(y_pred, y)) + regularization_loss = 0 + for p1, p2 in zip(params.values(), meta_params.values()): + regularization_loss += 0.5 * jnp.sum(jnp.square(p1 - p2)) + loss = loss + regularization_loss + return loss + + @jaxopt.implicit_diff.custom_root( + jax.grad(imaml_objective_jax, argnums=0), + solve=jaxopt.linear_solve.solve_normal_cg, + ) + def inner_solver_jax(params, meta_params, x, y): + """Solve ridge regression by conjugate gradient.""" + # Initial functional optimizer based on torchopt + optimizer = optax.sgd(inner_lr) + opt_state = optimizer.init(params) + + def compute_loss(params, meta_params, x, y): + pred = jax_model(params, x) + loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(pred, y)) + # Compute regularization loss + regularization_loss = 0 + for p1, p2 in zip(params.values(), meta_params.values()): + regularization_loss += 0.5 * jnp.sum(jnp.square(p1 - p2)) + final_loss = loss + regularization_loss + return final_loss + + for i in range(inner_update): + grads = jax.grad(compute_loss)(params, meta_params, x, y) # compute gradients + updates, opt_state = optimizer.update(grads, opt_state) # get updates + params = optax.apply_updates(params, updates) + return params + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + data = (xs, ys, fmodel) + inner_params = pytree.tree_map( + lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), params + ) + optimal_inner_params = inner_solver_torchopt(inner_params, params, data) + outer_loss = fmodel(optimal_inner_params, xs).mean() + + grads = torch.autograd.grad(outer_loss, params) + updates, optim_state = optim.update(grads, optim_state) + params = torchopt.apply_updates(params, updates) + + xs = xs.numpy() + ys = ys.numpy() + + def outer_level(p, xs, ys): + optimal_params = inner_solver_jax(copy.deepcopy(p), p, xs, ys) + outer_loss = jax_model(optimal_params, xs).mean() + return outer_loss + + grads_jax = jax.grad(outer_level, argnums=0)(jax_params, xs, ys) + updates_jax, optim_state_jax = optim_jax.update(grads_jax, optim_state_jax) # get updates + jax_params = optax.apply_updates(jax_params, updates_jax) + + jax_params_as_tensor = tuple( + nn.Parameter(torch.tensor(np.asarray(jax_params[j]), dtype=dtype)) for j in jax_params + ) + + for p, p_ref in zip(params, jax_params_as_tensor): + helpers.assert_all_close(p, p_ref) + + +@helpers.parametrize( + dtype=[torch.float64, torch.float32], + lr=[1e-3, 1e-4], + inner_lr=[2e-2, 2e-3], + inner_update=[20, 50, 100], +) +def test_imaml_module(dtype: torch.dtype, lr: float, inner_lr: float, inner_update: int) -> None: + np_dtype = helpers.dtype_torch2numpy(dtype) + + jax_model, jax_params = get_model_jax(dtype=np_dtype) + model, loader = get_model_torch(device='cpu', dtype=dtype) + + class InnerNet(ImplicitMetaGradientModule): + def __init__(self, meta_model): + super().__init__() + self.meta_model = meta_model + self.model = torchopt.module_clone(meta_model, by='deepcopy', detach_buffers=True) + + def forward(self, x): + return self.model(x) + + def objective(self, x, y): + y_pred = self.model(x) + loss = F.cross_entropy(y_pred, y) + regularization_loss = 0 + for p1, p2 in zip(self.parameters(), self.meta_parameters()): + regularization_loss += 0.5 * torch.sum(torch.square(p1 - p2)) + loss = loss + regularization_loss + return loss + + def solve(self, x, y): + params = tuple(self.parameters()) + optim_inner = torchopt.SGD(params, lr=inner_lr) + with torch.enable_grad(): + # Temporarily enable gradient computation for conducting the optimization + for _ in range(inner_update): + loss = self.objective(x, y) + optim_inner.zero_grad() + loss.backward(inputs=params) + optim_inner.step() + + return self, (0, {'a': 1, 'b': 2}) + + outer_optim = torchopt.SGD(model.parameters(), lr) + + optim_jax = optax.sgd(lr) + optim_state_jax = optim_jax.init(jax_params) + + def imaml_objective_jax(params, meta_params, x, y): + y_pred = jax_model(params, x) + loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(y_pred, y)) + regularization_loss = 0 + for p1, p2 in zip(params.values(), meta_params.values()): + regularization_loss += 0.5 * jnp.sum(jnp.square(p1 - p2)) + loss = loss + regularization_loss + return loss + + @jaxopt.implicit_diff.custom_root(jax.grad(imaml_objective_jax, argnums=0), has_aux=True) + def inner_solver_jax(params, meta_params, x, y): + """Solve ridge regression by conjugate gradient.""" + # Initial functional optimizer based on torchopt + optimizer = optax.sgd(inner_lr) + opt_state = optimizer.init(params) + + def compute_loss(params, meta_params, x, y): + pred = jax_model(params, x) + loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(pred, y)) + # Compute regularization loss + regularization_loss = 0 + for p1, p2 in zip(params.values(), meta_params.values()): + regularization_loss += 0.5 * jnp.sum(jnp.square(p1 - p2)) + final_loss = loss + regularization_loss + return final_loss + + for i in range(inner_update): + grads = jax.grad(compute_loss)(params, meta_params, x, y) # compute gradients + updates, opt_state = optimizer.update(grads, opt_state) # get updates + params = optax.apply_updates(params, updates) + return params, (0, {'a': 1, 'b': 2}) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + inner_model = InnerNet(model) + optimal_inner_model, aux = inner_model.solve(xs, ys) + assert aux == (0, {'a': 1, 'b': 2}) + outer_loss = optimal_inner_model(xs).mean() + + outer_optim.zero_grad() + outer_loss.backward() + outer_optim.step() + + xs = xs.numpy() + ys = ys.numpy() + + def outer_level(p, xs, ys): + optimal_params, aux = inner_solver_jax(copy.deepcopy(p), p, xs, ys) + assert aux == (0, {'a': 1, 'b': 2}) + outer_loss = jax_model(optimal_params, xs).mean() + return outer_loss + + grads_jax = jax.grad(outer_level, argnums=0)(jax_params, xs, ys) + updates_jax, optim_state_jax = optim_jax.update(grads_jax, optim_state_jax) # get updates + jax_params = optax.apply_updates(jax_params, updates_jax) + + jax_params_as_tensor = tuple( + nn.Parameter(torch.tensor(np.asarray(jax_params[j]), dtype=dtype)) for j in jax_params + ) + + for p, p_ref in zip(model.parameters(), jax_params_as_tensor): + helpers.assert_all_close(p, p_ref) + + +@helpers.parametrize( + dtype=[torch.float64, torch.float32], + lr=[1e-3, 1e-4], +) +def test_rr_solve_cg( + dtype: torch.dtype, + lr: float, +) -> None: + helpers.seed_everything(42) + np_dtype = helpers.dtype_torch2numpy(dtype) + input_size = 10 + + init_params_torch = torch.randn(input_size, dtype=dtype) + l2reg_torch = torch.rand(1, dtype=dtype).squeeze_().requires_grad_(True) + + init_params_jax = jnp.array(init_params_torch.detach().numpy(), dtype=np_dtype) + l2reg_jax = jnp.array(l2reg_torch.detach().numpy(), dtype=np_dtype) + + loader = get_rr_dataset_torch() + + optim = torchopt.sgd(lr) + optim_state = optim.init(l2reg_torch) + + optim_jax = optax.sgd(lr) + optim_state_jax = optim_jax.init(l2reg_jax) + + def ridge_objective_torch(params, l2reg, data): + """Ridge objective function.""" + X_tr, y_tr = data + residuals = X_tr @ params - y_tr + regularization_loss = 0.5 * l2reg * torch.sum(torch.square(params)) + return 0.5 * torch.mean(torch.square(residuals)) + regularization_loss + + @torchopt.diff.implicit.custom_root(functorch.grad(ridge_objective_torch, argnums=0), argnums=1) + def ridge_solver_torch_cg(params, l2reg, data): + """Solve ridge regression by conjugate gradient.""" + X_tr, y_tr = data + + def matvec(u): + return X_tr.T @ (X_tr @ u) + + solve = torchopt.linear_solve.solve_cg( + ridge=len(y_tr) * l2reg.item(), + init=params, + maxiter=20, + ) + + return solve(matvec=matvec, b=X_tr.T @ y_tr) + + def ridge_objective_jax(params, l2reg, X_tr, y_tr): + """Ridge objective function.""" + residuals = X_tr @ params - y_tr + regularization_loss = 0.5 * l2reg * jnp.sum(jnp.square(params)) + return 0.5 * jnp.mean(jnp.square(residuals)) + regularization_loss + + @jaxopt.implicit_diff.custom_root(jax.grad(ridge_objective_jax, argnums=0)) + def ridge_solver_jax_cg(params, l2reg, X_tr, y_tr): + """Solve ridge regression by conjugate gradient.""" + + def matvec(u): + return X_tr.T @ (X_tr @ u) + + return jaxopt.linear_solve.solve_cg( + matvec=matvec, + b=X_tr.T @ y_tr, + ridge=len(y_tr) * l2reg.item(), + init=params, + maxiter=20, + ) + + for xs, ys, xq, yq in loader: + xs = xs.to(dtype=dtype) + ys = ys.to(dtype=dtype) + xq = xq.to(dtype=dtype) + yq = yq.to(dtype=dtype) + + w_fit = ridge_solver_torch_cg(init_params_torch, l2reg_torch, (xs, ys)) + outer_loss = F.mse_loss(xq @ w_fit, yq) + + grads, *_ = torch.autograd.grad(outer_loss, l2reg_torch) + updates, optim_state = optim.update(grads, optim_state) + l2reg_torch = torchopt.apply_updates(l2reg_torch, updates) + + xs = jnp.array(xs.numpy(), dtype=np_dtype) + ys = jnp.array(ys.numpy(), dtype=np_dtype) + xq = jnp.array(xq.numpy(), dtype=np_dtype) + yq = jnp.array(yq.numpy(), dtype=np_dtype) + + def outer_level(params_jax, l2reg_jax, xs, ys, xq, yq): + w_fit = ridge_solver_jax_cg(params_jax, l2reg_jax, xs, ys) + y_pred = xq @ w_fit + loss_value = jnp.mean(jnp.square(y_pred - yq)) + return loss_value + + grads_jax = jax.grad(outer_level, argnums=1)(init_params_jax, l2reg_jax, xs, ys, xq, yq) + updates_jax, optim_state_jax = optim_jax.update(grads_jax, optim_state_jax) # get updates + l2reg_jax = optax.apply_updates(l2reg_jax, updates_jax) + + l2reg_jax_as_tensor = torch.tensor(np.asarray(l2reg_jax), dtype=dtype) + helpers.assert_all_close(l2reg_torch, l2reg_jax_as_tensor) + + +@helpers.parametrize( + dtype=[torch.float64, torch.float32], + lr=[1e-3, 1e-4], + ns=[True, False], +) +def test_rr_solve_inv( + dtype: torch.dtype, + lr: float, + ns: bool, +) -> None: + if dtype == torch.float64 and ns: + pytest.skip('Neumann Series test skips torch.float64 due to numerical stability.') + helpers.seed_everything(42) + np_dtype = helpers.dtype_torch2numpy(dtype) + input_size = 10 + + init_params_torch = torch.randn(input_size, dtype=dtype) + l2reg_torch = torch.rand(1, dtype=dtype).squeeze_().requires_grad_(True) + + init_params_jax = jnp.array(init_params_torch.detach().numpy(), dtype=np_dtype) + l2reg_jax = jnp.array(l2reg_torch.detach().numpy(), dtype=np_dtype) + + loader = get_rr_dataset_torch() + + optim = torchopt.sgd(lr) + optim_state = optim.init(l2reg_torch) + + optim_jax = optax.sgd(lr) + optim_state_jax = optim_jax.init(l2reg_jax) + + def ridge_objective_torch(params, l2reg, data): + """Ridge objective function.""" + X_tr, y_tr = data + residuals = X_tr @ params - y_tr + regularization_loss = 0.5 * l2reg * torch.sum(torch.square(params)) + return 0.5 * torch.mean(torch.square(residuals)) + regularization_loss + + @torchopt.diff.implicit.custom_root(functorch.grad(ridge_objective_torch, argnums=0), argnums=1) + def ridge_solver_torch_inv(params, l2reg, data): + """Solve ridge regression by conjugate gradient.""" + X_tr, y_tr = data + + def matvec(u): + return X_tr.T @ (X_tr @ u) + + solve = torchopt.linear_solve.solve_inv( + matvec=matvec, + b=X_tr.T @ y_tr, + ridge=len(y_tr) * l2reg.item(), + ns=ns, + ) + + return solve(matvec=matvec, b=X_tr.T @ y_tr) + + def ridge_objective_jax(params, l2reg, X_tr, y_tr): + """Ridge objective function.""" + residuals = X_tr @ params - y_tr + regularization_loss = 0.5 * l2reg * jnp.sum(jnp.square(params)) + return 0.5 * jnp.mean(jnp.square(residuals)) + regularization_loss + + @jaxopt.implicit_diff.custom_root(jax.grad(ridge_objective_jax, argnums=0)) + def ridge_solver_jax_inv(params, l2reg, X_tr, y_tr): + """Solve ridge regression by conjugate gradient.""" + + def matvec(u): + return X_tr.T @ (X_tr @ u) + + return jaxopt.linear_solve.solve_inv( + matvec=matvec, + b=X_tr.T @ y_tr, + ridge=len(y_tr) * l2reg.item(), + ) + + for xs, ys, xq, yq in loader: + xs = xs.to(dtype=dtype) + ys = ys.to(dtype=dtype) + xq = xq.to(dtype=dtype) + yq = yq.to(dtype=dtype) + + w_fit = ridge_solver_torch_inv(init_params_torch, l2reg_torch, (xs, ys)) + outer_loss = F.mse_loss(xq @ w_fit, yq) + + grads, *_ = torch.autograd.grad(outer_loss, l2reg_torch) + updates, optim_state = optim.update(grads, optim_state) + l2reg_torch = torchopt.apply_updates(l2reg_torch, updates) + + xs = jnp.array(xs.numpy(), dtype=np_dtype) + ys = jnp.array(ys.numpy(), dtype=np_dtype) + xq = jnp.array(xq.numpy(), dtype=np_dtype) + yq = jnp.array(yq.numpy(), dtype=np_dtype) + + def outer_level(params_jax, l2reg_jax, xs, ys, xq, yq): + w_fit = ridge_solver_jax_inv(params_jax, l2reg_jax, xs, ys) + y_pred = xq @ w_fit + loss_value = jnp.mean(jnp.square(y_pred - yq)) + return loss_value + + grads_jax = jax.grad(outer_level, argnums=1)(init_params_jax, l2reg_jax, xs, ys, xq, yq) + updates_jax, optim_state_jax = optim_jax.update(grads_jax, optim_state_jax) # get updates + l2reg_jax = optax.apply_updates(l2reg_jax, updates_jax) + + l2reg_jax_as_tensor = torch.tensor(np.asarray(l2reg_jax), dtype=dtype) + helpers.assert_all_close(l2reg_torch, l2reg_jax_as_tensor) diff --git a/torchopt/_src/typing.py b/tests/test_meta_optim.py similarity index 66% rename from torchopt/_src/typing.py rename to tests/test_meta_optim.py index b2104682..5916574e 100644 --- a/torchopt/_src/typing.py +++ b/tests/test_meta_optim.py @@ -13,16 +13,11 @@ # limitations under the License. # ============================================================================== -from typing import Any, Callable, Iterable, Mapping, TypeVar, Union +import helpers +import torchopt -from torch import Tensor +def test_filter_nones_in_params(): + model = helpers.get_models()[0] -Scalar = TypeVar('Scalar', float, int) -Numeric = Union[Tensor, Scalar] - -Schedule = Callable[[Numeric], Numeric] -ScalarOrSchedule = Union[float, Schedule] - -# mypy: ignore-errors -TensorTree = Union[Tensor, Iterable['TensorTree'], Mapping[Any, 'TensorTree']] + meta_adam = torchopt.MetaAdam(model) diff --git a/tests/test_optimizer.py b/tests/test_optim.py similarity index 85% rename from tests/test_optimizer.py rename to tests/test_optim.py index c0db3e34..fe1697c9 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optim.py @@ -13,8 +13,9 @@ # limitations under the License. # ============================================================================== -from typing import Tuple +from typing import Callable, Tuple +import functorch import pytest import torch import torch.nn.functional as F @@ -30,7 +31,7 @@ dampening=[0.0, 0.5], nesterov=[False, True], weight_decay=[0.0, 1e-2], - maximize=[False], # TODO: test maximize after PyTorch 1.13 + maximize=[False, True], ) def test_SGD( dtype: torch.dtype, @@ -364,3 +365,56 @@ def test_RMSProp( optim_ref.step() helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype) + + +@helpers.parametrize( + dtype=[torch.float64, torch.float32], + lr=[1e-2, 1e-3], + optimizers=[ + (torchopt.sgd, torch.optim.SGD), + (torchopt.adam, torch.optim.Adam), + (torchopt.adamw, torch.optim.AdamW), + (torchopt.rmsprop, torch.optim.RMSprop), + ], + inplace=[True, False], + weight_decay=[0.0, 1e-2], +) +def test_FuncOptimizer( + dtype: torch.dtype, + lr: float, + optimizers: Tuple[Callable, torch.optim.Optimizer], + inplace: bool, + weight_decay: float, +) -> None: + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + torchopt_optimizer, torch_optimizer = optimizers + + fmodel, params, buffers = functorch.make_functional_with_buffers(model) + optim = torchopt.FuncOptimizer( + torchopt_optimizer( + lr=lr, + weight_decay=weight_decay, + ), + inplace=inplace, + ) + optim_ref = torch_optimizer( + model_ref.parameters(), + lr, + weight_decay=weight_decay, + ) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = fmodel(params, buffers, xs) + pred_ref = model_ref(xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + params = optim.step(loss, params) + + optim_ref.zero_grad() + loss_ref.backward() + optim_ref.step() + + helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) diff --git a/tests/test_schedule.py b/tests/test_schedule.py index 971c0de4..67e3429a 100644 --- a/tests/test_schedule.py +++ b/tests/test_schedule.py @@ -13,8 +13,14 @@ # limitations under the License. # ============================================================================== +from typing import Callable, Tuple + +import functorch import numpy as np +import torch +import torch.nn.functional as F +import helpers import torchopt @@ -35,3 +41,64 @@ def test_linear_schedule() -> None: lr = schedule(i) lr_gt = init_value - gap_value * (i - transition_begin) / transition_steps assert np.allclose(lr, lr_gt) + + +@helpers.parametrize( + dtype=[torch.float64, torch.float32], + lr=[1e-2, 1e-3], + total_iters=[helpers.NUM_UPDATES, helpers.NUM_UPDATES * 2], + optimizers=[ + (torchopt.sgd, torch.optim.SGD), + (torchopt.adam, torch.optim.Adam), + (torchopt.adamw, torch.optim.AdamW), + (torchopt.rmsprop, torch.optim.RMSprop), + ], + inplace=[True, False], + weight_decay=[0.0, 1e-2], +) +def test_lr_linear_schedule( + dtype: torch.dtype, + lr: float, + total_iters: int, + optimizers: Tuple[Callable, torch.optim.Optimizer], + inplace: bool, + weight_decay: float, +) -> None: + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + torchopt_optimizer, torch_optimizer = optimizers + + fmodel, params, buffers = functorch.make_functional_with_buffers(model) + optim = torchopt_optimizer( + torchopt.schedule.linear_schedule( + init_value=lr, end_value=0.1 * lr, transition_steps=total_iters, transition_begin=0 + ), + weight_decay=weight_decay, + ) + optim_state = optim.init(params) + optim_ref = torch_optimizer( + model_ref.parameters(), + lr, + weight_decay=weight_decay, + ) + torch_scheduler = torch.optim.lr_scheduler.LinearLR( + optim_ref, start_factor=1.0, end_factor=0.1, total_iters=total_iters + ) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = fmodel(params, buffers, xs) + pred_ref = model_ref(xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + grads = torch.autograd.grad(loss, params, allow_unused=True) + updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) + params = torchopt.apply_updates(params, updates, inplace=inplace) + + optim_ref.zero_grad() + loss_ref.backward() + optim_ref.step() + torch_scheduler.step() + + helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) diff --git a/tests/test_zero_order.py b/tests/test_zero_order.py new file mode 100644 index 00000000..32d3ae3b --- /dev/null +++ b/tests/test_zero_order.py @@ -0,0 +1,79 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import functorch +import torch +import torch.nn as nn +import torch.types + +import helpers +import torchopt + + +BATCH_SIZE = 8 +NUM_UPDATES = 5 + + +class FcNet(nn.Module): + def __init__(self, dim, out): + super().__init__() + self.fc = nn.Linear(in_features=dim, out_features=out, bias=True) + nn.init.ones_(self.fc.weight) + nn.init.zeros_(self.fc.bias) + + def forward(self, x): + return self.fc(x) + + +@helpers.parametrize( + dtype=[torch.float64, torch.float32], + lr=[1e-2, 1e-3], + method=['naive', 'forward', 'antithetic'], + sigma=[0.01, 0.1, 1], +) +def test_zero_order(dtype: torch.dtype, lr: float, method: str, sigma: float) -> None: + helpers.seed_everything(42) + input_size = 32 + output_size = 1 + batch_size = BATCH_SIZE + coef = 0.1 + num_iterations = NUM_UPDATES + num_samples = 500 + + model = FcNet(input_size, output_size) + + fmodel, params = functorch.make_functional(model) + x = torch.randn(batch_size, input_size) * coef + y = torch.randn(input_size) * coef + distribution = torch.distributions.Normal(loc=0, scale=1) + + @torchopt.diff.zero_order.zero_order( + distribution=distribution, method=method, argnums=0, sigma=sigma, num_samples=num_samples + ) + def forward_process(params, fn, x, y): + y_pred = fn(params, x) + loss = torch.mean((y - y_pred) ** 2) + return loss + + optimizer = torchopt.adam(lr=lr) + opt_state = optimizer.init(params) + + for i in range(num_iterations): + opt_state = optimizer.init(params) # init optimizer + loss = forward_process(params, fmodel, x, y) # compute loss + + grads = torch.autograd.grad(loss, params) # compute gradients + updates, opt_state = optimizer.update(grads, opt_state) # get updates + params = torchopt.apply_updates(params, updates) # update network parameters diff --git a/torchopt/_C/adam_op.pyi b/torchopt/_C/adam_op.pyi index 7b98a576..39d51a5a 100644 --- a/torchopt/_C/adam_op.pyi +++ b/torchopt/_C/adam_op.pyi @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== +# pylint: disable=all # isort: off from typing import Tuple @@ -29,9 +30,9 @@ def forward_( eps_root: float, count: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ... -def forwardMu(updates: torch.Tensor, mu: torch.Tensor, b1: float) -> torch.Tensor: ... -def forwardNu(updates: torch.Tensor, nu: torch.Tensor, b2: float) -> torch.Tensor: ... -def forwardUpdates( +def forward_mu(updates: torch.Tensor, mu: torch.Tensor, b1: float) -> torch.Tensor: ... +def forward_nu(updates: torch.Tensor, nu: torch.Tensor, b2: float) -> torch.Tensor: ... +def forward_updates( new_mu: torch.Tensor, new_nu: torch.Tensor, b1: float, @@ -40,13 +41,13 @@ def forwardUpdates( eps_root: float, count: int, ) -> torch.Tensor: ... -def backwardMu( +def backward_mu( dmu: torch.Tensor, updates: torch.Tensor, mu: torch.Tensor, b1: float ) -> Tuple[torch.Tensor, torch.Tensor]: ... -def backwardNu( +def backward_nu( dnu: torch.Tensor, updates: torch.Tensor, nu: torch.Tensor, b2: float ) -> Tuple[torch.Tensor, torch.Tensor]: ... -def backwardUpdates( +def backward_updates( dupdates: torch.Tensor, updates: torch.Tensor, new_mu: torch.Tensor, diff --git a/torchopt/__init__.py b/torchopt/__init__.py index ab7a5a4d..db78f217 100644 --- a/torchopt/__init__.py +++ b/torchopt/__init__.py @@ -14,12 +14,27 @@ # ============================================================================== """TorchOpt: a high-performance optimizer library built upon PyTorch.""" -from torchopt._src import accelerated_op_available, clip, combine, hook, schedule, visual -from torchopt._src.alias import adam, adamw, rmsprop, sgd -from torchopt._src.clip import clip_grad_norm -from torchopt._src.combine import chain -from torchopt._src.optimizer import SGD, Adam, AdamW, Optimizer, RMSProp, RMSprop, meta -from torchopt._src.optimizer.meta import ( +from torchopt import ( + clip, + combine, + diff, + distributed, + hook, + linear_solve, + nn, + pytree, + schedule, + typing, + visual, +) +from torchopt.accelerated_op import is_available as accelerated_op_available +from torchopt.alias import adam, adamw, rmsprop, sgd +from torchopt.clip import clip_grad_norm +from torchopt.combine import chain +from torchopt.hook import register_hook +from torchopt.optim import SGD, Adam, AdamW, Optimizer, RMSProp, RMSprop, meta +from torchopt.optim.func import FuncOptimizer +from torchopt.optim.meta import ( MetaAdam, MetaAdamW, MetaOptimizer, @@ -27,23 +42,28 @@ MetaRMSprop, MetaSGD, ) -from torchopt._src.update import apply_updates -from torchopt._src.utils import extract_state_dict, recover_state_dict, stop_gradient +from torchopt.transform import nan_to_num +from torchopt.update import apply_updates +from torchopt.utils import ( + extract_state_dict, + module_clone, + module_detach_, + recover_state_dict, + stop_gradient, +) from torchopt.version import __version__ __all__ = [ 'accelerated_op_available', - 'clip', - 'combine', - 'hook', - 'schedule', - 'visual', + 'diff', 'adam', 'adamw', 'rmsprop', 'sgd', 'clip_grad_norm', + 'nan_to_num', + 'register_hook', 'chain', 'Optimizer', 'SGD', @@ -57,8 +77,11 @@ 'MetaAdamW', 'MetaRMSProp', 'MetaRMSprop', + 'FuncOptimizer', 'apply_updates', 'extract_state_dict', 'recover_state_dict', 'stop_gradient', + 'module_clone', + 'module_detach_', ] diff --git a/torchopt/_src/alias.py b/torchopt/_src/alias.py deleted file mode 100644 index 40b2e92d..00000000 --- a/torchopt/_src/alias.py +++ /dev/null @@ -1,468 +0,0 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -# This file is modified from: -# https://github.com/deepmind/optax/blob/master/optax/_src/alias.py -# ============================================================================== -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -# pylint: disable=invalid-name - -from typing import Any, Callable, Optional, Tuple, Union - -from torchopt._src import base, combine, transform -from torchopt._src.typing import ScalarOrSchedule - - -def _flip_sign_and_weight_decay(weight_decay: float = 0.0, maximize=False): - if not 0.0 <= weight_decay: # pylint: disable=unneeded-not - raise ValueError(f'Invalid weight_decay value: {weight_decay}') - - if not maximize and weight_decay == 0.0: - return base.identity() - - def init_fn(params): # pylint: disable=unused-argument - return base.EmptyState() - - if not maximize: # gradient descent - - def update_fn(updates, state, *, params=None, inplace=True): - assert params is not None, ( - 'Parameters are required for weight decay. ' - 'Call `update(updates, state, params=params)` instead.' - ) - - if inplace: - - def f(g, p): - if g is not None: - if g.requires_grad: - return g.add_(p, alpha=weight_decay) - return g.add_(p.data, alpha=weight_decay) - return None - - else: - - def f(g, p): - return g.add(p, alpha=weight_decay) if g is not None else None - - updates = transform.map_flattened(f, updates, params) - return updates, state - - else: # gradient ascent - - if weight_decay == 0.0: - # pylint: disable-next=unused-argument - def update_fn(updates, state, *, params=None, inplace=True): - if inplace: - - def f(g): - return g.neg_() if g is not None else None - - else: - - def f(g): - return g.neg() if g is not None else None - - updates = transform.map_flattened(f, updates) - return updates, state - - else: - - def update_fn(updates, state, *, params=None, inplace=True): - assert params is not None, ( - 'Parameters are required for weight decay. ' - 'Call `update(updates, state, params=params)` instead.' - ) - - if inplace: - - def f(g, p): - if g is not None: - if g.requires_grad: - return g.neg_().add_(p, alpha=weight_decay) - return g.neg_().add_(p.data, alpha=weight_decay) - return None - - else: - - def f(g, p): - return g.neg().add_(p, alpha=weight_decay) if g is not None else None - - updates = transform.map_flattened(f, updates, params) - return updates, state - - return base.GradientTransformation(init_fn, update_fn) - - -def _scale_by_neg_lr(lr: ScalarOrSchedule): - if not (callable(lr) or 0.0 <= lr): - raise ValueError(f'Invalid learning rate: {lr}') - - if callable(lr): - - def schedule_wrapper(count): - def f(scaled_lr): - return -scaled_lr - - return transform.map_flattened(f, lr(count)) # type: ignore[operator] - - return transform._scale_by_schedule( # pylint: disable=protected-access - schedule_wrapper, already_flattened=True - ) - return transform._scale(-lr, already_flattened=True) # pylint: disable=protected-access - - -# pylint: disable-next=too-many-arguments -def adam( - lr: ScalarOrSchedule = 1e-3, - betas: Tuple[float, float] = (0.9, 0.999), - eps: float = 1e-8, - weight_decay: float = 0.0, - *, - eps_root: float = 0.0, - moment_requires_grad: bool = False, - maximize: bool = False, - use_accelerated_op: bool = False, -) -> base.GradientTransformation: - """The functional Adam optimizer. - - Adam is an SGD variant with learning rate adaptation. The *learning rate* used for each weight - is computed from estimates of first- and second-order moments of the gradients (using suitable - exponential moving averages). - - References: - - Kingma et al, 2014: https://arxiv.org/abs/1412.6980 - - Args: - lr: (default: :const:`1e-3`) - This is a fixed global scaling factor. - betas: (default: :const:`(0.9, 0.999)`) - Coefficients used for computing running averages of gradient and its square. - eps: (default: :const:`1e-8`) - A small constant applied to denominator outside of the square root (as in the Adam - paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - eps_root: (default: :data:`0.0`) - A small constant applied to denominator inside the square root (as in RMSProp), to avoid - dividing by zero when rescaling. This is needed for example when computing - (meta-)gradients through Adam. - moment_requires_grad: (default: :data:`False`) - If :data:`True` the momentums will be created with flag ``requires_grad=True``, this - flag is often used in Meta-Learning algorithms. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - use_accelerated_op: (default: :data:`False`) - If :data:`True` use our implemented fused operator. - - Returns: - The corresponding :class:`GradientTransformation` instance. - """ - b1, b2 = betas - # pylint: disable=unneeded-not - if not (callable(lr) or 0.0 <= lr): - raise ValueError(f'Invalid learning rate: {lr}') - if not 0.0 <= eps: - raise ValueError(f'Invalid epsilon value: {eps}') - if not 0.0 <= b1 < 1.0: - raise ValueError(f'Invalid beta parameter at index 0: {b1}') - if not 0.0 <= b2 < 1.0: - raise ValueError(f'Invalid beta parameter at index 1: {b2}') - if not 0.0 <= weight_decay: - raise ValueError(f'Invalid weight_decay value: {weight_decay}') - # pylint: enable=unneeded-not - - if use_accelerated_op: - adam_scaler = transform._scale_by_accelerated_adam # pylint: disable=protected-access - else: - adam_scaler = transform._scale_by_adam # pylint: disable=protected-access - - return transform.with_flattened_tree( - combine.chain( - _flip_sign_and_weight_decay(weight_decay=weight_decay, maximize=maximize), - adam_scaler( - b1=b1, - b2=b2, - eps=eps, - eps_root=eps_root, - moment_requires_grad=moment_requires_grad, - already_flattened=True, - ), - _scale_by_neg_lr(lr), - ) - ) - - -# pylint: disable-next=too-many-arguments -def adamw( - lr: ScalarOrSchedule = 1e-3, - betas: Tuple[float, float] = (0.9, 0.999), - eps: float = 1e-8, - weight_decay: float = 1e-2, - *, - eps_root: float = 0.0, - mask: Optional[Union[Any, Callable[['base.Params'], Any]]] = None, - moment_requires_grad: bool = False, - maximize: bool = False, - use_accelerated_op: bool = False, -) -> base.GradientTransformation: - """Adam with weight decay regularization. - - AdamW uses weight decay to regularize learning towards small weights, as - this leads to better generalization. In SGD you can also use L2 regularization - to implement this as an additive loss term, however L2 regularization - does not behave as intended for adaptive gradient algorithms such as Adam. - - References: - - Loshchilov et al, 2019: https://arxiv.org/abs/1711.05101 - - Args: - lr: (default: :const:`1e-3`) - This is a fixed global scaling factor. - betas: (default: :const:`(0.9, 0.999)`) - Coefficients used for computing running averages of gradient and its square. - eps: (default: :const:`1e-8`) - A small constant applied to denominator outside of the square root (as in the Adam - paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`1e-2`) - Strength of the weight decay regularization. Note that this weight decay is multiplied - with the learning rate. This is consistent with other frameworks such as PyTorch, but - different from (Loshchilov et al, 2019) where the weight decay is only multiplied with - the "schedule multiplier", but not the base learning rate. - eps_root: (default: :data:`0.0`) - A small constant applied to denominator inside the square root (as in RMSProp), to avoid - dividing by zero when rescaling. This is needed for example when computing - (meta-)gradients through Adam. - mask: (default: :data:`None`) - A tree with same structure as (or a prefix of) the params PyTree, or a Callable that - returns such a pytree given the params/updates. The leaves should be booleans, - :data:`True` for leaves/subtrees you want to apply the weight decay to, and - :data:`False` for those you want to skip. Note that the Adam gradient - transformations are applied to all parameters. - moment_requires_grad: (default: :data:`False`) - If :data:`True` the momentums will be created with flag ``requires_grad=True``, this - flag is often used in Meta-Learning algorithms. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - use_accelerated_op: (default: :data:`False`) - If :data:`True` use our implemented fused operator. - - Returns: - The corresponding :class:`GradientTransformation` instance. - """ - b1, b2 = betas - # pylint: disable=unneeded-not - if not (callable(lr) or 0.0 <= lr): - raise ValueError(f'Invalid learning rate: {lr}') - if not 0.0 <= eps: - raise ValueError(f'Invalid epsilon value: {eps}') - if not 0.0 <= b1 < 1.0: - raise ValueError(f'Invalid beta parameter at index 0: {b1}') - if not 0.0 <= b2 < 1.0: - raise ValueError(f'Invalid beta parameter at index 1: {b2}') - if not 0.0 <= weight_decay: - raise ValueError(f'Invalid weight_decay value: {weight_decay}') - # pylint: enable=unneeded-not - - if use_accelerated_op: - adam_scaler = transform._scale_by_accelerated_adam # pylint: disable=protected-access - else: - adam_scaler = transform._scale_by_adam # pylint: disable=protected-access - - return transform.with_flattened_tree( - combine.chain( - _flip_sign_and_weight_decay(weight_decay=0.0, maximize=maximize), - adam_scaler( - b1=b1, - b2=b2, - eps=eps, - eps_root=eps_root, - moment_requires_grad=moment_requires_grad, - already_flattened=True, - ), - transform._add_decayed_weights( # pylint: disable=protected-access - weight_decay=weight_decay, - mask=mask, - already_flattened=True, - ), - _scale_by_neg_lr(lr), - ) - ) - - -# pylint: disable-next=too-many-arguments -def rmsprop( - lr: ScalarOrSchedule = 1e-2, - alpha: float = 0.9, - eps: float = 1e-8, - weight_decay: float = 0.0, - momentum: float = 0.0, - centered: bool = False, - *, - initial_scale: float = 0.0, - nesterov: bool = False, - maximize: bool = False, -) -> base.GradientTransformation: - """The functional version of the RMSProp optimizer. - - RMSProp is an SGD variant with learning rate adaptation. The *learning rate* used for each - weight is scaled by a suitable estimate of the magnitude of the gradients on previous steps. - Several variants of RMSProp can be found in the literature. This alias provides an easy to - configure RMSProp optimizer that can be used to switch between several of these variants. - - References: - - Tieleman and Hinton, 2012: http://www.cs.toronto.edu/~hinton/coursera/lecture6/lec6.pdf - - Graves, 2013: https://arxiv.org/abs/1308.0850 - - Args: - lr: (default: :const:`1e-2`) - This is a fixed global scaling factor. - alpha: (default: :const:`0.99`) - Smoothing constant, the decay used to track the magnitude of previous gradients. - eps: (default: :const:`1e-8`) - A small numerical constant to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - momentum: (default: :const:`0.0`) - The decay rate used by the momentum term. The momentum is not used when it is set to - :const:`0.0`. - centered: (default: :data:`False`) - If :data:`True`, use the variance of the past gradients to rescale the latest - gradients. - initial_scale: (default: :data:`0.0`) - Initialization of accumulators tracking the magnitude of previous updates. PyTorch - uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When reproducing results from a - paper, verify the value used by the authors. - nesterov: (default: :data:`False`) - Whether to use Nesterov momentum. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - - Returns: - The corresponding :class:`GradientTransformation` instance. - """ - # pylint: disable=unneeded-not - if not (callable(lr) or 0.0 <= lr): - raise ValueError(f'Invalid learning rate: {lr}') - if not 0.0 <= alpha: - raise ValueError(f'Invalid alpha value: {alpha}') - if not 0.0 <= eps: - raise ValueError(f'Invalid epsilon value: {eps}') - if not 0.0 <= momentum: - raise ValueError(f'Invalid momentum value: {momentum}') - if not 0.0 <= weight_decay: - raise ValueError(f'Invalid weight_decay value: {weight_decay}') - # pylint: enable=unneeded-not - - if centered: - rmsprop_scaler = transform._scale_by_stddev # pylint: disable=protected-access - else: - rmsprop_scaler = transform._scale_by_rms # pylint: disable=protected-access - - return transform.with_flattened_tree( - combine.chain( - _flip_sign_and_weight_decay(weight_decay=weight_decay, maximize=maximize), - rmsprop_scaler( - alpha=alpha, - eps=eps, - initial_scale=initial_scale, - already_flattened=True, - ), - transform._trace( # pylint: disable=protected-access - momentum=momentum, - nesterov=nesterov, - already_flattened=True, - ), - _scale_by_neg_lr(lr), - ) - ) - - -def sgd( - lr: ScalarOrSchedule, - momentum: float = 0.0, - dampening: float = 0.0, - weight_decay: float = 0.0, - nesterov: bool = False, - *, - moment_requires_grad: bool = False, - maximize: bool = False, -) -> base.GradientTransformation: - """The functional version of the canonical Stochastic Gradient Descent optimizer. - - This implements stochastic gradient descent. It also includes support for momentum, and nesterov - acceleration, as these are standard practice when using stochastic gradient descent to train - deep neural networks. - - References: - - Sutskever et al, 2013: http://proceedings.mlr.press/v28/sutskever13.pdf - - Args: - lr: This is a fixed global scaling factor. - momentum: (default: :const:`0.0`) - The decay rate used by the momentum term. The momentum is not used when it is set to - :const:`0.0`. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - dampening: (default: :const:`0.0`) - Dampening for momentum. - nesterov: (default: :data:`False`) - Whether to use Nesterov momentum. - moment_requires_grad: (default: :data:`False`) - If :data:`True` the momentums will be created with flag ``requires_grad=True``, this - flag is often used in Meta-Learning algorithms. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - - Returns: - The corresponding :class:`GradientTransformation` instance. - """ - # pylint: disable=unneeded-not - if not (callable(lr) or 0.0 <= lr): - raise ValueError(f'Invalid learning rate: {lr}') - if not 0.0 <= momentum: - raise ValueError(f'Invalid momentum value: {momentum}') - if not 0.0 <= weight_decay: - raise ValueError(f'Invalid weight_decay value: {weight_decay}') - if nesterov and (momentum <= 0.0 or dampening != 0.0): - raise ValueError('Nesterov momentum requires a momentum and zero dampening') - # pylint: enable=unneeded-not - - return transform.with_flattened_tree( - combine.chain( - _flip_sign_and_weight_decay(weight_decay=weight_decay, maximize=maximize), - transform._trace( # pylint: disable=protected-access - momentum=momentum, - dampening=dampening, - nesterov=nesterov, - moment_requires_grad=moment_requires_grad, - already_flattened=True, - ), - _scale_by_neg_lr(lr), - ) - ) diff --git a/torchopt/_src/transform.py b/torchopt/_src/transform.py deleted file mode 100644 index 15bf11ed..00000000 --- a/torchopt/_src/transform.py +++ /dev/null @@ -1,897 +0,0 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -# This file is modified from: -# https://github.com/deepmind/optax/blob/master/optax/_src/transform.py -# ============================================================================== -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -# pylint: disable=invalid-name - -from typing import Any, Callable, List, NamedTuple, Optional, Sequence, Union - -import torch - -from torchopt._src import base -from torchopt._src.typing import Schedule -from torchopt._src.utils import pytree - - -ScaleState = base.EmptyState -INT32_MAX = torch.iinfo(torch.int32).max -TRIPLE_PYTREEDEF = pytree.tree_structure((0, 1, 2)) - - -def map_flattened(func: Callable, *args: Any) -> List[Any]: - """Apply a function to each element of a flattened list.""" - return list(map(func, *args)) - - -def with_flattened_tree(inner: base.GradientTransformation) -> base.GradientTransformation: - # pylint: disable-next=line-too-long - """Wraps around the inner transformation that manipulates the flattened tree structure (:class:``list``).""" - - def init_fn(params): - return inner.init(pytree.tree_leaves(params)) - - def update_fn(updates, state, *, params=None, inplace=True): - flattened_updates, treedef = pytree.tree_flatten(updates) - if params is not None: - params = pytree.tree_leaves(params) - - flattened_updates, state = inner.update( - flattened_updates, state, params=params, inplace=inplace - ) - updates = pytree.tree_unflatten(treedef, flattened_updates) - - return updates, state - - return base.GradientTransformation(init_fn, update_fn) - - -def inc_count(updates: base.Updates, count: Sequence[torch.Tensor]) -> Sequence[torch.Tensor]: - """Increments int counter by one. - - Returns: - A counter incremeted by one, or max_int if the maximum precision is reached. - """ - return _inc_count(updates=updates, count=count, already_flattened=False) - - -def _inc_count( - updates: base.Updates, count: Sequence[torch.Tensor], *, already_flattened: bool = False -) -> Sequence[torch.Tensor]: - def f(c, g): - return c + (c != INT32_MAX).to(torch.int32) if g is not None else c - - if already_flattened: - return map_flattened(f, count, updates) - return pytree.tree_map(f, count, updates) - - -def scale(step_size: float) -> base.GradientTransformation: - """Scale updates by some fixed scalar ``step_size``. - - Args: - step_size: A scalar corresponding to a fixed scaling factor for updates. - - Returns: - An ``(init_fn, update_fn)`` tuple. - """ - return _scale(step_size=step_size, already_flattened=False) - - -def _scale(step_size: float, *, already_flattened: bool = False) -> base.GradientTransformation: - if already_flattened: - tree_map = map_flattened - else: - tree_map = pytree.tree_map - - def init_fn(params): # pylint: disable=unused-argument - return ScaleState() - - def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument - if inplace: - - def f(g): - return g.mul_(step_size) if g is not None else None - - else: - - def f(g): - return g.mul(step_size) if g is not None else None - - updates = tree_map(f, updates) - return updates, state - - return base.GradientTransformation(init_fn, update_fn) - - -class ScaleByScheduleState(NamedTuple): - """Maintains count for scale scheduling.""" - - count: Sequence[torch.Tensor] # type: ignore - - -def scale_by_schedule(step_size_fn: Schedule) -> base.GradientTransformation: - """Scale updates using a custom schedule for the ``step_size``. - - Args: - step_size_fn: - A function that takes an update count as input and proposes the ``step_size`` to - multiply the updates by. - - Returns: - An ``(init_fn, update_fn)`` tuple. - """ - return _scale_by_schedule(step_size_fn=step_size_fn, already_flattened=False) - - -def _scale_by_schedule( - step_size_fn: Schedule, *, already_flattened: bool = False -) -> base.GradientTransformation: - if already_flattened: - tree_map = map_flattened - else: - tree_map = pytree.tree_map - - def init_fn(params): - zero = tree_map( # count init - lambda t: torch.zeros(1, dtype=torch.int32, device=t.device), params - ) - return ScaleByScheduleState(count=zero) - - def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument - step_size = step_size_fn(state.count) - - if inplace: - - def f(g): - return g.mul_(step_size) if g is not None else None - - else: - - def f(g): - return g.mul(step_size) if g is not None else None - - updates = tree_map(f, updates) - return updates, ScaleByScheduleState(count=inc_count(updates, state.count)) - - return base.GradientTransformation(init_fn, update_fn) - - -def _update_moment(updates, moments, decay, *, order, inplace=True, already_flattened=False): - """Compute the exponential moving average of the ``order``-th moment.""" - assert order in (1, 2) - - if inplace: - - if order == 2: - - def f(g, t): - return t.mul_(decay).addcmul_(g, g, value=1 - decay) if g is not None else t - - else: - - def f(g, t): - return t.mul_(decay).add_(g, alpha=1 - decay) if g is not None else t - - else: - - if order == 2: - - def f(g, t): - return t.mul(decay).addcmul_(g, g, value=1 - decay) if g is not None else t - - else: - - def f(g, t): - return t.mul(decay).add_(g, alpha=1 - decay) if g is not None else t - - if already_flattened: - return map_flattened(f, updates, moments) - return pytree.tree_map(f, updates, moments) - - -class ScaleByAdamState(NamedTuple): - """State for the Adam algorithm.""" - - mu: base.Updates - nu: base.Updates - count: Sequence[torch.Tensor] # type: ignore - - -def _bias_correction(moment, decay, count, *, already_flattened=False): - """Perform bias correction. This becomes a no-op as count goes to infinity.""" - - def f(t, c): - return t.div(1 - decay**c) - - if already_flattened: - return map_flattened(f, moment, count) - return pytree.tree_map(f, moment, count) - - -def scale_by_adam( - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - moment_requires_grad: bool = False, -) -> base.GradientTransformation: - """Rescale updates according to the Adam algorithm. - - References: - [Kingma et al, 2014](https://arxiv.org/abs/1412.6980) - - Args: - b1: (default: :const:`0.9`) - Decay rate for the exponentially weighted average of grads. - b2: (default: :const:`0.999`) - Decay rate for the exponentially weighted average of squared grads. - eps: (default: :const:`1e-8`) - Term added to the denominator to improve numerical stability. - eps_root: (default: :const:`0.0`) - Term added to the denominator inside the square-root to improve - numerical stability when back-propagating gradients through the rescaling. - moment_requires_grad: (default: :data:`False`) - if :data:`True`, states will be created with flag `requires_grad = True`. - - Returns: - An (init_fn, update_fn) tuple. - """ - return _scale_by_adam( - b1=b1, - b2=b2, - eps=eps, - eps_root=eps_root, - moment_requires_grad=moment_requires_grad, - already_flattened=False, - ) - - -def _scale_by_adam( - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - moment_requires_grad: bool = False, - *, - already_flattened: bool = False, -) -> base.GradientTransformation: - # pylint: disable=unneeded-not - if not 0.0 <= eps: - raise ValueError(f'Invalid epsilon value: {eps}') - if not 0.0 <= b1 < 1.0: - raise ValueError(f'Invalid beta parameter at index 0: {b1}') - if not 0.0 <= b2 < 1.0: - raise ValueError(f'Invalid beta parameter at index 1: {b2}') - # pylint: enable=unneeded-not - - if already_flattened: - tree_map = map_flattened - else: - tree_map = pytree.tree_map - - def init_fn(params): - zero = tree_map( # count init - lambda t: torch.zeros(1, dtype=torch.int32, device=t.device), params - ) - mu = tree_map( # first moment - lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params - ) - nu = tree_map( # second moment - lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params - ) - return ScaleByAdamState(mu=mu, nu=nu, count=zero) - - def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument - mu = _update_moment( - updates, state.mu, b1, order=1, inplace=inplace, already_flattened=already_flattened - ) - nu = _update_moment( - updates, state.nu, b2, order=2, inplace=inplace, already_flattened=already_flattened - ) - count_inc = _inc_count(updates, state.count, already_flattened=already_flattened) - mu_hat = _bias_correction(mu, b1, count_inc, already_flattened=already_flattened) - nu_hat = _bias_correction(nu, b2, count_inc, already_flattened=already_flattened) - - if inplace: - - def f(g, m, v): - return m.div_(v.add_(eps_root).sqrt_().add_(eps)) if g is not None else None - - else: - - def f(g, m, v): - return m.div(v.add(eps_root).sqrt_().add_(eps)) if g is not None else None - - updates = tree_map(f, updates, mu_hat, nu_hat) - return updates, ScaleByAdamState(mu=mu, nu=nu, count=count_inc) - - return base.GradientTransformation(init_fn, update_fn) - - -def scale_by_accelerated_adam( - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - moment_requires_grad: bool = False, -) -> base.GradientTransformation: - """Rescale updates according to the Adam algorithm. - - This function is accelerated by using some fused accelerated operators. - - References: - [Kingma et al, 2014](https://arxiv.org/abs/1412.6980) - - Args: - b1: (default: :const:`0.9`) - Decay rate for the exponentially weighted average of grads. - b2: (default: :const:`0.999`) - Decay rate for the exponentially weighted average of squared grads. - eps: (default: :const:`1e-8`) - Term added to the denominator to improve numerical stability. - eps_root: (default: :const:`0.0`) - Term added to the denominator inside the square-root to improve - numerical stability when back-propagating gradients through the rescaling. - moment_requires_grad: (default: :data:`False`) - if :data:`True`, states will be created with flag `requires_grad = True`. - - Returns: - An (init_fn, update_fn) tuple. - """ - return _scale_by_accelerated_adam( - b1=b1, - b2=b2, - eps=eps, - eps_root=eps_root, - moment_requires_grad=moment_requires_grad, - already_flattened=False, - ) - - -def _scale_by_accelerated_adam( - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - moment_requires_grad: bool = False, - *, - already_flattened: bool = False, -) -> base.GradientTransformation: - # pylint: disable=unneeded-not - if not 0.0 <= eps: - raise ValueError(f'Invalid epsilon value: {eps}') - if not 0.0 <= b1 < 1.0: - raise ValueError(f'Invalid beta parameter at index 0: {b1}') - if not 0.0 <= b2 < 1.0: - raise ValueError(f'Invalid beta parameter at index 1: {b2}') - # pylint: enable=unneeded-not - - from torchopt._src.accelerated_op import AdamOp # pylint: disable=import-outside-toplevel - - if already_flattened: - tree_map = map_flattened - - # pylint: disable-next=unused-argument - def update_fn(updates, state, *, params=None, inplace=True): - count_inc = _inc_count(updates, state.count, already_flattened=True) - - op = AdamOp(b1=b1, b2=b2, eps=eps, eps_root=eps_root, inplace=inplace) - out = map_flattened(op, state.mu, state.nu, updates, count_inc) - - new_mu, new_nu, new_updates = tuple(zip(*out)) # transpose - return new_updates, ScaleByAdamState(mu=new_mu, nu=new_nu, count=count_inc) - - else: - tree_map = pytree.tree_map - - # pylint: disable-next=unused-argument - def update_fn(updates, state, *, params=None, inplace=True): - count_inc = _inc_count(updates, state.count, already_flattened=False) - - treedef = pytree.tree_structure(updates) - - op = AdamOp(b1=b1, b2=b2, eps=eps, eps_root=eps_root, inplace=inplace) - out = pytree.tree_map(op, state.mu, state.nu, updates, count_inc) - - new_mu, new_nu, new_updates = pytree.tree_transpose(treedef, TRIPLE_PYTREEDEF, out) - return new_updates, ScaleByAdamState(mu=new_mu, nu=new_nu, count=count_inc) - - def init_fn(params): - zero = tree_map( # count init - lambda t: torch.zeros(1, dtype=torch.int32, device=t.device), params - ) - mu = tree_map( # first moment - lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params - ) - nu = tree_map( # second moment - lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params - ) - return ScaleByAdamState(mu=mu, nu=nu, count=zero) - - return base.GradientTransformation(init_fn, update_fn) - - -class TraceState(NamedTuple): - """Holds an aggregation of past updates.""" - - trace: base.Params - - -def trace( - momentum: float = 0.9, - dampening: float = 0.0, - nesterov: bool = False, - moment_requires_grad: bool = False, -) -> base.GradientTransformation: - """Compute a trace of past updates. - - Note: `trace` and `ema` have very similar but distinct updates; - `trace = decay * trace + t`, while `ema = decay * ema + (1 - decay) * t`. - Both are frequently found in the optimization literature. - - Args: - momentum: (default: :const:`0.9`) - The decay rate for the trace of past updates. - dampening: (default: :const:`0.0`) - Dampening for momentum. - nesterov: (default: :data:`False`) - Whether to use Nesterov momentum. - moment_requires_grad: (default: :data:`False`) - if :data:`True`, states will be created with flag `requires_grad = True`. - - Returns: - An (init_fn, update_fn) tuple. - """ - return _trace( - momentum=momentum, - dampening=dampening, - nesterov=nesterov, - moment_requires_grad=moment_requires_grad, - already_flattened=False, - ) - - -def _trace( - momentum: float = 0.9, - dampening: float = 0.0, - nesterov: bool = False, - moment_requires_grad: bool = False, - *, - already_flattened: bool = False, -) -> base.GradientTransformation: - # pylint: disable=unneeded-not - if not 0.0 <= momentum: - raise ValueError(f'Invalid momentum value: {momentum}') - if nesterov and (momentum <= 0.0 or dampening != 0.0): - raise ValueError('Nesterov momentum requires a momentum and zero dampening') - # pylint: enable=unneeded-not - - if momentum == 0.0: - return base.identity() - - if already_flattened: - tree_map = map_flattened - else: - tree_map = pytree.tree_map - - def init_fn(params): - return TraceState( - trace=tree_map( - lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params - ) - ) - - first_call = True - - def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument - nonlocal first_call - - if nesterov: - if inplace: - - def f1(g, t): - if first_call: - return t.add_(g) - return t.mul_(momentum).add_(g) - - def f2(g, t): - return g.add_(t, alpha=momentum) - - new_trace = tree_map(f1, updates, state.trace) - updates = tree_map(f2, updates, new_trace) - else: - - def f1(g, t): - if first_call: - return t.add(g) - return t.mul(momentum).add_(g) - - def f2(g, t): - return g.add(t, alpha=momentum) - - new_trace = tree_map(f1, updates, state.trace) - updates = tree_map(f2, updates, new_trace) - else: - if inplace: - - def f(g, t): - if first_call: - return t.add(g) - return t.mul_(momentum).add_(g, alpha=1.0 - dampening) - - def copy_(g, t): - return g.copy_(t) - - new_trace = tree_map(f, updates, state.trace) - updates = tree_map(copy_, updates, new_trace) - else: - - def f(g, t): - if first_call: - return t.add(g) - return t.mul(momentum).add_(g, alpha=1.0 - dampening) - - new_trace = tree_map(f, updates, state.trace) - updates = tree_map(torch.clone, new_trace) - - first_call = False - return updates, TraceState(trace=new_trace) - - return base.GradientTransformation(init_fn, update_fn) - - -class ScaleByRmsState(NamedTuple): - """State for exponential root mean-squared (RMS)-normalized updates.""" - - nu: base.Updates - - -def scale_by_rms( - alpha: float = 0.9, eps: float = 1e-8, initial_scale: float = 0.0 -) -> base.GradientTransformation: - """Rescale updates by the root of the exp. moving avg of the square. - - References: - [Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) - - Args: - alpha: (default: :const:`0.9`) - Decay rate for the exponentially weighted average of squared grads. - eps: (default: :const:`1e-8`) - Term added to the denominator to improve numerical stability. - initial_scale: (default: :const:`0.0`) - Initial value for second moment - - Returns: - An (init_fn, update_fn) tuple. - """ - return _scale_by_rms( - alpha=alpha, - eps=eps, - initial_scale=initial_scale, - already_flattened=False, - ) - - -def _scale_by_rms( - alpha: float = 0.9, - eps: float = 1e-8, - initial_scale: float = 0.0, - *, - already_flattened: bool = False, -) -> base.GradientTransformation: - # pylint: disable=unneeded-not - if not 0.0 <= alpha: - raise ValueError(f'Invalid alpha value: {alpha}') - if not 0.0 <= eps: - raise ValueError(f'Invalid epsilon value: {eps}') - # pylint: enable=unneeded-not - - if already_flattened: - tree_map = map_flattened - else: - tree_map = pytree.tree_map - - def init_fn(params): - nu = tree_map(lambda n: torch.full_like(n, initial_scale), params) # second moment - return ScaleByRmsState(nu=nu) - - def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument - nu = _update_moment( - updates, state.nu, alpha, order=2, inplace=inplace, already_flattened=already_flattened - ) - - if inplace: - - def f(g, n): - return g.div_(n.sqrt().add_(eps)) - - else: - - def f(g, n): - return g.div(n.sqrt().add_(eps)) - - updates = tree_map(f, updates, nu) - return updates, ScaleByRmsState(nu=nu) - - return base.GradientTransformation(init_fn, update_fn) - - -class ScaleByRStdDevState(NamedTuple): - """State for centered exponential moving average of squares of updates.""" - - mu: base.Updates - nu: base.Updates - - -def scale_by_stddev( - alpha: float = 0.9, eps: float = 1e-8, initial_scale: float = 0.0 -) -> base.GradientTransformation: - """Rescale updates by the root of the centered exp. moving average of squares. - - References: - [Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) - - Args: - alpha: (default: :const:`0.9`) - Decay rate for the exponentially weighted average of squared grads. - eps: (default: :const:`1e-8`) - Term added to the denominator to improve numerical stability. - initial_scale: (default: :const:`0.0`) - Initial value for second moment - - Returns: - An (init_fn, update_fn) tuple. - """ - return _scale_by_stddev( - alpha=alpha, - eps=eps, - initial_scale=initial_scale, - already_flattened=False, - ) - - -def _scale_by_stddev( - alpha: float = 0.9, - eps: float = 1e-8, - initial_scale: float = 0.0, - *, - already_flattened: bool = False, -) -> base.GradientTransformation: - # pylint: disable=unneeded-not - if not 0.0 <= alpha: - raise ValueError(f'Invalid alpha value: {alpha}') - if not 0.0 <= eps: - raise ValueError(f'Invalid epsilon value: {eps}') - # pylint: enable=unneeded-not - - if already_flattened: - tree_map = map_flattened - else: - tree_map = pytree.tree_map - - def init_fn(params): - mu = tree_map(torch.zeros_like, params) # first moment - nu = tree_map(lambda n: torch.full_like(n, initial_scale), params) # second moment - return ScaleByRStdDevState(mu=mu, nu=nu) - - def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument - mu = _update_moment( - updates, state.mu, alpha, order=1, inplace=inplace, already_flattened=already_flattened - ) - nu = _update_moment( - updates, state.nu, alpha, order=2, inplace=inplace, already_flattened=already_flattened - ) - - if inplace: - - def f(g, m, n): - return g.div_(n.addcmul(m, m, value=-1.0).sqrt_().add_(eps)) - - else: - - def f(g, m, n): - return g.div(n.addcmul(m, m, value=-1.0).sqrt_().add_(eps)) - - updates = tree_map(f, updates, mu, nu) - return updates, ScaleByRStdDevState(mu=mu, nu=nu) - - return base.GradientTransformation(init_fn, update_fn) - - -class MaskedState(NamedTuple): - """Maintains inner transform state for masked transformations.""" - - inner_state: Any - - -class MaskedNode(NamedTuple): - """A node used to mask out unspecified parts of a tree. - - This node is ignored when mapping functions across the tree e.g. using - :func:`pytree.tree_map` since it is a container without children. It can - therefore be used to mask out parts of a tree. - """ - - -def masked( - inner: base.GradientTransformation, - mask: Union[Any, Callable[[base.Params], Any]], -) -> base.GradientTransformation: - """Mask updates so only some are transformed, the rest are passed through. - - For example, it is common to skip weight decay for BatchNorm scale and all - bias parameters. In many networks, these are the only parameters with only - one dimension. So, you may create a mask function to mask these out as - follows:: - mask_fn = lambda p: pytree.tree_map(lambda x: x.ndim != 1, p) - weight_decay = torchopt.masked(torchopt.add_decayed_weights(0.001), mask_fn) - You may alternatively create the mask pytree upfront:: - mask = pytree.tree_map(lambda x: x.ndim != 1, params) - weight_decay = torchopt.masked(torchopt.add_decayed_weights(0.001), mask) - For the ``inner`` transform, state will only be stored for the parameters that - have a mask value of ``True``. - - Args: - inner: Inner transformation to mask. - mask: a PyTree with same structure as (or a prefix of) the params PyTree, or - a Callable that returns such a pytree given the params/updates. The leaves - should be booleans, ``True`` for leaves/subtrees you want to apply the - transformation to, and ``False`` for those you want to skip. The mask must - be static for the gradient transformation to be jit-compilable. - - Returns: - New GradientTransformation wrapping ``inner``. - """ - return _masked( - inner=inner, - mask=mask, - already_flattened=False, - ) - - -def _masked( - inner: base.GradientTransformation, - mask: Union[Any, Callable[[base.Params], Any]], - *, - already_flattened: bool = False, -) -> base.GradientTransformation: - - if already_flattened: - tree_map = map_flattened - else: - tree_map = pytree.tree_map - - def tree_mask(params, mask_tree): - return tree_map(lambda p, m: p if m else MaskedNode(), params, mask_tree) - - def init_fn(params): - mask_tree = mask(params) if callable(mask) else mask - masked_params = tree_mask(params, mask_tree) - return MaskedState(inner_state=inner.init(masked_params)) - - def update_fn(updates, state, params=None, inplace=True): # pylint: disable=unused-argument - mask_tree = mask(updates) if callable(mask) else mask - masked_updates = tree_mask(updates, mask_tree) - masked_params = None if params is None else tree_mask(params, mask_tree) - - new_masked_updates, new_inner_state = inner.update( - masked_updates, state.inner_state, params=masked_params, inplace=inplace - ) - - new_updates = tree_map( - lambda new_u, old_u, m: new_u if m else old_u, new_masked_updates, updates, mask_tree - ) - return new_updates, MaskedState(inner_state=new_inner_state) - - return base.GradientTransformation(init_fn, update_fn) - - -AddDecayedWeightsState = base.EmptyState - - -# mypy: ignore-errors -def add_decayed_weights( - weight_decay: float = 0.0, - mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, -) -> base.GradientTransformation: - """Add parameter scaled by `weight_decay`. - - Args: - weight_decay: a scalar weight decay rate. - mask: a tree with same structure as (or a prefix of) the params PyTree, - or a Callable that returns such a pytree given the params/updates. - The leaves should be booleans, `True` for leaves/subtrees you want to - apply the transformation to, and `False` for those you want to skip. - - Returns: - An (init_fn, update_fn) tuple. - """ - return _add_decayed_weights( - weight_decay=weight_decay, - mask=mask, - already_flattened=False, - ) - - -# mypy: ignore-errors -def _add_decayed_weights( - weight_decay: float = 0.0, - mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, - *, - already_flattened: bool = False, -) -> base.GradientTransformation: - if not 0.0 <= weight_decay: # pylint: disable=unneeded-not - raise ValueError(f'Invalid weight_decay value: {weight_decay}') - - if weight_decay == 0.0 and mask is None: - return base.identity() - - if already_flattened: - tree_map = map_flattened - else: - tree_map = pytree.tree_map - - def init_fn(params): # pylint: disable=unused-argument - return AddDecayedWeightsState() - - def update_fn(updates, state, params=None, inplace=True): # pylint: disable=unused-argument - assert params is not None, ( - 'Parameters are required for weight decay. ' - 'Call `update(updates, state, params=params)` instead.' - ) - - if inplace: - - def f(g, p): - if g is not None: - if g.requires_grad: - return g.add_(p, alpha=weight_decay) - return g.add_(p.data, alpha=weight_decay) - return None - - else: - - def f(g, p): - return g.add(p, alpha=weight_decay) if g is not None else None - - updates = tree_map(f, updates, params) - return updates, state - - # If mask is not `None`, apply mask to the gradient transformation. - # E.g. it is common to skip weight decay on bias units and batch stats. - if mask is not None: - return _masked( - inner=base.GradientTransformation(init_fn, update_fn), - mask=mask, - already_flattened=already_flattened, - ) - return base.GradientTransformation(init_fn, update_fn) diff --git a/torchopt/_src/utils.py b/torchopt/_src/utils.py deleted file mode 100644 index 6bfd5bbe..00000000 --- a/torchopt/_src/utils.py +++ /dev/null @@ -1,194 +0,0 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from typing import Dict, List, NamedTuple, Union - -import optree as pytree -import torch -import torch.nn as nn - - -class _ModuleState(NamedTuple): - params: List[Dict] - visual_contents: Union[None, Dict] = None - - -# mypy: ignore-errors -def stop_gradient(target): - """Stop the gradient for the input object. - - Since a tensor use :attr:`grad_fn` to connect itself with the previous computation graph, the - back-propagated gradient will flow over the tensor and continue flow to the tensors that is - connected by :attr:`grad_fn`. Some algorithms requires manually detaching tensors from the - computation graph. - - Note that the :func:`stop_gradient` operation is in-place. - - Args: - target: The target that to be detached from the computation graph, it could be a - :class:`nn.Module`, :class:`torchopt.MetaOptimizer`, state of the - :class:`torchopt.MetaOptimizer`, or just a plain list of tensors. - inplace: If :data:`True`, the target will be detached in-place. if :data:`Frue`, this - function will return a detached copy of the target. The in-place operation is fast and - memory efficient but may raise back-propagation error. - """ - # pylint: disable-next=import-outside-toplevel,cyclic-import - from torchopt._src.optimizer.meta.base import MetaOptimizer - - def f(obj): - if isinstance(obj, torch.Tensor): - requires_grad = obj.requires_grad - obj.detach_().requires_grad_(requires_grad) - - if isinstance(target, _ModuleState): - true_target = target.params - elif isinstance(target, nn.Module): - true_target = tuple(target.parameters()) - elif isinstance(target, MetaOptimizer): - true_target = pytree.tree_leaves(target.state_dict()) - else: - true_target = target - - pytree.tree_map(f, true_target) - - -# pylint: disable-next=too-many-branches,too-many-locals -def extract_state_dict(mod, copy=False, *, with_buffer=True, enable_visual=False, visual_prefix=''): - """Extract target state. - - Since a tensor use :attr:`grad_fn` to connect itself with the previous computation graph, the - back-propagated gradient will flow over the tensor and continue flow to the tensors that is - connected by :attr:`grad_fn`. Some algorithms requires manually detaching tensors from the - computation graph. - - Note that the extracted state is a reference, which means any in-place operator will affect the - target that the state is extracted from. - - Args: - mod: It could be a :class:`nn.Module` or :class:`torchopt.MetaOptimizer`. - with_buffer: - Extract buffer together with parameters, this argument is only used if the input target - is :class:`nn.Module`. - enable_visual: - Add additional annotations, which could be used in computation graph visualization. - Currently, this flag only has effect on :class:`nn.Module` but we will support - :class:`torchopt.MetaOptimizer` later. - visual_prefix: Prefix for the visualization annotations. - - Returns: - State extracted of the input object. - """ - # pylint: disable=import-outside-toplevel,cyclic-import - from torchopt._src.optimizer.meta.base import MetaOptimizer - - if isinstance(mod, nn.Module): # pylint: disable=no-else-return - if enable_visual: - visual_contents = {} - - for k, v in mod.named_parameters(): # pylint: disable=invalid-name - if v.grad_fn is not None: - visual_contents.update({v.grad_fn: (visual_prefix + k, v)}) - else: - visual_contents.update({v: visual_prefix + k}) - else: - visual_contents = None - - params = [] - - def get_variable(t): - if copy: - requires_grad = t.requires_grad - return t.clone().detach_().requires_grad_(requires_grad) - return t - - def _update(term): - if len(term) != 0: - params.append({k: get_variable(v) for k, v in term.items()}) - - # pylint: disable=protected-access - _update(mod._parameters) - if with_buffer: - _update(mod._buffers) - for module in mod.modules(): - if module is mod: - continue - _update(module._parameters) - if with_buffer: - _update(module._buffers) - return _ModuleState(params=tuple(params), visual_contents=visual_contents) - - elif isinstance(mod, MetaOptimizer): - state = mod.state_dict() - if copy: - - def get_variable(t): - if not isinstance(t, torch.Tensor): - return t - requires_grad = t.requires_grad - return t.clone().detach_().requires_grad_(requires_grad) - - return pytree.tree_map(get_variable, state) - - return state - - raise RuntimeError(f'Unexpected class of {mod}') - - -def _extract_container(mod, with_buffer=True): - if isinstance(mod, nn.Module): - containers = [] - - def _update(term): - if len(term) != 0: - containers.append(term) - - # pylint: disable=protected-access - _update(mod._parameters) - if with_buffer: - _update(mod._buffers) - for module in mod.modules(): - if module is mod: - continue - _update(module._parameters) - if with_buffer: - _update(module._buffers) - return tuple(containers) - - raise RuntimeError(f'Unexpected class of {mod}') - - -def recover_state_dict(mod, state): - """Recover state. - - This function is compatible for the ``extract_state``. - - Note that the recovering process is not in-place, so the tensors of the object will not be - modified. - - Args: - mod: Target that need to recover. - state: The recovering state. - """ - # pylint: disable-next=import-outside-toplevel,cyclic-import - from torchopt._src.optimizer.meta.base import MetaOptimizer - - if isinstance(mod, nn.Module): - target_container = _extract_container(mod) - for target, source in zip(target_container, state.params): - target.update(source) - elif isinstance(mod, MetaOptimizer): - mod.load_state_dict(state) - else: - raise RuntimeError(f'Unexpected class of {mod}') diff --git a/torchopt/_src/accelerated_op/__init__.py b/torchopt/accelerated_op/__init__.py similarity index 85% rename from torchopt/_src/accelerated_op/__init__.py rename to torchopt/accelerated_op/__init__.py index 4c7f1cd9..874174f2 100644 --- a/torchopt/_src/accelerated_op/__init__.py +++ b/torchopt/accelerated_op/__init__.py @@ -12,16 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""The accelerated Ops.""" from typing import Iterable, Optional, Union import torch -from torchopt._src.accelerated_op.adam_op import AdamOp +from torchopt.accelerated_op.adam_op import AdamOp -def accelerated_op_available( - devices: Optional[Union[str, torch.device, Iterable[Union[str, torch.device]]]] = None +def is_available( + devices: Optional[Union[int, str, torch.device, Iterable[Union[int, str, torch.device]]]] = None ) -> bool: """Check the availability of accelerated optimizer.""" op = AdamOp() @@ -30,7 +31,7 @@ def accelerated_op_available( devices = [torch.device('cuda'), torch.device('cpu')] elif isinstance(devices, torch.device): devices = [devices] - elif isinstance(devices, str): + elif isinstance(devices, (int, str)): devices = [torch.device(devices)] try: diff --git a/torchopt/_src/__init__.py b/torchopt/accelerated_op/_src/__init__.py similarity index 91% rename from torchopt/_src/__init__.py rename to torchopt/accelerated_op/_src/__init__.py index 75b3cf8d..bbf0b4cd 100644 --- a/torchopt/_src/__init__.py +++ b/torchopt/accelerated_op/_src/__init__.py @@ -12,5 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - -from torchopt._src.accelerated_op import accelerated_op_available +"""The Python implementation of accelerated ops.""" diff --git a/torchopt/accelerated_op/_src/adam_op.py b/torchopt/accelerated_op/_src/adam_op.py new file mode 100644 index 00000000..65752446 --- /dev/null +++ b/torchopt/accelerated_op/_src/adam_op.py @@ -0,0 +1,116 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Python implementation of accelerated AdamOp.""" + +# pylint: disable=invalid-name,too-many-arguments,unused-argument + +from typing import Tuple + +import torch + + +def forward_( + updates: torch.Tensor, + mu: torch.Tensor, + nu: torch.Tensor, + b1: float, + b2: float, + eps: float, + eps_root: float, + count: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Adam forward inplace.""" + inv_one_minus_pow_b1 = 1.0 / (1.0 - pow(b1, count)) + inv_one_minus_pow_b2 = 1.0 / (1.0 - pow(b2, count)) + + mu = mu.mul_(b1).add_(updates, alpha=1.0 - b1) + nu = nu.mul_(b2).add_(updates.square(), alpha=1.0 - b2) + updates.copy_( + mu.mul(inv_one_minus_pow_b1).div_( + nu.mul(inv_one_minus_pow_b2).add_(eps_root).sqrt_().add_(eps) + ) + ) + return updates, mu, nu + + +def forward_mu(updates: torch.Tensor, mu: torch.Tensor, b1: float) -> torch.Tensor: + """Adam forward mu.""" + return mu.mul(b1).add_(updates, alpha=1.0 - b1) + + +def forward_nu(updates: torch.Tensor, nu: torch.Tensor, b2: float) -> torch.Tensor: + """Adam forward nu.""" + return nu.mul(b2).add_(updates.square(), alpha=1.0 - b2) + + +def forward_updates( + new_mu: torch.Tensor, + new_nu: torch.Tensor, + b1: float, + b2: float, + eps: float, + eps_root: float, + count: int, +) -> torch.Tensor: + """Adam forward updates.""" + inv_one_minus_pow_b1 = 1.0 / (1.0 - pow(b1, count)) + inv_one_minus_pow_b2 = 1.0 / (1.0 - pow(b2, count)) + return new_mu.mul(inv_one_minus_pow_b1).div_( + new_nu.mul(inv_one_minus_pow_b2).add_(eps_root).sqrt_().add_(eps) + ) + + +def backward_mu( + dmu: torch.Tensor, updates: torch.Tensor, mu: torch.Tensor, b1: float +) -> Tuple[torch.Tensor, torch.Tensor]: + """Adam backward mu.""" + dupdates = dmu.mul(1.0 - b1) + dmu = dmu.mul(b1) + return dupdates, dmu + + +def backward_nu( + dnu: torch.Tensor, updates: torch.Tensor, nu: torch.Tensor, b2: float +) -> Tuple[torch.Tensor, torch.Tensor]: + """Adam backward nu.""" + dupdates = updates.mul(dnu).mul_(2.0 * (1.0 - b2)) + dnu = dnu.mul(b2) + return dupdates, dnu + + +def backward_updates( + dupdates: torch.Tensor, + updates: torch.Tensor, + new_mu: torch.Tensor, + new_nu: torch.Tensor, + b1: float, + b2: float, + count: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Adam backward updates.""" + one_minus_pow_b1 = 1.0 - pow(b1, count) + inv_one_minus_pow_b2 = 1.0 / (1.0 - pow(b2, count)) + + updates_div_new_mu = updates.div(new_mu) + denominator = updates_div_new_mu.mul_(one_minus_pow_b1) + dnew_mu_out = dupdates.mul(updates_div_new_mu) + dnew_nu_out = ( + dupdates.mul(updates).mul_(denominator.square_()).mul_(-0.5 * inv_one_minus_pow_b2) + ) + + mask = new_mu == 0 + dnew_mu_out[mask] = 0 + dnew_nu_out[mask] = 0 + return dnew_mu_out, dnew_nu_out diff --git a/torchopt/_src/accelerated_op/adam_op.py b/torchopt/accelerated_op/adam_op.py similarity index 89% rename from torchopt/_src/accelerated_op/adam_op.py rename to torchopt/accelerated_op/adam_op.py index 00261c1a..56792487 100644 --- a/torchopt/_src/accelerated_op/adam_op.py +++ b/torchopt/accelerated_op/adam_op.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""The accelerated AdamOp.""" # pylint: disable=c-extension-no-member,invalid-name @@ -19,7 +20,11 @@ import torch -from torchopt._C import adam_op # pylint: disable=no-name-in-module + +try: + from torchopt._C import adam_op # pylint: disable=no-name-in-module +except ImportError: + from torchopt.accelerated_op._src import adam_op # type: ignore[no-redef] class AdamOp: # pylint: disable=too-few-public-methods @@ -30,14 +35,13 @@ class MuOp(torch.autograd.Function): # pylint: disable=abstract-method @staticmethod def jvp(ctx: Any, *grad_inputs: Any) -> Any: - # pylint: disable-next=line-too-long """Defines a formula for differentiating the operation with forward mode automatic differentiation.""" @staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: """Performs the operation.""" updates, mu, b1 = args - new_mu = adam_op.forwardMu(updates, mu, b1) + new_mu = adam_op.forward_mu(updates, mu, b1) ctx.save_for_backward(updates, mu) ctx.b1 = b1 return new_mu @@ -49,7 +53,7 @@ def backward(ctx: Any, *args: Any) -> Any: dmu = args[0] updates, mu = ctx.saved_tensors b1 = ctx.b1 - result = adam_op.backwardMu(dmu, updates, mu, b1) + result = adam_op.backward_mu(dmu, updates, mu, b1) return result[0], result[1], None class NuOp(torch.autograd.Function): # pylint: disable=abstract-method @@ -57,14 +61,13 @@ class NuOp(torch.autograd.Function): # pylint: disable=abstract-method @staticmethod def jvp(ctx: Any, *grad_inputs: Any) -> Any: - # pylint: disable-next=line-too-long """Defines a formula for differentiating the operation with forward mode automatic differentiation.""" @staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: """Performs the operation.""" updates, nu, b2 = args - new_nu = adam_op.forwardNu(updates, nu, b2) + new_nu = adam_op.forward_nu(updates, nu, b2) ctx.save_for_backward(updates, nu) ctx.b2 = b2 return new_nu @@ -76,7 +79,7 @@ def backward(ctx: Any, *args: Any) -> Any: dnu = args[0] updates, nu = ctx.saved_tensors b2 = ctx.b2 - result = adam_op.backwardNu(dnu, updates, nu, b2) + result = adam_op.backward_nu(dnu, updates, nu, b2) return result[0], result[1], None class UpdatesOp(torch.autograd.Function): # pylint: disable=abstract-method @@ -84,14 +87,13 @@ class UpdatesOp(torch.autograd.Function): # pylint: disable=abstract-method @staticmethod def jvp(ctx: Any, *grad_inputs: Any) -> Any: - # pylint: disable-next=line-too-long """Defines a formula for differentiating the operation with forward mode automatic differentiation.""" @staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: """Performs the operation.""" new_mu, new_nu, (b1, b2, eps, eps_root, count) = args - new_updates = adam_op.forwardUpdates(new_mu, new_nu, b1, b2, eps, eps_root, count) + new_updates = adam_op.forward_updates(new_mu, new_nu, b1, b2, eps, eps_root, count) ctx.save_for_backward(new_updates, new_mu, new_nu) ctx.others = (b1, b2, eps, eps_root, count) return new_updates @@ -103,7 +105,7 @@ def backward(ctx: Any, *args: Any) -> Any: dupdates = args[0] updates, new_mu, new_nu = ctx.saved_tensors b1, b2, _, _, count = ctx.others - result = adam_op.backwardUpdates(dupdates, updates, new_mu, new_nu, b1, b2, count) + result = adam_op.backward_updates(dupdates, updates, new_mu, new_nu, b1, b2, count) return result[0], result[1], None # pylint: disable-next=too-many-arguments diff --git a/torchopt/_src/combine.py b/torchopt/alias/__init__.py similarity index 69% rename from torchopt/_src/combine.py rename to torchopt/alias/__init__.py index 00e90bc1..b00b3c35 100644 --- a/torchopt/_src/combine.py +++ b/torchopt/alias/__init__.py @@ -29,22 +29,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +r"""The aliases of preset :class:`GradientTransformation`\s for optimizers.""" -from torchopt._src import base +from torchopt.alias.adam import adam +from torchopt.alias.adamw import adamw +from torchopt.alias.rmsprop import rmsprop +from torchopt.alias.sgd import sgd -def chain(*args: base.GradientTransformation) -> base.GradientTransformation: - """Applies a list of chainable update transformations. - - Given a sequence of chainable transforms, :func:`chain` returns an :func:`init_fn` that - constructs a ``state`` by concatenating the states of the individual transforms, and returns an - :func:`update_fn` which chains the update transformations feeding the appropriate state to each. - - Args: - *args: - A sequence of chainable ``(init_fn, update_fn)`` tuples. - - Returns: - A single ``(init_fn, update_fn)`` tuple. - """ - return base.ChainedGradientTransformation(*args) +__all__ = ['adam', 'adamw', 'rmsprop', 'sgd'] diff --git a/torchopt/alias/adam.py b/torchopt/alias/adam.py new file mode 100644 index 00000000..637b40c7 --- /dev/null +++ b/torchopt/alias/adam.py @@ -0,0 +1,123 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/alias.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset :class:`GradientTransformation` for the Adam optimizer.""" + +from typing import Tuple + +from torchopt.alias.utils import flip_sign_and_add_weight_decay, scale_by_neg_lr +from torchopt.combine import chain_flat +from torchopt.transform import scale_by_accelerated_adam, scale_by_adam +from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +__all__ = ['adam'] + + +# pylint: disable-next=too-many-arguments +def adam( + lr: ScalarOrSchedule = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + *, + eps_root: float = 0.0, + moment_requires_grad: bool = False, + maximize: bool = False, + use_accelerated_op: bool = False, +) -> GradientTransformation: + """The functional Adam optimizer. + + Adam is an SGD variant with learning rate adaptation. The *learning rate* used for each weight + is computed from estimates of first- and second-order moments of the gradients (using suitable + exponential moving averages). + + References: + - Kingma et al, 2014: https://arxiv.org/abs/1412.6980 + + Args: + lr: (default: :const:`1e-3`) + This is a fixed global scaling factor. + betas: (default: :const:`(0.9, 0.999)`) + Coefficients used for computing running averages of gradient and its square. + eps: (default: :const:`1e-8`) + A small constant applied to denominator outside of the square root (as in the Adam + paper) to avoid dividing by zero when rescaling. + weight_decay: (default: :const:`0.0`) + Weight decay, add L2 penalty to parameters. + eps_root: (default: :data:`0.0`) + A small constant applied to denominator inside the square root (as in RMSProp), to avoid + dividing by zero when rescaling. This is needed for example when computing + (meta-)gradients through Adam. + moment_requires_grad: (default: :data:`False`) + If :data:`True` the momentums will be created with flag ``requires_grad=True``, this + flag is often used in Meta-Learning algorithms. + maximize: (default: :data:`False`) + Maximize the params based on the objective, instead of minimizing. + use_accelerated_op: (default: :data:`False`) + If :data:`True` use our implemented fused operator. + + Returns: + The corresponding :class:`GradientTransformation` instance. + + See Also: + The functional optimizer wrapper :class:`torchopt.FuncOptimizer`. + """ + b1, b2 = betas # pylint: disable=invalid-name + # pylint: disable=unneeded-not + if not (callable(lr) or 0.0 <= lr): + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= b1 < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {b1}') + if not 0.0 <= b2 < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {b2}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + # pylint: enable=unneeded-not + + if use_accelerated_op: + adam_scaler = scale_by_accelerated_adam.flat # type: ignore[attr-defined] + else: + adam_scaler = scale_by_adam.flat # type: ignore[attr-defined] + + return chain_flat( + flip_sign_and_add_weight_decay(weight_decay=weight_decay, maximize=maximize), + adam_scaler( + b1=b1, + b2=b2, + eps=eps, + eps_root=eps_root, + moment_requires_grad=moment_requires_grad, + ), + scale_by_neg_lr(lr), + ) diff --git a/torchopt/alias/adamw.py b/torchopt/alias/adamw.py new file mode 100644 index 00000000..b088be60 --- /dev/null +++ b/torchopt/alias/adamw.py @@ -0,0 +1,135 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/alias.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset :class:`GradientTransformation` for the AdamW optimizer.""" + +from typing import Any, Callable, Optional, Tuple, Union + +from torchopt.alias.utils import flip_sign_and_add_weight_decay, scale_by_neg_lr +from torchopt.combine import chain_flat +from torchopt.transform import add_decayed_weights, scale_by_accelerated_adam, scale_by_adam +from torchopt.typing import GradientTransformation, Params, ScalarOrSchedule + + +__all__ = ['adamw'] + + +# pylint: disable-next=too-many-arguments +def adamw( + lr: ScalarOrSchedule = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 1e-2, + *, + eps_root: float = 0.0, + mask: Optional[Union[Any, Callable[[Params], Any]]] = None, + moment_requires_grad: bool = False, + maximize: bool = False, + use_accelerated_op: bool = False, +) -> GradientTransformation: + """Adam with weight decay regularization. + + AdamW uses weight decay to regularize learning towards small weights, as + this leads to better generalization. In SGD you can also use L2 regularization + to implement this as an additive loss term, however L2 regularization + does not behave as intended for adaptive gradient algorithms such as Adam. + + References: + - Loshchilov et al, 2019: https://arxiv.org/abs/1711.05101 + + Args: + lr: (default: :const:`1e-3`) + This is a fixed global scaling factor. + betas: (default: :const:`(0.9, 0.999)`) + Coefficients used for computing running averages of gradient and its square. + eps: (default: :const:`1e-8`) + A small constant applied to denominator outside of the square root (as in the Adam + paper) to avoid dividing by zero when rescaling. + weight_decay: (default: :const:`1e-2`) + Strength of the weight decay regularization. Note that this weight decay is multiplied + with the learning rate. This is consistent with other frameworks such as PyTorch, but + different from (Loshchilov et al, 2019) where the weight decay is only multiplied with + the "schedule multiplier", but not the base learning rate. + eps_root: (default: :data:`0.0`) + A small constant applied to denominator inside the square root (as in RMSProp), to avoid + dividing by zero when rescaling. This is needed for example when computing + (meta-)gradients through Adam. + mask: (default: :data:`None`) + A tree with same structure as (or a prefix of) the params PyTree, or a Callable that + returns such a pytree given the params/updates. The leaves should be booleans, + :data:`True` for leaves/subtrees you want to apply the weight decay to, and + :data:`False` for those you want to skip. Note that the Adam gradient + transformations are applied to all parameters. + moment_requires_grad: (default: :data:`False`) + If :data:`True` the momentums will be created with flag ``requires_grad=True``, this + flag is often used in Meta-Learning algorithms. + maximize: (default: :data:`False`) + Maximize the params based on the objective, instead of minimizing. + use_accelerated_op: (default: :data:`False`) + If :data:`True` use our implemented fused operator. + + Returns: + The corresponding :class:`GradientTransformation` instance. + + See Also: + The functional optimizer wrapper :class:`torchopt.FuncOptimizer`. + """ + b1, b2 = betas # pylint: disable=invalid-name + # pylint: disable=unneeded-not + if not (callable(lr) or 0.0 <= lr): + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= b1 < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {b1}') + if not 0.0 <= b2 < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {b2}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + # pylint: enable=unneeded-not + + if use_accelerated_op: + adam_scaler = scale_by_accelerated_adam.flat # type: ignore[attr-defined] + else: + adam_scaler = scale_by_adam.flat # type: ignore[attr-defined] + + return chain_flat( + flip_sign_and_add_weight_decay(weight_decay=0.0, maximize=maximize), + adam_scaler( + b1=b1, + b2=b2, + eps=eps, + eps_root=eps_root, + moment_requires_grad=moment_requires_grad, + ), + add_decayed_weights.flat(weight_decay=weight_decay, mask=mask), # type: ignore[attr-defined] + scale_by_neg_lr(lr), + ) diff --git a/torchopt/alias/rmsprop.py b/torchopt/alias/rmsprop.py new file mode 100644 index 00000000..6d2ddeb3 --- /dev/null +++ b/torchopt/alias/rmsprop.py @@ -0,0 +1,124 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/alias.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset :class:`GradientTransformation` for the RMSProp optimizer.""" + +from torchopt.alias.utils import flip_sign_and_add_weight_decay, scale_by_neg_lr +from torchopt.combine import chain_flat +from torchopt.transform import scale_by_rms, scale_by_stddev, trace +from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +__all__ = ['rmsprop'] + + +# pylint: disable-next=too-many-arguments +def rmsprop( + lr: ScalarOrSchedule = 1e-2, + alpha: float = 0.99, + eps: float = 1e-8, + weight_decay: float = 0.0, + momentum: float = 0.0, + centered: bool = False, + *, + initial_scale: float = 0.0, + nesterov: bool = False, + maximize: bool = False, +) -> GradientTransformation: + """The functional version of the RMSProp optimizer. + + RMSProp is an SGD variant with learning rate adaptation. The *learning rate* used for each + weight is scaled by a suitable estimate of the magnitude of the gradients on previous steps. + Several variants of RMSProp can be found in the literature. This alias provides an easy to + configure RMSProp optimizer that can be used to switch between several of these variants. + + References: + - Tieleman and Hinton, 2012: http://www.cs.toronto.edu/~hinton/coursera/lecture6/lec6.pdf + - Graves, 2013: https://arxiv.org/abs/1308.0850 + + Args: + lr: (default: :const:`1e-2`) + This is a fixed global scaling factor. + alpha: (default: :const:`0.99`) + Smoothing constant, the decay used to track the magnitude of previous gradients. + eps: (default: :const:`1e-8`) + A small numerical constant to avoid dividing by zero when rescaling. + weight_decay: (default: :const:`0.0`) + Weight decay, add L2 penalty to parameters. + momentum: (default: :const:`0.0`) + The decay rate used by the momentum term. The momentum is not used when it is set to + :const:`0.0`. + centered: (default: :data:`False`) + If :data:`True`, use the variance of the past gradients to rescale the latest + gradients. + initial_scale: (default: :data:`0.0`) + Initialization of accumulators tracking the magnitude of previous updates. PyTorch + uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When reproducing results from a + paper, verify the value used by the authors. + nesterov: (default: :data:`False`) + Whether to use Nesterov momentum. + maximize: (default: :data:`False`) + Maximize the params based on the objective, instead of minimizing. + + Returns: + The corresponding :class:`GradientTransformation` instance. + + See Also: + The functional optimizer wrapper :class:`torchopt.FuncOptimizer`. + """ + # pylint: disable=unneeded-not + if not (callable(lr) or 0.0 <= lr): + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= alpha: + raise ValueError(f'Invalid alpha value: {alpha}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= momentum: + raise ValueError(f'Invalid momentum value: {momentum}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + # pylint: enable=unneeded-not + + if centered: + rmsprop_scaler = scale_by_stddev.flat # type: ignore[attr-defined] + else: + rmsprop_scaler = scale_by_rms.flat # type: ignore[attr-defined] + + return chain_flat( + flip_sign_and_add_weight_decay(weight_decay=weight_decay, maximize=maximize), + rmsprop_scaler( + alpha=alpha, + eps=eps, + initial_scale=initial_scale, + ), + trace.flat(momentum=momentum, nesterov=nesterov), # type: ignore[attr-defined] + scale_by_neg_lr(lr), + ) diff --git a/torchopt/alias/sgd.py b/torchopt/alias/sgd.py new file mode 100644 index 00000000..af87587f --- /dev/null +++ b/torchopt/alias/sgd.py @@ -0,0 +1,105 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/alias.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset :class:`GradientTransformation` for the SGD optimizer.""" + +from torchopt.alias.utils import flip_sign_and_add_weight_decay, scale_by_neg_lr +from torchopt.combine import chain_flat +from torchopt.transform import trace +from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +__all__ = ['sgd'] + + +def sgd( + lr: ScalarOrSchedule, + momentum: float = 0.0, + dampening: float = 0.0, + weight_decay: float = 0.0, + nesterov: bool = False, + *, + moment_requires_grad: bool = False, + maximize: bool = False, +) -> GradientTransformation: + """The functional version of the canonical Stochastic Gradient Descent optimizer. + + This implements stochastic gradient descent. It also includes support for momentum, and nesterov + acceleration, as these are standard practice when using stochastic gradient descent to train + deep neural networks. + + References: + - Sutskever et al, 2013: http://proceedings.mlr.press/v28/sutskever13.pdf + + Args: + lr: This is a fixed global scaling factor. + momentum: (default: :const:`0.0`) + The decay rate used by the momentum term. The momentum is not used when it is set to + :const:`0.0`. + weight_decay: (default: :const:`0.0`) + Weight decay, add L2 penalty to parameters. + dampening: (default: :const:`0.0`) + Dampening for momentum. + nesterov: (default: :data:`False`) + Whether to use Nesterov momentum. + moment_requires_grad: (default: :data:`False`) + If :data:`True` the momentums will be created with flag ``requires_grad=True``, this + flag is often used in Meta-Learning algorithms. + maximize: (default: :data:`False`) + Maximize the params based on the objective, instead of minimizing. + + Returns: + The corresponding :class:`GradientTransformation` instance. + + See Also: + The functional optimizer wrapper :class:`torchopt.FuncOptimizer`. + """ + # pylint: disable=unneeded-not + if not (callable(lr) or 0.0 <= lr): + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= momentum: + raise ValueError(f'Invalid momentum value: {momentum}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + if nesterov and (momentum <= 0.0 or dampening != 0.0): + raise ValueError('Nesterov momentum requires a momentum and zero dampening') + # pylint: enable=unneeded-not + + return chain_flat( + flip_sign_and_add_weight_decay(weight_decay=weight_decay, maximize=maximize), + trace.flat( # type: ignore[attr-defined] + momentum=momentum, + dampening=dampening, + nesterov=nesterov, + moment_requires_grad=moment_requires_grad, + ), + scale_by_neg_lr(lr), + ) diff --git a/torchopt/alias/utils.py b/torchopt/alias/utils.py new file mode 100644 index 00000000..3ba3b6dc --- /dev/null +++ b/torchopt/alias/utils.py @@ -0,0 +1,116 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r"""Utilities for the aliases of preset :class:`GradientTransformation`\s for optimizers.""" + +from torchopt.base import EmptyState, GradientTransformation, identity +from torchopt.transform import scale, scale_by_schedule +from torchopt.transform.utils import tree_map_flat +from torchopt.typing import ScalarOrSchedule + + +__all__ = ['flip_sign_and_add_weight_decay', 'scale_by_neg_lr'] + + +def flip_sign_and_add_weight_decay(weight_decay: float = 0.0, maximize=False): + """Flips the sign of the updates and adds weight decay.""" + if not 0.0 <= weight_decay: # pylint: disable=unneeded-not + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + + if not maximize and weight_decay == 0.0: + return identity() + + def init_fn(params): # pylint: disable=unused-argument + return EmptyState() + + if not maximize: # gradient descent + + def update_fn(updates, state, *, params=None, inplace=True): + assert params is not None, ( + 'Parameters are required for weight decay. ' + 'Call `update(updates, state, params=params)` instead.' + ) + + if inplace: + + def f(g, p): + if g.requires_grad: + return g.add_(p, alpha=weight_decay) + return g.add_(p.data, alpha=weight_decay) + + else: + + def f(g, p): + return g.add(p, alpha=weight_decay) + + updates = tree_map_flat(f, updates, params) + return updates, state + + else: # gradient ascent + + if weight_decay == 0.0: + # pylint: disable-next=unused-argument + def update_fn(updates, state, *, params=None, inplace=True): + if inplace: + + def f(g): + return g.neg_() + + else: + + def f(g): + return g.neg() + + updates = tree_map_flat(f, updates) + return updates, state + + else: + + def update_fn(updates, state, *, params=None, inplace=True): + assert params is not None, ( + 'Parameters are required for weight decay. ' + 'Call `update(updates, state, params=params)` instead.' + ) + + if inplace: + + def f(g, p): + if g is not None: + if g.requires_grad: + return g.neg_().add_(p, alpha=weight_decay) + return g.neg_().add_(p.data, alpha=weight_decay) + return None + + else: + + def f(g, p): + return g.neg().add_(p, alpha=weight_decay) + + updates = tree_map_flat(f, updates, params) + return updates, state + + return GradientTransformation(init_fn, update_fn) + + +def scale_by_neg_lr(lr: ScalarOrSchedule): + """Scales the updates by the negative learning rate.""" + if not (callable(lr) or 0.0 <= lr): + raise ValueError(f'Invalid learning rate: {lr}') + + if callable(lr): + + def schedule_wrapper(count): + return -lr(count) # type: ignore[operator] + + return scale_by_schedule.flat(schedule_wrapper) # type: ignore[attr-defined] + return scale.flat(-lr) # type: ignore[attr-defined] diff --git a/torchopt/_src/base.py b/torchopt/base.py similarity index 89% rename from torchopt/_src/base.py rename to torchopt/base.py index f17bf00f..5706957e 100644 --- a/torchopt/_src/base.py +++ b/torchopt/base.py @@ -29,32 +29,35 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""The base classes for gradient transformation.""" import itertools from abc import abstractmethod -from typing import Callable, NamedTuple, Optional, Tuple +from typing import TYPE_CHECKING, Callable, NamedTuple, Optional, Tuple +from typing_extensions import Protocol # Python 3.8+ -from torchopt._src.typing import Numeric, TensorTree +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates -try: - from typing import Protocol -except ImportError: - from typing_extensions import Protocol # type: ignore[misc] - -OptState = TensorTree # States are arbitrary nests of `torch.Tensor`. -# Parameters are arbitrary nests of `torch.Tensor`. -Params = TensorTree -Updates = Params # Gradient updates are of the same type as parameters. - -Schedule = Callable[[Numeric], Numeric] +__all__ = [ + 'EmptyState', + 'UninitializedState', + 'GradientTransformation', + 'ChainedGradientTransformation', + 'identity', +] class EmptyState(NamedTuple): """An empty state for the simplest stateless transformations.""" +class UninitializedState(NamedTuple): + """A state that is not initialized yet.""" + + class TransformInitFn(Protocol): # pylint: disable=too-few-public-methods """A callable type for the :func:`init` step of a :class:`GradientTransformation`. @@ -64,8 +67,8 @@ class TransformInitFn(Protocol): # pylint: disable=too-few-public-methods """ @abstractmethod - def __call__(self, params: Params) -> OptState: - """The `init` function. + def __call__(self, params: 'Params') -> 'OptState': + """The ``init`` function. Args: params: @@ -90,13 +93,13 @@ class TransformUpdateFn(Protocol): # pylint: disable=too-few-public-methods @abstractmethod def __call__( self, - updates: Updates, - state: OptState, + updates: 'Updates', + state: 'OptState', *, - params: Optional[Params] = None, + params: Optional['Params'] = None, inplace: bool = True, - ) -> Tuple[Updates, OptState]: - """The `update` function. + ) -> Tuple['Updates', 'OptState']: + """The ``update`` function. Args: updates: A tree of candidate updates. @@ -188,7 +191,7 @@ def update_fn(updates, state, *, params=None, inplace=True): instance.transformations = transformations return instance - def __str__(self): + def __str__(self) -> str: """Returns a string representation of the chained gradient transformation.""" return '{}(\n {}\n)'.format( self.__class__.__name__, ',\n '.join(repr(t) for t in self.transformations) @@ -229,19 +232,18 @@ def __new__(cls): return super().__new__(cls, init=cls.init_fn, update=cls.update_fn) @staticmethod - def init_fn(params: Params) -> OptState: # pylint: disable=unused-argument + def init_fn(params: 'Params') -> 'OptState': # pylint: disable=unused-argument """Returns empty state.""" return EmptyState() @staticmethod - # pylint: disable-next=unused-argument def update_fn( - updates: Updates, - state: OptState, + updates: 'Updates', + state: 'OptState', *, - params: Optional[Params] = None, - inplace: bool = True, - ) -> Tuple[Updates, OptState]: + params: Optional['Params'] = None, # pylint: disable=unused-argument + inplace: bool = True, # pylint: disable=unused-argument + ) -> Tuple['Updates', 'OptState']: """Returns updates unchanged.""" return updates, state diff --git a/torchopt/_src/clip.py b/torchopt/clip.py similarity index 61% rename from torchopt/_src/clip.py rename to torchopt/clip.py index 31d54797..29c26032 100644 --- a/torchopt/_src/clip.py +++ b/torchopt/clip.py @@ -15,24 +15,35 @@ # This file is modified from: # https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/clip_grad.py # ============================================================================== +"""Utilities for gradient clipping.""" + +from typing import Union import torch -from torch._six import inf -from torchopt._src import base -from torchopt._src.utils import pytree +from torchopt import pytree +from torchopt.base import EmptyState, GradientTransformation + + +__all__ = ['clip_grad_norm'] -ClipState = base.EmptyState +ClipState = EmptyState def clip_grad_norm( - max_norm: float, norm_type: float = 2.0, error_if_nonfinite: bool = False -) -> base.GradientTransformation: + max_norm: Union[float, int], + norm_type: Union[float, int] = 2.0, + error_if_nonfinite: bool = False, +) -> GradientTransformation: """Clips gradient norm of an iterable of parameters. Args: - max_delta: The maximum absolute value for each element in the update. + max_norm (float or int): The maximum absolute value for each element in the update. + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if :data:`True`, an error is thrown if the total norm of the + gradients from :attr:`updates` is ``nan``, ``inf``, or ``-inf``. Returns: An ``(init_fn, update_fn)`` tuple. @@ -42,15 +53,12 @@ def init_fn(params): # pylint: disable=unused-argument return ClipState() def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument - available_updates = [] - for g in updates: - if g is not None: - available_updates.append(g) + available_updates = pytree.tree_leaves(updates) if len(available_updates) == 0: - return torch.tensor(0.0) + return updates, state device = available_updates[0].device with torch.no_grad(): - if norm_type == inf: + if norm_type == torch.inf: norms = [p.abs().max().to(device) for p in available_updates] total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) else: @@ -64,22 +72,23 @@ def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable= f'non-finite, so it cannot be clipped. To disable this error and scale the ' f'gradients by the non-finite norm anyway, set `error_if_nonfinite=False`' ) - clip_coef = max_norm / (float(total_norm) + 1e-6) - # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but - # doing so avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device - # synchronization when the gradients do not reside in CPU memory. - clip_coef_clamped = min(clip_coef, 1.0) + clip_coefficient = max_norm / (float(total_norm) + 1e-6) + # Note: multiplying by the clamped coefficient is redundant when the coefficient is + # clamped to 1, but doing so avoids a `if clip_coefficient < 1:` conditional which + # can require a CPU <=> device synchronization when the gradients do not reside in + # CPU memory. + clip_coefficient_clamped = min(clip_coefficient, 1.0) if inplace: def f(g): - return g.mul_(clip_coef_clamped) if g is not None else None + return g.mul_(clip_coefficient_clamped) else: def f(g): - return g.mul(clip_coef_clamped) if g is not None else None + return g.mul(clip_coefficient_clamped) new_updates = pytree.tree_map(f, updates) return new_updates, state - return base.GradientTransformation(init_fn, update_fn) + return GradientTransformation(init_fn, update_fn) diff --git a/torchopt/combine.py b/torchopt/combine.py new file mode 100644 index 00000000..26f66214 --- /dev/null +++ b/torchopt/combine.py @@ -0,0 +1,98 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/alias.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities to define a chained transformation.""" + +from torchopt import pytree +from torchopt.base import ChainedGradientTransformation, GradientTransformation, identity +from torchopt.typing import Updates + + +__all__ = ['chain', 'chain_flat'] + + +def chain(*transformations: GradientTransformation) -> GradientTransformation: + """Applies a list of chainable update transformations. + + Given a sequence of chainable transforms, :func:`chain` returns an :func:`init_fn` that + constructs a ``state`` by concatenating the states of the individual transforms, and returns an + :func:`update_fn` which chains the update transformations feeding the appropriate state to each. + + Args: + *transformations: + A sequence of chainable ``(init_fn, update_fn)`` tuples. + + Returns: + A single ``(init_fn, update_fn)`` tuple. + """ + if len(transformations) == 0: + return identity() + if len(transformations) == 1: + return transformations[0] + return ChainedGradientTransformation(*transformations) + + +def chain_flat(*transformations: GradientTransformation) -> GradientTransformation: + """Wraps around the inner transformations that manipulates the flattened tree structure (:class:``list``). + + Args: + *transformations: + A sequence of chainable ``(init_fn, update_fn)`` tuples. + + Returns: + A single ``(init_fn, update_fn)`` tuple. + """ + if len(transformations) == 0: + return identity() + if len(transformations) == 1: + inner = transformations[0] + else: + inner = chain(*transformations) + + def init_fn(params): + return inner.init(pytree.tree_leaves(params, none_is_leaf=True)) + + def update_fn(updates, state, *, params=None, inplace=True): + flat_updates, treespec = pytree.tree_flatten(updates, none_is_leaf=True) + if params is not None: + flat_params = pytree.tree_leaves(params, none_is_leaf=True) + else: + flat_params = None + + flat_updates, state = inner.update(flat_updates, state, params=flat_params, inplace=inplace) + updates: Updates + updates = pytree.tree_unflatten(treespec, flat_updates) + return updates, state + + return GradientTransformation(init_fn, update_fn) + + +chain.flat = chain_flat # type: ignore[attr-defined] diff --git a/torchopt/_src/optimizer/__init__.py b/torchopt/diff/__init__.py similarity index 70% rename from torchopt/_src/optimizer/__init__.py rename to torchopt/diff/__init__.py index 8501fb15..45674fcf 100644 --- a/torchopt/_src/optimizer/__init__.py +++ b/torchopt/diff/__init__.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Differentiable Gradient Estimation.""" -from torchopt._src.optimizer import meta -from torchopt._src.optimizer.adam import Adam -from torchopt._src.optimizer.adamw import AdamW -from torchopt._src.optimizer.base import Optimizer -from torchopt._src.optimizer.rmsprop import RMSProp, RMSprop -from torchopt._src.optimizer.sgd import SGD +from torchopt.diff import implicit, zero_order +from torchopt.diff.implicit import ImplicitMetaGradientModule diff --git a/torchopt/_src/optimizer/meta/__init__.py b/torchopt/diff/implicit/__init__.py similarity index 69% rename from torchopt/_src/optimizer/meta/__init__.py rename to torchopt/diff/implicit/__init__.py index ec227474..4e50b615 100644 --- a/torchopt/_src/optimizer/meta/__init__.py +++ b/torchopt/diff/implicit/__init__.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Implicit Meta-Gradient.""" -from torchopt._src.optimizer.meta.adam import MetaAdam -from torchopt._src.optimizer.meta.adamw import MetaAdamW -from torchopt._src.optimizer.meta.base import MetaOptimizer -from torchopt._src.optimizer.meta.rmsprop import MetaRMSProp, MetaRMSprop -from torchopt._src.optimizer.meta.sgd import MetaSGD +from torchopt.diff.implicit import nn +from torchopt.diff.implicit.decorator import custom_root +from torchopt.diff.implicit.nn import ImplicitMetaGradientModule + + +__all__ = ['custom_root', 'ImplicitMetaGradientModule'] diff --git a/torchopt/diff/implicit/decorator.py b/torchopt/diff/implicit/decorator.py new file mode 100644 index 00000000..aaeda594 --- /dev/null +++ b/torchopt/diff/implicit/decorator.py @@ -0,0 +1,473 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implicit Meta-Gradient.""" + +# pylint: disable=invalid-name + +import functools +import inspect +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union + +import functorch +import torch +from torch.autograd import Function + +from torchopt import linear_solve, pytree +from torchopt.typing import ( + ListOfOptionalTensors, + ListOfTensors, + TensorOrTensors, + TupleOfOptionalTensors, + TupleOfTensors, +) + + +__all__ = ['custom_root'] + + +Args = Tuple[Any, ...] +KwArgs = Dict[str, Any] + + +class MaskedOptimalityFn: # pylint: disable=missing-class-docstring,too-few-public-methods + def __init__( + self, + optimality_fn: Callable[..., TensorOrTensors], + solution: TensorOrTensors, + output_is_tensor: bool, + argnums: Tuple[int, ...], + *args: Any, + ) -> None: + self.optimality_fn = optimality_fn + self.solution = solution + self.output_is_tensor = output_is_tensor + self.argnums = argnums + + pre_filled = [] + post_filled = [] + for idx, arg in enumerate(args): + if idx + 1 in argnums: # plus 1 because we exclude the first argument + post_filled.append(arg) + else: + pre_filled.append(arg) + self.len_args = len(pre_filled) + len(post_filled) + self.pre_filled = tuple(pre_filled) + self.post_filled = tuple(post_filled) + + def __call__(self, *args: Any) -> TensorOrTensors: + true_args = [] + pre_filled_counter = 0 + for idx in range(self.len_args): + if idx + 1 in self.argnums: # plus 1 because we exclude the first argument + arg = args[idx] + else: + arg = self.pre_filled[pre_filled_counter] + pre_filled_counter += 1 + true_args.append(arg) + if self.output_is_tensor: + return self.optimality_fn(self.solution[0], *true_args) + return self.optimality_fn(self.solution, *true_args) + + +# pylint: disable-next=too-many-arguments,too-many-locals,too-many-branches +def _root_vjp( + optimality_fn: Callable[..., TensorOrTensors], + solution: TupleOfTensors, + args: Args, + grad_outputs: TupleOfTensors, + output_is_tensor: bool, + argnums: Tuple[int, ...], + solve: Callable[..., TensorOrTensors] = linear_solve.solve_normal_cg(), +) -> TupleOfOptionalTensors: + + if output_is_tensor: + + def optimality_cond(solution: TupleOfTensors) -> TensorOrTensors: + return optimality_fn(solution[0], *args) + + else: + + def optimality_cond(solution: TupleOfTensors) -> TensorOrTensors: + return optimality_fn(solution, *args) + + _, optimality_cond_vjp_fn, *_ = functorch.vjp(optimality_cond, solution) + + # Compute the multiplication A^T u = (u^T A)^T. + if output_is_tensor: + + def matvec(u: TupleOfTensors) -> TupleOfTensors: + return optimality_cond_vjp_fn(u[0])[0] + + else: + + def matvec(u: TupleOfTensors) -> TupleOfTensors: + return optimality_cond_vjp_fn(u)[0] + + # The solution of A^T u = v, where + # A = jacobian(optimality_fn, argnums=0) + # v = -grad_outputs. + v: TupleOfTensors = pytree.tree_map(torch.neg, grad_outputs) # type: ignore[arg-type,assignment] + u: TupleOfTensors = solve(matvec, v) # type: ignore[assignment] + + masked_optimality_fn = MaskedOptimalityFn( + optimality_fn, solution, output_is_tensor, argnums, *args + ) + + _, optimality_vjp_fn, *_ = functorch.vjp( + masked_optimality_fn, *masked_optimality_fn.post_filled + ) + + output: TupleOfTensors + if output_is_tensor: + output = optimality_vjp_fn(u[0]) + else: + output = optimality_vjp_fn(u) + + # Prepend None as the vjp for init_params. + true_output: ListOfOptionalTensors = [None] + for idx in range(masked_optimality_fn.len_args): + if idx + 1 in argnums: # plus 1 because we exclude the first argument + true_output.append(output[idx]) + else: + true_output.append(None) + + return tuple(true_output) + + +def _extract_kwargs(kwarg_keys: Sequence[str], flat_args: Tuple[Any, ...]) -> Tuple[Args, KwArgs]: + nargs = len(flat_args) - len(kwarg_keys) + args, kwarg_vals = flat_args[:nargs], flat_args[nargs:] + kwargs = dict(zip(kwarg_keys, kwarg_vals)) + return args, kwargs + + +def _signature_bind(signature: inspect.Signature, *args: Any, **kwargs: Any) -> Tuple[Args, KwArgs]: + bound = signature.bind(*args, **kwargs) + bound.apply_defaults() + return bound.args, bound.kwargs + + +def _signature_bind_and_match( + signature: inspect.Signature, *args: Any, **kwargs: Any +) -> Tuple[Args, KwArgs, Callable[[Args], Tuple[Args, KwArgs]]]: + # We want to bind *args and **kwargs based on the provided signature, but also to associate the + # resulting positional arguments back. To achieve this, we lift arguments to a triple: + # + # (was_kwarg, ref, value) + # + # where ref is an index position (int) if the original argument was from *args and a dictionary + # key if the original argument was from **kwargs. After binding to the inspected signature, we + # use the tags to associate the resolved positional arguments back to their args and kwargs + # source. + + args = [(False, i, v) for i, v in enumerate(args)] + kwargs = {k: (True, k, v) for (k, v) in kwargs.items()} + bound = signature.bind(*args, **kwargs) + + mapping = [(was_kwarg, ref) for was_kwarg, ref, _ in bound.args] + + def map_args_back(out_args): + src_args = [None] * len(args) + src_kwargs = {} + for (was_kwarg, ref), out_arg in zip(mapping, out_args): + if was_kwarg: + src_kwargs[ref] = out_arg + else: + src_args[ref] = out_arg + return src_args, src_kwargs + + out_args = tuple(v for _, _, v in bound.args) + out_kwargs = {k: v for k, (_, _, v) in bound.kwargs.items()} + return out_args, out_kwargs, map_args_back + + +def _split_tensor_and_others( + mixed_tuple: Tuple[Any, ...], +) -> Tuple[pytree.PyTreeSpec, Tuple[bool, ...], TupleOfTensors, Tuple[Any, ...]]: + flattened: List[Any] + flattened, treespec = pytree.tree_flatten(mixed_tuple, none_is_leaf=True) # type: ignore[arg-type] + tensors: ListOfTensors = [] + non_tensors: List[Any] = [] + is_tensor_mask: List[bool] = [] + for item in flattened: + is_tensor = isinstance(item, torch.Tensor) + is_tensor_mask.append(is_tensor) + if is_tensor: + tensors.append(item.data) + else: + non_tensors.append(item) + return treespec, tuple(is_tensor_mask), tuple(tensors), tuple(non_tensors) + + +def _merge_tensor_and_others( + treespec: pytree.PyTreeSpec, + is_tensor_mask: Tuple[bool, ...], + tensors: TupleOfTensors, + non_tensors: Tuple[Any, ...], +) -> Tuple[Any, ...]: + tensor_counter = 0 + non_tensor_counter = 0 + results = [] + for is_tensor in is_tensor_mask: + if is_tensor: + results.append(tensors[tensor_counter]) + tensor_counter += 1 + else: + results.append(non_tensors[non_tensor_counter]) + non_tensor_counter += 1 + return pytree.tree_unflatten(treespec, results) # type: ignore[return-value] + + +# pylint: disable-next=too-many-arguments,too-many-statements +def _custom_root( + solver_fn: Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]], + optimality_fn: Callable[..., TensorOrTensors], + solve: Callable[..., TensorOrTensors], + argnums: Tuple[int, ...], + has_aux: bool, + reference_signature: Optional[Union[inspect.Signature, Callable]] = None, +) -> Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]]: + solver_fn_signature = inspect.signature(solver_fn) + + if reference_signature is None: + reference_signature = inspect.signature(optimality_fn) + elif not isinstance(reference_signature, inspect.Signature): + # If is a CompositeLinearFunction, accesses subfn. + # Otherwise, assumes a Callable. + fn = getattr(reference_signature, 'subfn', reference_signature) + reference_signature = inspect.signature(fn) + + def make_custom_vjp_solver_fn( + solver_fn: Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]], + kwarg_keys: Sequence[str], + args_signs: Tuple[Tuple[int, int, Optional[Union[Type[tuple], Type[list]]]], ...], + ) -> Type[Function]: + # pylint: disable-next=missing-class-docstring,abstract-method + class ImplicitMetaGradient(Function): + @staticmethod + def forward( # type: ignore[override] # pylint: disable=arguments-differ + ctx: Any, *flat_args: Any + ) -> Tuple[Any, ...]: + output, aux, output_is_tensor = None, None, False + + args = [] + for offset, nargs, arg_seq_type in args_signs: + if arg_seq_type is not None: + args.append(arg_seq_type(flat_args[offset : offset + nargs])) + else: + args.append(flat_args[offset]) + args = tuple(args) + + args, kwargs = _extract_kwargs(kwarg_keys, args) + output = solver_fn(*args, **kwargs) + if has_aux: + if not (isinstance(output, tuple) and len(output) == 2): + raise RuntimeError( + f'custom_root(optimality_fn)(solver_fn)(*args): output of function ' + f'solver_fn should be a tuple: (output, aux) if has_aux is True. ' + f'Got {output}' + ) + output, aux = output + if isinstance(output, torch.Tensor): + output_is_tensor = True + output = (output,) + elif not (isinstance(output, tuple) and all(map(torch.is_tensor, output))): + raise RuntimeError( + f'custom_root(optimality_fn)(solver_fn)(*args): output of function ' + f'solver_fn should be a torch.Tensor or a tuple of torch.Tensor. ' + f'Got {output}' + ) + + ( + args_treespec, + args_is_tensor_mask, + args_tensors, + args_non_tensors, + ) = _split_tensor_and_others(args) + ctx.args_treespec = args_treespec + ctx.args_is_tensor_mask = args_is_tensor_mask + ctx.args_non_tensors = args_non_tensors + + ctx.save_for_backward(*output, *args_tensors) + ctx.output_is_tensor = output_is_tensor + + return (*output, aux, output_is_tensor, type(output)) + + @staticmethod + def backward( # pylint: disable=too-many-locals + ctx: Any, *grad_outputs: Any + ) -> TupleOfTensors: + grad_outputs: TupleOfTensors = grad_outputs[:-3] + + saved_tensors = ctx.saved_tensors + output = saved_tensors[: len(grad_outputs)] + args_tensors = saved_tensors[len(grad_outputs) :] + args_treespec = ctx.args_treespec + args_is_tensor_mask = ctx.args_is_tensor_mask + args_non_tensors = ctx.args_non_tensors + args = _merge_tensor_and_others( + args_treespec, args_is_tensor_mask, args_tensors, args_non_tensors + ) + + args, kwargs = _extract_kwargs(kwarg_keys, args) + + bound_args, bound_kwargs, map_args_back = _signature_bind_and_match( + reference_signature, *args, **kwargs # type: ignore[arg-type] + ) + if bound_kwargs: + raise TypeError( + f'keyword arguments to solver_fn could not be resolved to positional ' + f'arguments based on the signature {reference_signature}. This can ' + f'happen under custom_root if optimality_fn takes catch-all **kwargs, or ' + f'under custom_fixed_point if fixed_point_fn takes catch-all **kwargs, ' + f'both of which are currently unsupported.' + ) + + # Compute VJPs w.r.t. args. + vjps = _root_vjp( + optimality_fn=optimality_fn, + solution=output, + args=bound_args[1:], + grad_outputs=grad_outputs, + output_is_tensor=ctx.output_is_tensor, + argnums=argnums, + solve=solve, + ) + + args_vjps, kwargs_vjps = map_args_back(vjps) + ordered_vjps = tuple(args_vjps) + tuple(kwargs_vjps[k] for k in kwargs.keys()) + true_vjps = [] + for (_, _, arg_seq_type), vjp in zip(args_signs, ordered_vjps): + if arg_seq_type is not None: + true_vjps.extend(vjp) + else: + true_vjps.append(vjp) + return tuple(true_vjps) + + return ImplicitMetaGradient + + @functools.wraps(solver_fn) + def wrapped_solver_fn( + *args: Any, **kwargs: Any + ) -> Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]: + args, kwargs = _signature_bind(solver_fn_signature, *args, **kwargs) + keys, vals = list(kwargs.keys()), list(kwargs.values()) + + args_signs: List[Tuple[int, int, Optional[Union[Type[tuple], Type[list]]]]] = [] + flat_args: List[Any] = [] + args_offset = 0 + for idx, arg in enumerate(args): + if idx in argnums: + if isinstance(arg, torch.Tensor): + args_signs.append((args_offset, 1, None)) # start position, None + flat_args.append(arg) + args_offset += 1 + elif isinstance(arg, (tuple, list)) and all(map(torch.is_tensor, arg)): + nargs = len(arg) + args_signs.append( + (args_offset, nargs, type(arg)) # start position, sequence type + ) + flat_args.extend(arg) + args_offset += nargs + else: + raise RuntimeError( + 'custom_root(optimality_fn)(solver_fn)(*args): argument of function ' + 'solver_fn specified with `argnums` should be a torch.Tensor or a tuple of ' + 'torch.Tensor' + ) + else: + args_signs.append((args_offset, 1, None)) # start position, None + flat_args.append(arg) + args_offset += 1 + + args_signs = tuple(args_signs) + flat_args = tuple(flat_args) + + result = make_custom_vjp_solver_fn(solver_fn, keys, args_signs).apply(*flat_args, *vals) + *output, aux, output_is_tensor, output_type = result + if output_is_tensor: + output = output[0] + else: + output = output_type(output) + if has_aux: + return output, aux + return output + + return wrapped_solver_fn + + +def custom_root( + optimality_fn: Callable[..., TensorOrTensors], + argnums: Union[int, Tuple[int, ...]], + has_aux: bool = False, + solve: Callable[..., TensorOrTensors] = linear_solve.solve_normal_cg(), +) -> Callable[ + [Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]]], + Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]], +]: + """Decorator for adding implicit differentiation to a root solver. + + This wrapper should be used as a decorator: + + .. code-block:: python + + def optimality_fn(optimal_params, ...): + ... + return residual + + @custom_root(optimality_fn, argnums=argnums) + def solver_fn(params, arg1, arg2, ...): + ... + return optimal_params + + optimal_params = solver_fn(init_params, ...) + + The first argument to ``optimality_fn`` and ``solver_fn`` is preserved as the parameter input. + The ``argnums`` argument refers to the indices of the variables in ``solver_fn``'s signature. + For example, setting ``argnums=(1, 2)`` will compute the gradient of ``optimal_params`` with + respect to ``arg1`` and ``arg2`` in the above snippet. Note that, except the first argument, the + keyword arguments of the ``optimality_fn`` should be a subset of the ones of ``solver_fn``. + **In best practice, the ``optimality_fn`` should have the same signature as ``solver_fn``.** + + Args: + optimality_fn: (callable) + An equation function, ``optimality_fn(params, *args)``. The invariant is + ``optimality_fn(solution, *args) == 0`` at the solution / root of ``solution``. + argnums: (int or tuple of ints) + Specifies arguments to compute gradients with respect to. The ``argnums`` can be an + integer or a tuple of integers, which respect to the zero-based indices of the arguments + of the ``solver_fn(params, *args)`` function. The argument ``params`` is included + for the counting, while it is indexed as ``argnums=0``. + has_aux: (default: :data:`False`) + Whether the decorated solver function returns auxiliary data. + solve: (callable, optional, default: :func:`linear_solve.solve_normal_cg`) + a linear solver of the form ``solve(matvec, b)``. + + Returns: + A solver function decorator, i.e., ``custom_root(optimality_fn)(solver_fn)``. + """ + if isinstance(argnums, int): + assert argnums != 0 + argnums = (argnums,) + else: + assert 0 not in argnums + + return functools.partial( + _custom_root, + optimality_fn=optimality_fn, + solve=solve, + argnums=argnums, + has_aux=has_aux, + ) diff --git a/torchopt/diff/implicit/nn/__init__.py b/torchopt/diff/implicit/nn/__init__.py new file mode 100644 index 00000000..95a2ea85 --- /dev/null +++ b/torchopt/diff/implicit/nn/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The base class for differentiable implicit meta-gradient models.""" + +# Preload to resolve circular references +import torchopt.nn.module # pylint: disable=unused-import +from torchopt.diff.implicit.nn.module import ImplicitMetaGradientModule + + +__all__ = ['ImplicitMetaGradientModule'] diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py new file mode 100644 index 00000000..ed27b14c --- /dev/null +++ b/torchopt/diff/implicit/nn/module.py @@ -0,0 +1,297 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The base class for differentiable implicit meta-gradient models.""" + +# pylint: disable=redefined-builtin + +import contextlib +import functools +import itertools +from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, Type + +import functorch +import torch + +from torchopt import pytree +from torchopt.diff.implicit.decorator import custom_root +from torchopt.nn.module import MetaGradientModule +from torchopt.typing import LinearSolver, TensorTree, TupleOfTensors +from torchopt.utils import extract_module_containers + + +__all__ = ['ImplicitMetaGradientModule'] + + +def update_containers( + dst_containers: Iterable[Dict[str, Optional[torch.Tensor]]], + src_containers: Iterable[Dict[str, Optional[torch.Tensor]]], +) -> None: + """Update the tensor containers in ``dst_containers`` with the ones in ``src_containers``.""" + for src_container, dst_container in zip(src_containers, dst_containers): + dst_container.update(src_container) + + +@contextlib.contextmanager +def container_context( + orig_containers: Iterable[Dict[str, Optional[torch.Tensor]]], + args_containers: Iterable[Dict[str, Optional[torch.Tensor]]], +) -> Generator[None, None, None]: + # pylint: disable-next=line-too-long + """A context manager that temporarily updates the containers in ``orig_containers`` with the ones in ``args_containers``.""" + if not isinstance(orig_containers, (list, tuple)): + orig_containers = list(orig_containers) + orig_containers_backups = [container.copy() for container in orig_containers] + try: + update_containers(orig_containers, args_containers) + yield + finally: + update_containers(orig_containers, orig_containers_backups) + + +def make_optimality_from_objective( + objective: Callable[..., torch.Tensor] +) -> Callable[..., TupleOfTensors]: + """Make a function that computes the optimality function of the objective function.""" + + def optimality(self: 'ImplicitMetaGradientModule', *input, **kwargs) -> TupleOfTensors: + params_containers = extract_module_containers(self, with_buffers=False)[0] + flat_params: TupleOfTensors + # pylint: disable-next=line-too-long + flat_params, params_containers_treespec = pytree.tree_flatten_as_tuple(params_containers) # type: ignore[arg-type] + + def objective_fn(__flat_params: TupleOfTensors, *input, **kwargs) -> torch.Tensor: + flat_grad_tracking_params = __flat_params + grad_tracking_params_containers: Tuple[ + Dict[str, Optional[torch.Tensor]], ... + ] = pytree.tree_unflatten( # type: ignore[assignment] + params_containers_treespec, flat_grad_tracking_params + ) + + with container_context(params_containers, grad_tracking_params_containers): + return objective(self, *input, **kwargs) + + objective_grad_fn = functorch.grad(objective_fn, argnums=0) + flat_grads = objective_grad_fn(flat_params, *input, **kwargs) + return flat_grads + + return optimality + + +def enable_implicit_gradients( + cls: Type['ImplicitMetaGradientModule'], +) -> Type['ImplicitMetaGradientModule']: + """Enables implicit gradients for the :func:`solve` method.""" + cls_solve = cls.solve + if getattr(cls_solve, '__implicit_gradients_enabled__', False): + raise TypeError('Implicit gradients are already enabled for the `solve` method.') + + if cls.linear_solve is not None: + solve_kwargs = dict(solve=cls.linear_solve) + else: + solve_kwargs = {} + + @functools.wraps(cls_solve) + def wrapped( # pylint: disable=too-many-locals + self: 'ImplicitMetaGradientModule', *input, **kwargs + ) -> Any: + """Solve the optimization problem.""" + params_containers = extract_module_containers(self, with_buffers=False)[0] + meta_params_containers = [self._meta_parameters] # pylint: disable=protected-access + for meta_module in self.meta_children(): + meta_params_containers.extend( + extract_module_containers(meta_module, with_buffers=False)[0] + ) + meta_params_containers = tuple(meta_params_containers) + + flat_params: TupleOfTensors + flat_meta_params: TupleOfTensors + flat_params, params_containers_treespec = pytree.tree_flatten_as_tuple( + params_containers # type: ignore[arg-type] + ) + flat_meta_params, meta_params_containers_treespec = pytree.tree_flatten_as_tuple( + meta_params_containers # type: ignore[arg-type] + ) + + def optimality_fn( + __flat_params: TupleOfTensors, + __flat_meta_params: TupleOfTensors, + *input, + **kwargs, + ) -> TupleOfTensors: + flat_grad_tracking_params = __flat_params + grad_tracking_params_containers: Tuple[ + Dict[str, Optional[torch.Tensor]], ... + ] = pytree.tree_unflatten( # type: ignore[assignment] + params_containers_treespec, flat_grad_tracking_params + ) + flat_grad_tracking_meta_params = __flat_meta_params + grad_tracking_meta_params_containers: Tuple[ + Dict[str, Optional[torch.Tensor]], ... + ] = pytree.tree_unflatten( # type: ignore[assignment] + meta_params_containers_treespec, flat_grad_tracking_meta_params + ) + + with container_context( + itertools.chain( + params_containers, + meta_params_containers, + ), + itertools.chain( + grad_tracking_params_containers, + grad_tracking_meta_params_containers, + ), + ): + return self.optimality(*input, **kwargs) + + @custom_root(optimality_fn, argnums=1, has_aux=True, **solve_kwargs) + def solver_fn( + __flat_params: TupleOfTensors, # pylint: disable=unused-argument + __flat_meta_params: TupleOfTensors, # pylint: disable=unused-argument + *input, + **kwargs, + ) -> Tuple[TupleOfTensors, Any]: + output = cls_solve(self, *input, **kwargs) + flat_optimal_params: TupleOfTensors = tuple(pytree.tree_leaves(params_containers)) # type: ignore[arg-type] + return flat_optimal_params, output + + # pylint: disable-next=unused-variable + flat_optimal_params, output = solver_fn(flat_params, flat_meta_params, *input, **kwargs) + return output + + wrapped.__implicit_gradients_enabled__ = True # type: ignore[attr-defined] + cls.solve = wrapped # type: ignore[assignment] + return cls + + +class ImplicitMetaGradientModule(MetaGradientModule): + """The base class for differentiable implicit meta-gradient models.""" + + _custom_optimality: bool + _custom_objective: bool + linear_solve: Optional[LinearSolver] + + def __init_subclass__(cls, linear_solve: Optional[LinearSolver] = None) -> None: + """Validates and initializes the subclass.""" + super().__init_subclass__() + cls.linear_solve = linear_solve + + optimality = getattr(cls, 'optimality', ImplicitMetaGradientModule.optimality) + objective = getattr(cls, 'objective', ImplicitMetaGradientModule.objective) + cls._custom_optimality = optimality is not ImplicitMetaGradientModule.optimality + cls._custom_objective = objective is not ImplicitMetaGradientModule.objective + + if cls._custom_optimality: + if isinstance(optimality, staticmethod): + raise TypeError('method optimality() must not be a staticmethod.') + if isinstance(optimality, classmethod): + raise TypeError('method optimality() must not be a classmethod.') + if not callable(optimality): + raise TypeError('method optimality() must be callable.') + elif not cls._custom_objective: + raise TypeError( + 'ImplicitMetaGradientModule requires either an optimality() method or an objective() method' + ) + else: + if isinstance(objective, staticmethod): + raise TypeError('method objective() must not be a staticmethod.') + if isinstance(objective, classmethod): + raise TypeError('method objective() must not be a classmethod.') + if not callable(objective): + raise TypeError('method objective() must be callable.') + + cls.optimality = make_optimality_from_objective(objective) # type: ignore[assignment] + + enable_implicit_gradients(cls) + + def solve(self, *input, **kwargs) -> Any: + """Solves the inner optimization problem. + + .. warning:: + + For gradient-based optimization methods, the parameter inputs should be explicitly + specified in the :func:`torch.autograd.backward` function as argument ``inputs``. + Otherwise, if not provided, the gradient is accumulated into all the leaf Tensors + (including the meta-parameters) that were used to compute the objective output. + Alternatively, please use :func:`torch.autograd.grad` instead. + + Example:: + + def solve(self, batch, labels): + parameters = tuple(self.parameters()) + optimizer = torch.optim.Adam(parameters, lr=1e-3) + with torch.enable_grad(): + for _ in range(100): + loss = self.objective(batch, labels) + optimizer.zero_grad() + # Only update the `.grad` attribute for parameters + # and leave the meta-parameters unchanged + loss.backward(inputs=parameters) + optimizer.step() + return self + """ + raise NotImplementedError # update parameters + + def optimality(self, *input, **kwargs) -> TensorTree: + r"""Computes the optimality residual. + + This method stands for the optimality residual to the optimal parameters after solving the + inner optimization problem (:meth:`solve`), i.e.: + + .. code-block:: python + + module.solve(*input, **kwargs) + module.optimality(*input, **kwargs) # -> 0 + + 1. For gradient-based optimization, the :meth:`optimality` function is the KKT condition, + usually it is the gradients of the :meth:`objective` function with respect to the module + parameters (not the meta-parameters). If this method is not implemented, it will be + automatically derived from the gradient of the :meth:`objective` function. + + .. math:: + + \text{optimality residual} = \nabla_{\boldsymbol{x}} f (\boldsymbol{x}, \boldsymbol{\theta}) \to \boldsymbol{0} + + where :math:`\boldsymbol{x}` is the joint vector of the module parameters and + :math:`\boldsymbol{\theta}` is the joint vector of the meta-parameters. + + References: + - Karush-Kuhn-Tucker (KKT) conditions: https://en.wikipedia.org/wiki/Karush-Kuhn-Tucker_conditions + + 2. For fixed point iteration, the :meth:`optimality` function can be the residual of the + parameters between iterations, i.e.: + + .. math:: + + \text{optimality residual} = f (\boldsymbol{x}, \boldsymbol{\theta}) - \boldsymbol{x} \to \boldsymbol{0} + + where :math:`\boldsymbol{x}` is the joint vector of the module parameters and + :math:`\boldsymbol{\theta}` is the joint vector of the meta-parameters. + + Returns: + A tree of tensors, the optimality residual to the optimal parameters after solving the + inner optimization problem. + """ # pylint: disable=line-too-long + raise NotImplementedError + + def objective(self, *input, **kwargs) -> torch.Tensor: + """Computes the objective function value. + + This method is used to calculate the :meth:`optimality` if it is not implemented. + Otherwise, this method is optional. + + Returns: + A scalar tensor (``dim=0``), the objective function value. + """ + raise NotImplementedError diff --git a/torchopt/diff/zero_order/__init__.py b/torchopt/diff/zero_order/__init__.py new file mode 100644 index 00000000..a76dcb9a --- /dev/null +++ b/torchopt/diff/zero_order/__init__.py @@ -0,0 +1,34 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Zero-Order Gradient.""" + +import sys as _sys +from types import ModuleType as _ModuleType + +from torchopt.diff.zero_order.decorator import zero_order + + +__all__ = ['zero_order'] + + +class _CallableModule(_ModuleType): # pylint: disable=too-few-public-methods + def __call__(self, *args, **kwargs): + return self.zero_order(*args, **kwargs) + + +# Replace entry in sys.modules for this module with an instance of _CallableModule +_modself = _sys.modules[__name__] +_modself.__class__ = _CallableModule +del _sys, _ModuleType, _modself, _CallableModule diff --git a/torchopt/diff/zero_order/decorator.py b/torchopt/diff/zero_order/decorator.py new file mode 100644 index 00000000..361da4ff --- /dev/null +++ b/torchopt/diff/zero_order/decorator.py @@ -0,0 +1,407 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Zero-Order Gradient Estimation.""" + +import functools +from typing import Any, Callable, List, Tuple, Union +from typing_extensions import Literal # Python 3.8+ +from typing_extensions import TypeAlias # Python 3.10+ + +import torch +from torch.autograd import Function + +from torchopt import pytree +from torchopt.typing import ( + ListOfTensors, + Numeric, + Samplable, + SampleFunc, + Sequence, + TupleOfOptionalTensors, +) + + +class WrappedSamplable(Samplable): # pylint: disable=too-few-public-methods + """A wrapper that wraps a sample function to a :class:`Samplable` object.""" + + def __init__(self, sample_fn: SampleFunc) -> None: + """Wrap a sample function to make it a :class:`Samplable` object.""" + self.sample_fn = sample_fn + + def sample( + self, sample_shape: torch.Size = torch.Size() + ) -> Union[torch.Tensor, Sequence[Numeric]]: + # pylint: disable-next=line-too-long + """Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.""" + return self.sample_fn(sample_shape) + + +def _zero_order_naive( # pylint: disable=too-many-statements + fn: Callable[..., torch.Tensor], + distribution: Samplable, + argnums: Tuple[int, ...], + num_samples: int, + sigma: Numeric, +) -> Callable[..., torch.Tensor]: + @functools.wraps(fn) + def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements + diff_params = [args[argnum] for argnum in argnums] + flat_diff_params: List[Any] + flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params) # type: ignore[arg-type] + + class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-method + @staticmethod + def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: + flat_diff_params = args[:-1] + origin_args = list(args[-1][0]) + flat_args: List[Any] + flat_args, args_treespec = pytree.tree_flatten(origin_args, none_is_leaf=True) # type: ignore[arg-type] + ctx.args_treespec = args_treespec + + is_tensor_mask = [] + tensors = [] + non_tensors = [] + for origin_arg in flat_args: + is_tensor = isinstance(origin_arg, torch.Tensor) + is_tensor_mask.append(is_tensor) + if is_tensor: + tensors.append(origin_arg) + else: + non_tensors.append(origin_arg) + + ctx.non_tensors = non_tensors + ctx.is_tensor_mask = is_tensor_mask + + output = fn(*origin_args) + if not isinstance(output, torch.Tensor): + raise RuntimeError('`output` must be a tensor.') + if output.ndim != 0: + raise RuntimeError('`output` must be a scalar tensor.') + ctx.save_for_backward(*flat_diff_params, *tensors) + ctx.len_args = len(args) + ctx.len_params = len(flat_diff_params) + return output + + @staticmethod + def backward( # pylint: disable=too-many-locals + ctx: Any, *grad_outputs: Any + ) -> TupleOfOptionalTensors: + saved_tensors = ctx.saved_tensors + flat_diff_params = saved_tensors[: ctx.len_params] + tensors = saved_tensors[ctx.len_params :] + non_tensors = ctx.non_tensors + + flat_args = [] + tensors_counter = 0 + non_tensors_counter = 0 + for is_tensor in ctx.is_tensor_mask: + if is_tensor: + flat_args.append(tensors[tensors_counter]) + tensors_counter += 1 + else: + flat_args.append(non_tensors[non_tensors_counter]) + non_tensors_counter += 1 + + args: List[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment] + + def add_perturbation(tensor, noises): + return tensor.add(noises, alpha=sigma) + + param_grads: ListOfTensors = [0.0 for _ in range(len(flat_diff_params))] # type: ignore[misc] + + for _ in range(num_samples): + noises = [distribution.sample(sample_shape=p.shape) for p in flat_diff_params] + flat_noisy_params = [ + add_perturbation(t, n) for t, n in zip(flat_diff_params, noises) + ] + noisy_params: List[Any] = pytree.tree_unflatten( # type: ignore[assignment] + diff_params_treespec, flat_noisy_params + ) + + for argnum, noisy_param in zip(argnums, noisy_params): + args[argnum] = noisy_param + + output = fn(*args) + weighted_grad = grad_outputs[0].mul(output).mul_(1 / sigma) + + for i, noise in enumerate(noises): + param_grads[i] += weighted_grad * noise + + for i in range(len(flat_diff_params)): + param_grads[i] /= num_samples + + return tuple(param_grads + [None] * (ctx.len_args - len(flat_diff_params))) + + return ZeroOrder.apply(*flat_diff_params, (args,)) + + return apply + + +def _zero_order_forward( # pylint: disable=too-many-statements + fn: Callable[..., torch.Tensor], + distribution: Samplable, + argnums: Tuple[int, ...], + num_samples: int, + sigma: Numeric, +) -> Callable[..., torch.Tensor]: + @functools.wraps(fn) + def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements + diff_params = [args[argnum] for argnum in argnums] + flat_diff_params: List[Any] + flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params) # type: ignore[arg-type] + + class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-method + @staticmethod + def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: + flat_diff_params = args[:-1] + origin_args = list(args[-1][0]) + flat_args: List[Any] + flat_args, args_treespec = pytree.tree_flatten(origin_args, none_is_leaf=True) # type: ignore[arg-type] + ctx.args_treespec = args_treespec + + is_tensor_mask = [] + tensors = [] + non_tensors = [] + for origin_arg in flat_args: + is_tensor = isinstance(origin_arg, torch.Tensor) + is_tensor_mask.append(is_tensor) + if is_tensor: + tensors.append(origin_arg) + else: + non_tensors.append(origin_arg) + + ctx.non_tensors = non_tensors + ctx.is_tensor_mask = is_tensor_mask + + output = fn(*origin_args) + if not isinstance(output, torch.Tensor): + raise RuntimeError('`output` must be a tensor.') + if output.ndim != 0: + raise RuntimeError('`output` must be a scalar tensor.') + ctx.save_for_backward(*flat_diff_params, *tensors, output) + ctx.len_args = len(args) + ctx.len_params = len(flat_diff_params) + return output + + @staticmethod + def backward( # pylint: disable=too-many-locals + ctx: Any, *grad_outputs: Any + ) -> TupleOfOptionalTensors: + saved_tensors = ctx.saved_tensors + flat_diff_params = saved_tensors[: ctx.len_params] + tensors = saved_tensors[ctx.len_params : -1] + output = saved_tensors[-1] + non_tensors = ctx.non_tensors + + flat_args = [] + tensors_counter = 0 + non_tensors_counter = 0 + for is_tensor in ctx.is_tensor_mask: + if is_tensor: + flat_args.append(tensors[tensors_counter]) + tensors_counter += 1 + else: + flat_args.append(non_tensors[non_tensors_counter]) + non_tensors_counter += 1 + + args: List[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment] + + def add_perturbation(tensor, noises): + return tensor.add(noises, alpha=sigma) + + param_grads: ListOfTensors = [0.0 for _ in range(len(flat_diff_params))] # type: ignore[misc] + + for _ in range(num_samples): + noises = [distribution.sample(sample_shape=p.shape) for p in flat_diff_params] + flat_noisy_params = [ + add_perturbation(t, n) for t, n in zip(flat_diff_params, noises) + ] + noisy_params: List[Any] = pytree.tree_unflatten( # type: ignore[assignment] + diff_params_treespec, flat_noisy_params + ) + + for argnum, noisy_param in zip(argnums, noisy_params): + args[argnum] = noisy_param + + noisy_output = fn(*args) + output = noisy_output - output + weighted_grad = grad_outputs[0].mul(output).div_(1.0 / sigma) + + for i, noise in enumerate(noises): + param_grads[i] += weighted_grad * noise + + for i in range(len(flat_diff_params)): + param_grads[i] /= num_samples + + return tuple(param_grads + [None] * (ctx.len_args - len(flat_diff_params))) + + return ZeroOrder.apply(*flat_diff_params, (args,)) + + return apply + + +def _zero_order_antithetic( # pylint: disable=too-many-statements + fn: Callable[..., torch.Tensor], + distribution: Samplable, + argnums: Tuple[int, ...], + num_samples: int, + sigma: Numeric, +) -> Callable[..., torch.Tensor]: + @functools.wraps(fn) + def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements + diff_params = [args[argnum] for argnum in argnums] + flat_diff_params: List[Any] + flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params) # type: ignore[arg-type] + + class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-method + @staticmethod + def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: + flat_diff_params = args[:-1] + origin_args = list(args[-1][0]) + flat_args: List[Any] + flat_args, args_treespec = pytree.tree_flatten(origin_args, none_is_leaf=True) # type: ignore[arg-type] + ctx.args_treespec = args_treespec + + is_tensor_mask = [] + tensors = [] + non_tensors = [] + for origin_arg in flat_args: + is_tensor = isinstance(origin_arg, torch.Tensor) + is_tensor_mask.append(is_tensor) + if is_tensor: + tensors.append(origin_arg) + else: + non_tensors.append(origin_arg) + + ctx.non_tensors = non_tensors + ctx.is_tensor_mask = is_tensor_mask + + output = fn(*origin_args) + if not isinstance(output, torch.Tensor): + raise RuntimeError('`output` must be a tensor.') + if output.ndim != 0: + raise RuntimeError('`output` must be a scalar tensor.') + ctx.save_for_backward(*flat_diff_params, *tensors) + ctx.len_args = len(args) + ctx.len_params = len(flat_diff_params) + return output + + @staticmethod + def backward(ctx: Any, *grad_outputs: Any): # pylint: disable=too-many-locals + saved_tensors = ctx.saved_tensors + flat_diff_params = saved_tensors[: ctx.len_params] + tensors = saved_tensors[ctx.len_params :] + non_tensors = ctx.non_tensors + + flat_args = [] + tensors_counter = 0 + non_tensors_counter = 0 + for is_tensor in ctx.is_tensor_mask: + if is_tensor: + flat_args.append(tensors[tensors_counter]) + tensors_counter += 1 + else: + flat_args.append(non_tensors[non_tensors_counter]) + non_tensors_counter += 1 + + args: List[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment] + + param_grads: ListOfTensors = [0.0 for _ in range(len(flat_diff_params))] # type: ignore[misc] + + def get_output(add_perturbation_fn, noises) -> torch.Tensor: + flat_noisy_params = [ + add_perturbation_fn(t, n, alpha=sigma) + for t, n in zip(flat_diff_params, noises) + ] + noisy_params: List[Any] = pytree.tree_unflatten( # type: ignore[assignment] + diff_params_treespec, flat_noisy_params + ) + + for argnum, noisy_param in zip(argnums, noisy_params): + args[argnum] = noisy_param + + return fn(*args) + + for _ in range(num_samples): + noises = [distribution.sample(sample_shape=p.shape) for p in flat_diff_params] + output = get_output(torch.add, noises) - get_output(torch.sub, noises) + weighted_grad = grad_outputs[0].mul(output).mul_(0.5 / sigma) + + for i, noise in enumerate(noises): + param_grads[i] += weighted_grad * noise + + for i in range(len(flat_diff_params)): + param_grads[i] /= num_samples + + return tuple(param_grads + [None] * (ctx.len_args - len(flat_diff_params))) + + return ZeroOrder.apply(*flat_diff_params, (args,)) + + return apply + + +Method: TypeAlias = Literal['naive', 'forward', 'antithetic'] + + +def zero_order( + distribution: Union[SampleFunc, Samplable], + method: Method = 'naive', + argnums: Union[int, Tuple[int, ...]] = (0,), + num_samples: int = 1, + sigma: Numeric = 1.0, +) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: + """Decorator for applying zero-order differentiation. + + Args: + distribution: (function or Samplable) + A samplable object that has method ``samplable.sample(sample_shape)`` or a function that + takes the shape as input and returns a shaped batch of samples. This is used to sample + perturbations from the given distribution. The distribution should be sphere symmetric. + method: (str) + The algorithm to use. The currently supported algorithms are :const:`'naive'`, + :const:`'forward'`, and :const:`'antithetic'`. Defaults to :const:`'naive'`. + argnums: (int or tuple of int, default: :const:`0`) + Specifies arguments to compute gradients with respect to. + num_samples: (int, default :const:`1`) + The number of sample to get the averaged estimated gradient. + sigma: (Numeric) + The standard deviation of the perturbation. Defaults to :const:`1.0`. + + Returns: + A function decorator that enables zero-order gradient estimation. + """ + assert method in ('naive', 'forward', 'antithetic') + if method == 'naive': + method_fn = _zero_order_naive + elif method == 'forward': + method_fn = _zero_order_forward + else: + method_fn = _zero_order_antithetic + + if isinstance(argnums, int): + argnums = (argnums,) + + if not isinstance(distribution, Samplable): + if not callable(distribution): + raise TypeError('`distribution` must be a callable or an instance of `Samplable`.') + distribution = WrappedSamplable(distribution) + + return functools.partial( + method_fn, + distribution=distribution, + argnums=argnums, + num_samples=num_samples, + sigma=sigma, + ) diff --git a/torchopt/distributed/__init__.py b/torchopt/distributed/__init__.py new file mode 100644 index 00000000..d966691c --- /dev/null +++ b/torchopt/distributed/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Distributed utilities.""" + +import torch.distributed as dist +import torch.distributed.rpc as rpc + +from torchopt.distributed import api, autograd, world +from torchopt.distributed.api import * +from torchopt.distributed.world import * + + +__all__ = ['is_available', *api.__all__, *world.__all__] + + +def is_available(): + """Check if the distributed module is available.""" + return dist.is_available() and rpc.is_available() and autograd.is_available() diff --git a/torchopt/distributed/api.py b/torchopt/distributed/api.py new file mode 100644 index 00000000..0c06fa91 --- /dev/null +++ b/torchopt/distributed/api.py @@ -0,0 +1,481 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Distributed APIs.""" + +import functools +import sys +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Sequence, + Tuple, + TypeVar, + Union, + cast, +) + +import torch +import torch.distributed.rpc as rpc + +import torchopt.pytree as pytree +from torchopt.distributed.world import get_worker_id, get_world_rank, get_world_size +from torchopt.typing import Future + + +__all__ = [ + 'TensorDimensionPartitioner', + 'dim_partitioner', + 'batch_partitioner', + 'mean_reducer', + 'sum_reducer', + 'remote_async_call', + 'remote_sync_call', + 'parallelize', + 'parallelize_async', + 'parallelize_sync', +] + + +if rpc.is_available(): + UNSET_RPC_TIMEOUT = rpc.api.UNSET_RPC_TIMEOUT +else: + UNSET_RPC_TIMEOUT = -1.0 + + +T = TypeVar('T') +U = TypeVar('U') +Args = Tuple[Any, ...] +KwArgs = Dict[str, Any] +PartitionFunction = Callable[..., Sequence[Tuple[int, Optional[Args], Optional[KwArgs]]]] +Partitioner = Union[int, str, PartitionFunction] + + +class TensorDimensionPartitioner: + """Partitioner class that partitions a batch of inputs along a given dimension. + + All tensors in the ``args`` and ``kwargs`` will be partitioned along the dimension ``dim``, + while the non-tensor values will be broadcasted to partitions. + + Args: + dim: The dimension to partition. + exclusive: Whether to partition the batch exclusively. + If :data:`True`, the batch will be partitioned into ``batch_size`` partitions, where + ``batch_size`` is the size of the batch along the given dimension. Each batch sample + will be assigned to a separate RPC call. + If :data:`False`, the batch will be partitioned into ``min(batch_size, num_workers)`` + partitions, where ``num_workers`` is the number of workers in the world. When + ``batch_size > num_workers``, there can be multiple batch samples forward in a single + RPC call. + keepdim: Whether to keep the partitioned dimension. Defaults to :data:`True`, i.e., keep the + batch dimension. If :data:`False`, use select instead of slicing. This functionality + should be used with ``exclusive=True``. + workers: The workers to partition the batch to. If :data:`None`, the batch will be + partitioned to all workers in the world. + """ + + def __init__( + self, + dim: int, + *, + exclusive: bool = False, + keepdim: bool = False, + workers: Optional[Sequence[Union[int, str]]] = None, + ) -> None: + """Initialize the partitioner instance.""" + if not keepdim and not exclusive: + raise ValueError('keepdim=False should be used with exclusive=True.') + + self.dim = dim + self.exclusive = exclusive + self.keepdim = keepdim + self.workers = workers + + # pylint: disable-next=too-many-branches,too-many-locals + def __call__( + self, + *args: Any, + **kwargs: Any, + ) -> List[Tuple[int, Optional[Args], Optional[KwArgs]]]: + """Partition the batch of inputs along the given dimension.""" + if self.workers is None: + workers = list(range(get_world_size())) + else: + workers = list(map(get_worker_id, self.workers)) + num_workers = len(workers) + + args_tree = (args, kwargs) + flat_args: List[Any] + flat_args, treespec = pytree.tree_flatten(args_tree) # type: ignore[arg-type] + + batch_size = None + for arg in flat_args: + if isinstance(arg, torch.Tensor): + if batch_size is None: + batch_size = arg.shape[self.dim] + elif batch_size != arg.shape[self.dim]: # type: ignore[unreachable] + raise ValueError( + f'Batch size mismatch on dim={self.dim}. ' + f'Expected {batch_size}, got {arg.shape[self.dim]} (shape: {arg.shape}).' + ) + + if batch_size is None: + return [(get_world_rank(), args, kwargs.copy())] + + dim_slices: List[Union[int, slice]] + batch_slices: List[Tuple[Union[int, slice, Ellipsis.__class__], ...]] # type: ignore[name-defined] + if self.exclusive: + num_replicas = batch_size + if self.keepdim: + dim_slices = [slice(i, i + 1) for i in range(num_replicas)] + else: + dim_slices = list(range(num_replicas)) + else: + if batch_size <= num_workers: + num_replicas = batch_size + dim_slices = [slice(i, i + 1) for i in range(batch_size)] # keepdim=True + else: + num_replicas = num_workers + local_size = batch_size // num_workers + local_batch_indices = [i * local_size for i in range(num_workers)] + [batch_size] + dim_slices = [ + slice(local_batch_indices[i], local_batch_indices[i + 1]) + for i in range(num_workers) + ] + + if self.dim >= 0: + batch_slices = [ + (slice(None, None),) * self.dim + (dim_slice,) for dim_slice in dim_slices + ] + elif self.dim < 0: + batch_slices = [ + ( + ..., + dim_slice, + ) + + (slice(None, None),) * (-self.dim - 1) + for dim_slice in dim_slices + ] + + flat_args_replicas: List[List[Any]] = [[] for _ in range(num_replicas)] + for arg in flat_args: + if isinstance(arg, torch.Tensor): + for i, batch_slice in enumerate(batch_slices): + flat_args_replicas[i].append(arg[batch_slice]) + else: + for i in range(num_replicas): + flat_args_replicas[i].append(arg) + + args_replicas: List[Tuple[Args, KwArgs]] = [ + pytree.tree_unflatten(treespec, args_replica) # type: ignore[misc] + for args_replica in flat_args_replicas + ] + + return [ + (workers[i % num_workers], worker_args, worker_kwargs) + for i, (worker_args, worker_kwargs) in enumerate(args_replicas) + ] + + def __reduce__( + self, + ) -> Tuple[ + Callable[..., 'TensorDimensionPartitioner'], + Tuple[int], + Dict[str, Union[bool, Optional[Sequence[Union[int, str]]]]], + ]: + """Return a tuple that allows the partitioner to be pickled.""" + return ( + TensorDimensionPartitioner, + (self.dim,), + dict(exclusive=self.exclusive, keepdim=self.keepdim, workers=self.workers), + ) + + +def dim_partitioner( + dim: int = 0, + *, + exclusive: bool = False, + keepdim: bool = True, + workers: Optional[Sequence[Union[int, str]]] = None, +) -> PartitionFunction: + """Partition a batch of inputs along a given dimension. + + All tensors in the ``args`` and ``kwargs`` will be partitioned along the dimension ``dim``, + while the non-tensor values will be broadcasted to partitions. + + Args: + dim: The dimension to partition. + exclusive: Whether to partition the batch exclusively. + If :data:`True`, the batch will be partitioned into ``batch_size`` partitions, where + ``batch_size`` is the size of the batch along the given dimension. Each batch sample + will be assigned to a separate RPC call. + If :data:`False`, the batch will be partitioned into ``min(batch_size, num_workers)`` + partitions, where ``num_workers`` is the number of workers in the world. When + ``batch_size > num_workers``, there can be multiple batch samples forward in a single + RPC call. + keepdim: Whether to keep the partitioned dimension. Defaults to :data:`True`, i.e., keep the + batch dimension. If :data:`False`, use select instead of slicing. This functionality + should be used with ``exclusive=True``. + workers: The workers to partition the batch to. If :data:`None`, the batch will be + partitioned to all workers in the world. + + Returns: + A partition function. + """ + return TensorDimensionPartitioner(dim, exclusive=exclusive, keepdim=keepdim, workers=workers) + + +batch_partitioner: PartitionFunction = dim_partitioner(dim=0, keepdim=True, exclusive=False) +"""Partitioner for batch dimension. Divide and replicates the arguments to all workers along the first dimension. + +The batch will be partitioned into ``min(batch_size, num_workers)`` partitions, where +``num_workers`` is the number of workers in the world. +When ``batch_size > num_workers``, there can be multiple batch samples forward in a single RPC call. + +All tensors in the ``args`` and ``kwargs`` will be partitioned along the dimension ``dim``, +while the non-tensor values will be broadcasted to partitions. +""" +exclusive_batch_partitioner: PartitionFunction = dim_partitioner(dim=0, keepdim=True, exclusive=True) # fmt: skip +"""Partitioner for batch dimension. Divide and replicates the arguments to all workers along the first dimension. + +Each batch sample will be assigned to a separate RPC call. + +All tensors in the ``args`` and ``kwargs`` will be partitioned along the dimension ``dim``, +while the non-tensor values will be broadcasted to partitions. +""" + + +def mean_reducer(results: Iterable[torch.Tensor]) -> torch.Tensor: + """Reduce the results by averaging them.""" + return torch.mean(torch.stack(tuple(results), dim=0), dim=0) + + +def sum_reducer(results: Iterable[torch.Tensor]) -> torch.Tensor: + """Reduce the results by summing them.""" + return torch.sum(torch.stack(tuple(results), dim=0), dim=0) + + +def remote_async_call( + func: Callable[..., T], + *, + args: Optional[Args] = None, + kwargs: Optional[KwArgs] = None, + partitioner: Optional[Partitioner] = None, + reducer: Optional[Callable[[Iterable[T]], U]] = None, + timeout: Optional[float] = UNSET_RPC_TIMEOUT, +) -> Union[Future[List[T]], Future[U]]: + """Asynchronously do an RPC on remote workers and return the a :class:`torch.Future` instance at the current worker. + + Args: + func (Callable[..., T]): The function to call. + args (Optional[Args], optional): The arguments to pass to the function. Defaults to + :data:`None`. + kwargs (Optional[KwArgs], optional): The keyword arguments to pass to the function. Defaults + to :data:`None`. + partitioner (Partitioner, optional): A partitioner that partitions the arguments to multiple + workers. Defaults to :func:`batch_partitioner`. + reducer (Callable[[Iterable[T]], U], optional): A reducer that reduces the results from + multiple workers. Defaults to :data:`None`. + timeout (float, optional): The timeout for the RPC call. Defaults to + :data:`rpc.api.UNSET_RPC_TIMEOUT`. + + Returns: + A :class:`torch.Future` instance for the result. The result is at the current worker. + """ + if args is None: + args = () + if kwargs is None: + kwargs = {} + if partitioner is None: + partitioner = batch_partitioner + if isinstance(partitioner, (int, str)): + partitions = [(get_worker_id(id=partitioner), args, kwargs)] + elif callable(partitioner): + partitions = partitioner(*args, **kwargs) # type: ignore[assignment] + else: + raise ValueError(f'Invalid partitioner: {partitioner!r}.') + + futures = [] + for rank, worker_args, worker_kwargs in partitions: + fut = rpc.rpc_async(rank, func, args=worker_args, kwargs=worker_kwargs, timeout=timeout) + futures.append(fut) + + future = cast( + Future[List[T]], + torch.futures.collect_all(futures).then(lambda fut: [f.wait() for f in fut.wait()]), + ) + if reducer is not None: + return cast( + Future[U], + future.then(lambda fut: cast(Callable[[Iterable[T]], U], reducer)(fut.wait())), + ) + return future + + +def remote_sync_call( + func: Callable[..., T], + *, + args: Optional[Args] = None, + kwargs: Optional[KwArgs] = None, + partitioner: Optional[Partitioner] = None, + reducer: Optional[Callable[[Iterable[T]], U]] = None, + timeout: Optional[float] = UNSET_RPC_TIMEOUT, +) -> Union[List[T], U]: + """Synchronously do an RPC on remote workers and return the result to the current worker. + + Args: + func (Callable[..., T]): The function to call. + args (Optional[Args], optional): The arguments to pass to the function. Defaults to + :data:`None`. + kwargs (Optional[KwArgs], optional): The keyword arguments to pass to the function. Defaults + to :data:`None`. + partitioner (Partitioner, optional): A partitioner that partitions the arguments to multiple + workers. Defaults to :func:`batch_partitioner`. + reducer (Callable[[Iterable[T]], U], optional): A reducer that reduces the results from + multiple workers. Defaults to :data:`None`. + timeout (float, optional): The timeout for the RPC call. Defaults to + :data:`rpc.api.UNSET_RPC_TIMEOUT`. + + Returns: + The result of the RPC call. The result is at the current worker. + """ + return remote_async_call( + func, + args=args, + kwargs=kwargs, + partitioner=partitioner, + timeout=timeout, + reducer=reducer, + ).wait() + + +def parallelize_async( + partitioner: Optional[Partitioner] = None, + reducer: Optional[Callable[[Iterable[T]], U]] = None, + timeout: Optional[float] = UNSET_RPC_TIMEOUT, +) -> Callable[[Callable[..., T]], Callable[..., Union[Future[List[T]], Future[U]]]]: + """Decorator for parallelizing a function. + + This decorator can be used to parallelize a function call across multiple workers. The + function will be called asynchronously on remote workers. The decorated function will + return a :class:`torch.Future` instance of the result. + + Args: + partitioner (Partitioner, optional): A partitioner that partitions the arguments to multiple + workers. Defaults to :func:`batch_partitioner`. + reducer (Callable[[Iterable[T]], U], optional): A reducer that reduces the results from + multiple workers. Defaults to :func:`mean_reducer` if the ``partitioner`` is not + specified, i.e., :func:`batch_partitioner`. Otherwise, it defaults to :data:`None`. + timeout (float, optional): The timeout for the RPC call. Defaults to + :data:`rpc.api.UNSET_RPC_TIMEOUT`. + + Returns: + The decorator function. + """ + if partitioner is None: + partitioner = batch_partitioner + if reducer is None: + reducer = mean_reducer # type: ignore[assignment] + + def wrapper(func: Callable[..., T]) -> Callable[..., Union[Future[List[T]], Future[U]]]: + @functools.wraps(func) + def wrapped(*args: Any, **kwargs: Any) -> Union[Future[List[T]], Future[U]]: + return remote_async_call( + func, + args=args, + kwargs=kwargs, + partitioner=partitioner, + reducer=reducer, + timeout=timeout, + ) + + suffix = '__parallelize_async_unwrapped__' + module_name = func.__module__ + try: + name = func.__qualname__ + except AttributeError: + name = func.__name__ + else: + func.__qualname__ = f'{func.__qualname__}{suffix}' + func.__name__ = f'{func.__name__}{suffix}' + __import__(module_name, level=0) + module = sys.modules[module_name] + setattr(module, f'{name}{suffix}', func) + + return wrapped + + return wrapper + + +def parallelize( + partitioner: Optional[Partitioner] = None, + reducer: Optional[Callable[[Iterable[T]], U]] = None, + timeout: Optional[float] = UNSET_RPC_TIMEOUT, +) -> Callable[[Callable[..., T]], Callable[..., Union[List[T], U]]]: + """Decorator for parallelizing a function. + + This decorator can be used to parallelize a function call across multiple workers. + + Args: + partitioner (Partitioner, optional): A partitioner that partitions the arguments to multiple + workers. Defaults to :func:`batch_partitioner`. + reducer (Callable[[Iterable[T]], U], optional): A reducer that reduces the results from + multiple workers. Defaults to :func:`mean_reducer` if the ``partitioner`` is not + specified, i.e., :func:`batch_partitioner`. Otherwise, it defaults to :data:`None`. + timeout (float, optional): The timeout for the RPC call. Defaults to + :data:`rpc.api.UNSET_RPC_TIMEOUT`. + + Returns: + The decorator function. + """ + if partitioner is None: + partitioner = batch_partitioner + if reducer is None: + reducer = mean_reducer # type: ignore[assignment] + + def wrapper(func: Callable[..., T]) -> Callable[..., Union[List[T], U]]: + @functools.wraps(func) + def wrapped(*args: Any, **kwargs: Any) -> Union[List[T], U]: + return remote_sync_call( + func, + args=args, + kwargs=kwargs, + partitioner=partitioner, + reducer=reducer, + timeout=timeout, + ) + + suffix = '__parallelize_unwrapped__' + module_name = func.__module__ + try: + name = func.__qualname__ + except AttributeError: + name = func.__name__ + else: + func.__qualname__ = f'{func.__qualname__}{suffix}' + func.__name__ = f'{func.__name__}{suffix}' + __import__(module_name, level=0) + module = sys.modules[module_name] + setattr(module, f'{name}{suffix}', func) + + return wrapped + + return wrapper + + +parallelize_sync = parallelize diff --git a/torchopt/distributed/autograd.py b/torchopt/distributed/autograd.py new file mode 100644 index 00000000..41b6b461 --- /dev/null +++ b/torchopt/distributed/autograd.py @@ -0,0 +1,150 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Distributed Autograd.""" + +from threading import Lock +from typing import Optional, overload + +import torch +import torch.distributed.autograd as autograd +from torch.distributed.autograd import context + +from torchopt.typing import TensorOrTensors, TupleOfOptionalTensors, TupleOfTensors + + +__all__ = ['is_available', 'context'] + + +LOCK = Lock() + + +def is_available(): + """Check if distributed autograd module is available.""" + return autograd.is_available() + + +if is_available(): + # pylint: disable-next=unused-import,ungrouped-imports + from torch.distributed.autograd import DistAutogradContext, get_gradients + + def backward( + autograd_ctx_id: int, + tensors: TensorOrTensors, + retain_graph: bool = False, + inputs: Optional[TensorOrTensors] = None, + ) -> None: + """Perform distributed backward pass for local parameters. + + Computes the sum of gradients of given tensors with respect to graph leaves. + + Args: + autograd_ctx_id: The autograd context id. + tensors (Sequence[Tensor] or Tensor): Tensors of which the derivative will be computed. + retain_graph (bool, optional): If :data:`False`, the graph used to compute the grad will + be freed. Note that in nearly all cases setting this option to :data:`True` is not + needed and often can be worked around in a much more efficient way. + inputs (Sequence[Tensor] or Tensor, optional): Inputs w.r.t. which the gradient be will + accumulated into ``.grad``. All other Tensors will be ignored. If not provided, the + gradient is accumulated into all the leaf Tensors that were used to compute the + attr::tensors. + """ + if inputs is not None: + if isinstance(inputs, torch.Tensor): + inputs = (inputs,) + elif len(inputs) == 0: + raise RuntimeError("'inputs' argument to backward() cannot be empty.") + else: + inputs = tuple(inputs) + if not all(map(lambda t: t.requires_grad, inputs)): + raise RuntimeError('One of the differentiated Tensors does not require grad') + + roots = [tensors] if isinstance(tensors, torch.Tensor) else list(tensors) + autograd.backward(autograd_ctx_id, roots=roots, retain_graph=retain_graph) + + all_local_grads = autograd.get_gradients(autograd_ctx_id) + if inputs is not None: + inputs = set(inputs) # type: ignore[assignment] + all_local_grads = {p: g for p, g in all_local_grads.items() if p in inputs} + + with LOCK: + for p, g in all_local_grads.items(): + if p.grad is not None: + p.grad = p.grad.add(g) + else: + p.grad = g + + @overload + def grad( + autograd_ctx_id: int, + outputs: TensorOrTensors, + inputs: TensorOrTensors, + retain_graph: bool = False, + ) -> TupleOfTensors: + ... + + @overload + def grad( + autograd_ctx_id: int, + outputs: TensorOrTensors, + inputs: TensorOrTensors, + retain_graph: bool = False, + allow_unused: bool = False, + ) -> TupleOfOptionalTensors: + ... + + def grad( + autograd_ctx_id: int, + outputs: TensorOrTensors, + inputs: TensorOrTensors, + retain_graph: bool = False, + allow_unused: bool = False, + ) -> TupleOfOptionalTensors: + """Computes and returns the sum of gradients of outputs with respect to the inputs. + + Args: + autograd_ctx_id: The autograd context id. + outputs (sequence of Tensor): outputs of the differentiated function. + inputs (sequence of Tensor): Inputs w.r.t. which the gradient will be returned (and not + accumulated into ``.grad``). + retain_graph (bool, optional): If :data:`False`, the graph used to compute the grad will + be freed. Note that in nearly all cases setting this option to :data:`True` is not + needed and often can be worked around in a much more efficient way. + allow_unused (bool, optional): If :data:`False`, specifying inputs that were not used + when computing outputs (and therefore their grad is always zero) is an error. + Defaults to :data:`False`. + """ + outputs = [outputs] if isinstance(outputs, torch.Tensor) else list(outputs) + inputs = (inputs,) if isinstance(inputs, torch.Tensor) else tuple(inputs) + if not all(map(lambda t: t.requires_grad, inputs)): + raise RuntimeError('One of the differentiated Tensors does not require grad') + + autograd.backward(autograd_ctx_id, roots=outputs, retain_graph=retain_graph) + + all_local_grads = autograd.get_gradients(autograd_ctx_id) + grads = [] + for p in inputs: + try: + grads.append(all_local_grads[p]) + except KeyError as ex: + if not allow_unused: + raise RuntimeError( + 'One of the differentiated Tensors appears to not have been used in the ' + 'graph. Set allow_unused=True if this is the desired behavior.' + ) from ex + grads.append(None) # type: ignore[arg-type] + + return tuple(grads) + + __all__.extend(['DistAutogradContext', 'get_gradients', 'backward', 'grad']) diff --git a/torchopt/distributed/world.py b/torchopt/distributed/world.py new file mode 100644 index 00000000..4a24f3ef --- /dev/null +++ b/torchopt/distributed/world.py @@ -0,0 +1,228 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for gathering information about the world.""" + +import atexit +import functools +import os +from typing import Any, Callable, Iterable, NamedTuple, Optional, TypeVar, Union + +import torch.distributed.rpc as rpc +from torch.distributed.elastic.multiprocessing.errors import record + + +__all__ = [ + 'get_world_info', + 'get_world_rank', + 'get_rank', + 'get_world_size', + 'get_local_rank', + 'get_local_world_size', + 'get_worker_id', + 'barrier', + 'auto_init_rpc', + 'on_rank', + 'not_on_rank', + 'rank_zero_only', + 'rank_non_zero_only', +] + + +def default_worker_name_format( + world_rank: int, + world_size: int, + local_rank: int, # pylint: disable=unused-argument + local_world_size: int, # pylint: disable=unused-argument +) -> str: + """Default worker name format.""" + return f'worker{world_rank:0{len(str(world_size))}d}' + + +F = TypeVar('F', bound=Callable[..., Any]) +_WORKER_NAME_FORMAT: Callable[..., str] = default_worker_name_format + + +class WorldInfo(NamedTuple): + """Information about the world.""" + + world_rank: int + world_size: int + local_rank: int + local_world_size: int + + @property + def rank(self) -> int: + """The global world rank of the current worker.""" + return self.world_rank + + @property + def worker_name(self) -> str: + """The name of the current worker.""" + return _WORKER_NAME_FORMAT( + world_rank=self.world_rank, + world_size=self.world_size, + local_rank=self.local_rank, + local_world_size=self.local_world_size, + ) + + +def get_world_info() -> WorldInfo: + """Get the world information.""" + world_info = getattr(get_world_info, 'world_info', None) + + if world_info is None: + world_rank = int(os.getenv('RANK', '0')) + world_size = int(os.getenv('WORLD_SIZE', '1')) + local_rank = int(os.getenv('LOCAL_RANK', '0')) + local_world_size = int(os.getenv('LOCAL_WORLD_SIZE', '1')) + world_info = WorldInfo(world_rank, world_size, local_rank, local_world_size) + # pylint: disable=line-too-long + get_world_info.world_info = get_world_info.WORLD_INFO = world_info # type: ignore[attr-defined] + get_world_info.world_rank = get_world_info.WORLD_RANK = world_rank # type: ignore[attr-defined] + get_world_info.rank = get_world_info.RANK = world_rank # type: ignore[attr-defined] + get_world_info.world_size = get_world_info.WORLD_SIZE = world_size # type: ignore[attr-defined] + get_world_info.local_rank = get_world_info.LOCAL_RANK = local_rank # type: ignore[attr-defined] + get_world_info.local_world_size = get_world_info.LOCAL_WORLD_SIZE = local_world_size # type: ignore[attr-defined] + # pylint: enable=line-too-long + + return world_info + + +def get_world_rank() -> int: + """Get the global world rank of the current worker.""" + return get_world_info().world_rank + + +get_rank = get_world_rank + + +def get_world_size() -> int: + """Get the world size.""" + return get_world_info().world_size + + +def get_local_rank() -> int: + """Get the local rank of the current worker on the current node.""" + return get_world_info().local_rank + + +def get_local_world_size() -> int: + """Get the local world size on the current node.""" + return get_world_info().local_world_size + + +get_world_info() + + +# pylint: disable-next=redefined-builtin,invalid-name +def get_worker_id(id: Optional[Union[str, int]] = None) -> int: + """Get the worker id from the given id.""" + if isinstance(id, int): + return id + return rpc.get_worker_info(worker_name=id).id + + +def barrier(worker_names: Optional[Iterable[str]] = None) -> None: + r"""Synchronizes local and remote RPC processes. + + This will block until all local and remote RPC processes specified under worker_names + reach this method to wait for all outstanding work to complete. + + Args: + worker_names: The set of workers to synchronize. If :data:`None`, all workers. + """ + worker_names = {} if worker_names is None else set(worker_names) + rpc.api._barrier(worker_names) # pylint: disable=protected-access + + +def auto_init_rpc( + worker_init_fn: Optional[Callable[[], None]] = None, + worker_name_format: Callable[..., str] = default_worker_name_format, + *, + backend: Optional['rpc.BackendType'] = None, + rpc_backend_options: Optional['rpc.RpcBackendOptions'] = None, +) -> Callable[[F], F]: + """Decorator to automatically initialize RPC on the decorated function.""" + global _WORKER_NAME_FORMAT # pylint: disable=global-statement + _WORKER_NAME_FORMAT = worker_name_format + + def wrapper(func: F) -> F: + world_info = get_world_info() + + @record + @functools.wraps(func) + def wrapped(*args, **kwargs): + rpc.init_rpc( + name=world_info.worker_name, + rank=world_info.rank, + world_size=world_info.world_size, + backend=backend, + rpc_backend_options=rpc_backend_options, + ) + atexit.register(rpc.shutdown, graceful=True) + if worker_init_fn is not None: + barrier() + worker_init_fn() + barrier() + return func(*args, **kwargs) + + return wrapped # type: ignore[return-value] + + return wrapper + + +def __on_ranks(ranks: Iterable[int], inverse: bool = False) -> Callable[[F], F]: + ranks = frozenset(ranks) + + def wrapper(func: F) -> F: + world_rank = get_world_info().world_rank + + @functools.wraps(func) + def wrapped(*args, **kwargs): + if inverse: + if world_rank not in ranks: + return func(*args, **kwargs) + elif world_rank in ranks: + return func(*args, **kwargs) + return None + + return wrapped # type: ignore[return-value] + + return wrapper + + +def on_rank(*ranks: int) -> Callable[[F], F]: + """Decorator to mark a function to be executed only on given ranks.""" + return __on_ranks(ranks=ranks, inverse=False) + + +def not_on_rank(*ranks) -> Callable[[F], F]: + """Decorator to mark a function to be executed only on non given ranks.""" + return __on_ranks(ranks=ranks, inverse=True) + + +def rank_all(func: F) -> F: + """Decorator to mark a function to be executed on all ranks.""" + return func + + +def rank_zero_only(func: F) -> F: + """Decorator to mark a function to be executed only on rank zero.""" + return on_rank(0)(func) + + +def rank_non_zero_only(func: F) -> F: + """Decorator to mark a function to be executed only on non rank zero.""" + return not_on_rank(0)(func) diff --git a/torchopt/_src/hook.py b/torchopt/hook.py similarity index 60% rename from torchopt/_src/hook.py rename to torchopt/hook.py index 305c34ca..612f2177 100644 --- a/torchopt/_src/hook.py +++ b/torchopt/hook.py @@ -12,16 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Hook utilities.""" + +from typing import Callable, Optional import torch -from torchopt._src.base import EmptyState, GradientTransformation -from torchopt._src.utils import pytree +from torchopt import pytree +from torchopt.base import EmptyState, GradientTransformation + + +__all__ = ['zero_nan_hook', 'nan_to_num_hook', 'register_hook'] def zero_nan_hook(g: torch.Tensor) -> torch.Tensor: - """Registers a zero nan hook to replace nan with zero.""" - return torch.where(torch.isnan(g), torch.zeros_like(g), g) + """A zero ``nan`` hook to replace ``nan`` with zero.""" + return g.nan_to_num(nan=0.0) + + +def nan_to_num_hook( + nan: float = 0.0, posinf: Optional[float] = None, neginf: Optional[float] = None +) -> Callable[[torch.Tensor], torch.Tensor]: + """Returns a ``nan`` to num hook to replace ``nan`` / ``+inf`` / ``-inf`` with the given numbers.""" + + def hook(g: torch.Tensor) -> torch.Tensor: + """A hook to replace ``nan`` / ``+inf`` / ``-inf`` with the given numbers.""" + return g.nan_to_num(nan=nan, posinf=posinf, neginf=neginf) + + return hook def register_hook(hook) -> GradientTransformation: @@ -38,9 +56,9 @@ def init_fn(params): # pylint: disable=unused-argument def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument def f(g): - return g.register_hook(hook) if g is not None else None + return g.register_hook(hook) - pytree.tree_map(f, updates) + pytree.tree_map_(f, updates) return updates, state return GradientTransformation(init_fn, update_fn) diff --git a/torchopt/linalg/__init__.py b/torchopt/linalg/__init__.py new file mode 100644 index 00000000..20dc16aa --- /dev/null +++ b/torchopt/linalg/__init__.py @@ -0,0 +1,38 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/google/jax/blob/main/jax/_src/scipy/sparse/linalg.py +# ============================================================================== +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Linear algebra functions.""" + +from torchopt.linalg.cg import cg +from torchopt.linalg.ns import ns, ns_inv + + +__all__ = ['cg', 'ns', 'ns_inv'] diff --git a/torchopt/linalg/cg.py b/torchopt/linalg/cg.py new file mode 100644 index 00000000..94daee53 --- /dev/null +++ b/torchopt/linalg/cg.py @@ -0,0 +1,184 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/google/jax/blob/main/jax/_src/scipy/sparse/linalg.py +# ============================================================================== +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Conjugate Gradient iteration to solve ``Ax = b``.""" + +# pylint: disable=invalid-name + +from functools import partial +from typing import Callable, Optional, Union + +import torch + +from torchopt import pytree +from torchopt.linalg.utils import cat_shapes, normalize_matvec +from torchopt.pytree import tree_vdot_real +from torchopt.typing import TensorTree + + +__all__ = ['cg'] + + +def _identity(x: TensorTree) -> TensorTree: + return x + + +# pylint: disable-next=too-many-locals +def _cg_solve( + A: Callable[[TensorTree], TensorTree], + b: TensorTree, + x0: TensorTree, + *, + maxiter: int, + rtol: float = 1e-5, + atol: float = 0.0, + M: Callable[[TensorTree], TensorTree] = _identity, +) -> TensorTree: + # https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method + + # tolerance handling uses the "non-legacy" behavior of `scipy.sparse.linalg.cg` + b2 = tree_vdot_real(b, b) + atol2 = max(rtol**2 * b2, atol**2) + + def cond_fn(value): + _, r, gamma, _, k = value + rs = gamma if M is _identity else tree_vdot_real(r, r) + return rs > atol2 and k < maxiter + + def body_fn(value): + x, r, gamma, p, k = value + Ap = A(p) + alpha = gamma / tree_vdot_real(p, Ap) + x_ = pytree.tree_map(lambda a, b: a.add(b, alpha=alpha), x, p) + r_ = pytree.tree_map(lambda a, b: a.sub(b, alpha=alpha), r, Ap) + z_ = M(r_) + gamma_ = tree_vdot_real(r_, z_) + beta_ = gamma_ / gamma + p_ = pytree.tree_map(lambda a, b: a.add(b, alpha=beta_), z_, p) + return x_, r_, gamma_, p_, k + 1 + + r0 = pytree.tree_map(torch.sub, b, A(x0)) + p0 = z0 = M(r0) + gamma0 = tree_vdot_real(r0, z0) + + value = (x0, r0, gamma0, p0, 0) + while cond_fn(value): + value = body_fn(value) + + x_final, *_ = value + + return x_final + + +def _isolve( + _isolve_solve: Callable, + A: Union[TensorTree, Callable[[TensorTree], TensorTree]], + b: TensorTree, + x0: Optional[TensorTree] = None, + *, + rtol: float = 1e-5, + atol: float = 0.0, + maxiter: Optional[int] = None, + M: Optional[Union[TensorTree, Callable[[TensorTree], TensorTree]]] = None, +) -> TensorTree: + if x0 is None: + x0 = pytree.tree_map(torch.zeros_like, b) + + if maxiter is None: + size = sum(cat_shapes(b)) + maxiter = 10 * size # copied from SciPy + + if M is None: + M = _identity + A = normalize_matvec(A) + M = normalize_matvec(M) + + if cat_shapes(x0) != cat_shapes(b): + raise ValueError( + f'Tensors in x0 and b must have matching shapes: {cat_shapes(x0)} vs. {cat_shapes(b)}.' + ) + + isolve_solve = partial(_isolve_solve, x0=x0, rtol=rtol, atol=atol, maxiter=maxiter, M=M) + + x = isolve_solve(A, b) + return x + + +def cg( + A: Union[TensorTree, Callable[[TensorTree], TensorTree]], + b: TensorTree, + x0: Optional[TensorTree] = None, + *, + rtol: float = 1e-5, + atol: float = 0.0, + maxiter: Optional[int] = None, + M: Optional[Union[TensorTree, Callable[[TensorTree], TensorTree]]] = None, +) -> TensorTree: + """Use Conjugate Gradient iteration to solve ``Ax = b``. + + The numerics of TorchOpt's ``cg`` should exact match SciPy's ``cg`` (up to numerical precision), + but note that the interface is slightly different: you need to supply the linear operator ``A`` + as a function instead of a sparse matrix or ``LinearOperator``. + + Derivatives of :func:`cg` are implemented via implicit differentiation with another :func:`cg` + solve, rather than by differentiating *through* the solver. They will be accurate only if both + solves converge. + + Args: + A: (tensor or tree of tensors or function) + 2D array or function that calculates the linear map (matrix-vector product) ``Ax`` when + called like ``A(x)``. ``A`` must represent a hermitian, positive definite matrix, and + must return array(s) with the same structure and shape as its argument. + b: (tensor or tree of tensors) + Right hand side of the linear system representing a single vector. Can be stored as an + array or Python container of array(s) with any shape. + x0: (tensor or tree of tensors, optional) + Starting guess for the solution. Must have the same structure as ``b``. + rtol: (float, optional, default: :const:`1e-5`) + Tolerances for convergence, ``norm(residual) <= max(rtol*norm(b), atol)``. We do not + implement SciPy's "legacy" behavior, so TorchOpt's tolerance will differ from SciPy + unless you explicitly pass ``atol`` to SciPy's ``cg``. + atol: (float, optional, default: :const:`0.0`) + Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``. We do not + implement SciPy's "legacy" behavior, so TorchOpt's tolerance will differ from SciPy + unless you explicitly pass ``atol`` to SciPy's ``cg``. + maxiter: (integer, optional) + Maximum number of iterations. Iteration will stop after maxiter steps even if the + specified tolerance has not been achieved. + M: (tensor or tree of tensors or function) + Pre-conditioner for ``A``. The pre-conditioner should approximate the inverse of ``A``. + Effective preconditioning dramatically improves the rate of convergence, which implies + that fewer iterations are needed to reach a given error tolerance. + + Returns: + the Conjugate Gradient (CG) linear solver + """ + return _isolve(_cg_solve, A=A, b=b, x0=x0, rtol=rtol, atol=atol, maxiter=maxiter, M=M) diff --git a/torchopt/linalg/ns.py b/torchopt/linalg/ns.py new file mode 100644 index 00000000..4da8ef9f --- /dev/null +++ b/torchopt/linalg/ns.py @@ -0,0 +1,161 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Neumann Series Matrix Inversion Approximation to solve ``Ax = b``.""" + +# pylint: disable=invalid-name + +import functools +from typing import Callable, Optional, Union + +import torch + +from torchopt import pytree +from torchopt.linalg.utils import cat_shapes, normalize_matvec +from torchopt.typing import TensorTree + + +__all__ = ['ns', 'ns_inv'] + + +def _ns_solve( + A: torch.Tensor, + b: torch.Tensor, + maxiter: int, + alpha: Optional[float] = None, +) -> torch.Tensor: + """Uses Neumann Series Matrix Inversion Approximation to solve ``Ax = b``.""" + if A.ndim != 2 or A.shape[0] != A.shape[1]: + raise ValueError(f'`A` must be a square matrix, but has shape: {A.shape}') + + inv_A_hat_b = b + v = b + if alpha is not None: + # A^{-1} = a [I - (I - a A)]^{-1} = a [I + (I - a A) + (I - a A)^2 + (I - a A)^3 + ...] + for _ in range(maxiter): + v = v - alpha * (A @ v) + inv_A_hat_b = inv_A_hat_b + v + inv_A_hat_b = alpha * inv_A_hat_b + else: + # A^{-1} = [I - (I - A)]^{-1} = I + (I - A) + (I - A)^2 + (I - A)^3 + ... + for _ in range(maxiter): + v = v - A @ v + inv_A_hat_b = inv_A_hat_b + v + + return inv_A_hat_b + + +def ns( + A: Union[TensorTree, Callable[[TensorTree], TensorTree]], + b: TensorTree, + maxiter: Optional[int] = None, + *, + alpha: Optional[float] = None, +) -> TensorTree: + """Uses Neumann Series Matrix Inversion Approximation to solve ``Ax = b``. + + Args: + A: (tensor or tree of tensors or function) + 2D array or function that calculates the linear map (matrix-vector product) ``Ax`` when + called like ``A(x)``. ``A`` must represent a hermitian, positive definite matrix, and + must return array(s) with the same structure and shape as its argument. + b: (tensor or tree of tensors) + Right hand side of the linear system representing a single vector. Can be stored as an + array or Python container of array(s) with any shape. + maxiter: (integer, optional) + Maximum number of iterations. Iteration will stop after maxiter steps even if the + specified tolerance has not been achieved. + alpha: (float, optional) + Decay coefficient. + + Returns: + The Neumann Series (NS) matrix inversion approximation. + """ + if maxiter is None: + maxiter = 10 + + if not callable(A): + return pytree.tree_map(functools.partial(_ns_solve, maxiter=maxiter, alpha=alpha), A, b) + + matvec = normalize_matvec(A) + inv_A_hat_b = b + v = b + if alpha is not None: + # A^{-1} = a [I - (I - a A)]^{-1} = a [I + (I - a A) + (I - a A)^2 + (I - a A)^3 + ...] + for _ in range(maxiter): + # v = v - alpha * (A @ v) + v = pytree.tree_sub_scalar_mul(v, matvec(v), alpha=alpha) + # inv_A_hat_b = inv_A_hat_b + v + inv_A_hat_b = pytree.tree_add(inv_A_hat_b, v) + # inv_A_hat_b = alpha * inv_A_hat_b + inv_A_hat_b = pytree.tree_scalar_mul(alpha, inv_A_hat_b) + else: + # A^{-1} = [I - (I - A)]^{-1} = I + (I - A) + (I - A)^2 + (I - A)^3 + ... + for _ in range(maxiter): + # v = v - A @ v + v = pytree.tree_sub(v, matvec(v)) + # inv_A_hat_b = inv_A_hat_b + v + inv_A_hat_b = pytree.tree_add(inv_A_hat_b, v) + + return inv_A_hat_b + + +def _ns_inv(A: torch.Tensor, maxiter: int, alpha: Optional[float] = None): + """Uses Neumann Series iteration to solve ``A^{-1}``.""" + if A.ndim != 2 or A.shape[0] != A.shape[1]: + raise ValueError(f'`A` must be a square matrix, but has shape: {A.shape}') + + I = torch.eye(*A.shape, out=torch.empty_like(A)) + inv_A_hat = torch.zeros_like(A) + if alpha is not None: + # A^{-1} = a [I - (I - a A)]^{-1} = a [I + (I - a A) + (I - a A)^2 + (I - a A)^3 + ...] + M = I - alpha * A + for rank in range(maxiter): + inv_A_hat = inv_A_hat + torch.linalg.matrix_power(M, rank) + inv_A_hat = alpha * inv_A_hat + else: + # A^{-1} = [I - (I - A)]^{-1} = I + (I - A) + (I - A)^2 + (I - A)^3 + ... + M = I - A + for rank in range(maxiter): + inv_A_hat = inv_A_hat + torch.linalg.matrix_power(M, rank) + return inv_A_hat + + +def ns_inv( + A: TensorTree, + maxiter: Optional[int] = None, + *, + alpha: Optional[float] = None, +) -> TensorTree: + """Uses Neumann Series iteration to solve ``A^{-1}``. + + Args: + A: (tensor or tree of tensors or function) + 2D array or function that calculates the linear map (matrix-vector product) ``Ax`` when + called like ``A(x)``. ``A`` must represent a hermitian, positive definite matrix, and + must return array(s) with the same structure and shape as its argument. + maxiter: (integer, optional) + Maximum number of iterations. Iteration will stop after maxiter steps even if the + specified tolerance has not been achieved. + alpha: (float, optional) + Decay coefficient. + + Returns: + The Neumann Series (NS) matrix inversion approximation. + """ + if maxiter is None: + size = sum(cat_shapes(A)) + maxiter = 10 * size # copied from SciPy + + return pytree.tree_map(functools.partial(_ns_inv, maxiter=maxiter, alpha=alpha), A) diff --git a/torchopt/linalg/utils.py b/torchopt/linalg/utils.py new file mode 100644 index 00000000..f2440b9a --- /dev/null +++ b/torchopt/linalg/utils.py @@ -0,0 +1,55 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for linear algebra.""" + +import itertools +from typing import Callable, Tuple, Union + +import torch + +from torchopt import pytree +from torchopt.typing import TensorTree + + +def cat_shapes(tree: TensorTree) -> Tuple[int, ...]: + """Concatenates the shapes of the leaves of a tree of tensors.""" + leaves = pytree.tree_leaves(tree) + return tuple(itertools.chain.from_iterable(tuple(leaf.shape) for leaf in leaves)) + + +def normalize_matvec( + matvec: Union[TensorTree, Callable[[TensorTree], TensorTree]] +) -> Callable[[TensorTree], TensorTree]: + """Normalizes an argument for computing matrix-vector product.""" + if callable(matvec): + return matvec + + mat_flat, treespec = pytree.tree_flatten(matvec) + for mat in mat_flat: + if not isinstance(mat, torch.Tensor) or mat.ndim != 2 or mat.shape[0] != mat.shape[1]: + raise TypeError(f'Linear operator must be a square matrix, but has shape: {mat.shape}') + + def _matvec(x: TensorTree) -> TensorTree: + x_flat = pytree.tree_leaves(x) + if len(x_flat) != len(mat_flat): + raise ValueError( + f'`x` must have the same number of leaves as `matvec`, ' + f'but has {len(x_flat)} leaves and `matvec` has {len(mat_flat)} leaves' + ) + + y_flat = map(torch.matmul, mat_flat, x_flat) + return pytree.tree_unflatten(treespec, y_flat) + + return _matvec diff --git a/torchopt/linear_solve/__init__.py b/torchopt/linear_solve/__init__.py new file mode 100644 index 00000000..8d9115d3 --- /dev/null +++ b/torchopt/linear_solve/__init__.py @@ -0,0 +1,39 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/google/jaxopt/blob/main/jaxopt/_src/linear_solve.py +# ============================================================================== +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Linear algebra solvers.""" + +from torchopt.linear_solve.cg import solve_cg +from torchopt.linear_solve.inv import solve_inv +from torchopt.linear_solve.normal_cg import solve_normal_cg + + +__all__ = ['solve_cg', 'solve_normal_cg', 'solve_inv'] diff --git a/torchopt/linear_solve/cg.py b/torchopt/linear_solve/cg.py new file mode 100644 index 00000000..2ffc8217 --- /dev/null +++ b/torchopt/linear_solve/cg.py @@ -0,0 +1,107 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/google/jaxopt/blob/main/jaxopt/_src/linear_solve.py +# ============================================================================== +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Linear algebra solver for ``A x = b`` using conjugate gradient.""" + +# pylint: disable=invalid-name + +import functools +from typing import Callable, Optional + +from torchopt import linalg +from torchopt.linear_solve.utils import make_ridge_matvec +from torchopt.typing import TensorTree + + +__all__ = ['solve_cg'] + + +def _solve_cg( + matvec: Callable[[TensorTree], TensorTree], # (x) -> A @ x + b: TensorTree, + ridge: Optional[float] = None, + init: Optional[TensorTree] = None, + **kwargs, +) -> TensorTree: + """Solves ``A x = b`` using conjugate gradient. + + This assumes that ``A`` is a hermitian, positive definite matrix. + + Args: + matvec: A function that returns the product between ``A`` and a vector. + b: A tree of tensors for the right hand side of the equation. + ridge: Optional ridge regularization. + init: Optional initialization to be used by conjugate gradient. + **kwargs: Additional keyword arguments for the conjugate gradient solver. + + Returns: + The solution with the same structure as ``b``. + """ + if ridge is not None: + # (x) -> A @ x + ridge * x + # i.e. (x) -> (A + ridge * I) @ x + matvec = make_ridge_matvec(matvec, ridge=ridge) + + # Returns solution for `(A + ridge * I) @ x = b`. + return linalg.cg(matvec, b, x0=init, **kwargs) + + +def solve_cg(**kwargs): + """A wrapper that returns a solver function to solve ``A x = b`` using conjugate gradient. + + This assumes that ``A`` is a hermitian, positive definite matrix. + + Args: + ridge: Optional ridge regularization. Solves the equation for ``(A + ridge * I) @ x = b``. + init: Optional initialization to be used by conjugate gradient. + **kwargs: Additional keyword arguments for the conjugate gradient solver + :func:`torchopt.linalg.cg`. + + Returns: + A solver function with signature ``(matvec, b) -> x`` that solves ``A x = b`` using + conjugate gradient where ``matvec(v) = A v``. + + See Also: + Conjugate gradient iteration :func:`torchopt.linalg.cg`. + + Example:: + + >>> A = {'a': torch.eye(5, 5), 'b': torch.eye(3, 3)} + >>> x = {'a': torch.randn(5), 'b': torch.randn(3)} + >>> def matvec(x: TensorTree) -> TensorTree: + ... return {'a': A['a'] @ x['a'], 'b': A['b'] @ x['b']} + >>> b = matvec(x) + >>> solver = solve_cg(init={'a': torch.zeros(5), 'b': torch.zeros(3)}) + >>> x_hat = solver(matvec, b) + >>> assert torch.allclose(x_hat['a'], x['a']) and torch.allclose(x_hat['b'], x['b']) + + """ + return functools.partial(_solve_cg, **kwargs) diff --git a/torchopt/linear_solve/inv.py b/torchopt/linear_solve/inv.py new file mode 100644 index 00000000..bf36f40e --- /dev/null +++ b/torchopt/linear_solve/inv.py @@ -0,0 +1,122 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/google/jaxopt/blob/main/jaxopt/_src/linear_solve.py +# ============================================================================== +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Linear algebra solver for ``A x = b`` using matrix inversion.""" + +# pylint: disable=invalid-name + +import functools +from typing import Callable, Optional + +import torch + +from torchopt import linalg, pytree +from torchopt.linear_solve.utils import make_ridge_matvec, materialize_matvec +from torchopt.typing import TensorTree + + +__all__ = ['solve_inv'] + + +def _solve_inv( + matvec: Callable[[TensorTree], TensorTree], # (x) -> A @ x + b: TensorTree, + ridge: Optional[float] = None, + ns: bool = False, + **kwargs, +) -> TensorTree: + """Solves ``A x = b`` using matrix inversion. + + If ``ns = False``, this assumes the matrix ``A`` is a constant matrix and will materialize it + in memory. + + Args: + matvec: A function that returns the product between ``A`` and a vector. + b: A tensor for the right hand side of the equation. + ridge: Optional ridge regularization. Solves the equation for ``(A + ridge * I) @ x = b``. + ns: Whether to use Neumann Series matrix inversion approximation. If :data:`False`, + materialize the matrix ``A`` in memory and use :func:`torch.linalg.solve` instead. + **kwargs: Additional keyword arguments for the Neumann Series matrix inversion approximation + solver :func:`torchopt.linalg.ns`. + + Returns: + The solution with the same shape as ``b``. + """ + if ridge is not None: + # (x) -> A @ x + ridge * x + # i.e. (x) -> (A + ridge * I) @ x + matvec = make_ridge_matvec(matvec, ridge=ridge) + + b_flat = pytree.tree_leaves(b) + if len(b_flat) == 1 and b_flat[0].ndim == 0: + A, *_ = materialize_matvec(matvec, b) + return pytree.tree_truediv(b, A) + + if ns: + return linalg.ns(matvec, b, **kwargs) + + A, _, tree_ravel, tree_unravel = materialize_matvec(matvec, b) + return tree_unravel(pytree.tree_map(torch.linalg.solve, A, tree_ravel(b))) + + +def solve_inv(**kwargs): + """A wrapper that returns a solver function to solve ``A x = b`` using matrix inversion. + + If ``ns = False``, this assumes the matrix ``A`` is a constant matrix and will materialize it + in memory. + + Args: + ridge: Optional ridge regularization. Solves the equation for ``(A + ridge * I) @ x = b``. + ns: Whether to use Neumann Series matrix inversion approximation. If :data:`False`, + materialize the matrix ``A`` in memory and use :func:`torch.linalg.solve` instead. + **kwargs: Additional keyword arguments for the Neumann Series matrix inversion approximation + solver :func:`torchopt.linalg.ns`. + + Returns: + A solver function with signature ``(matvec, b) -> x`` that solves ``A x = b`` using matrix + inversion where ``matvec(v) = A v``. + + See Also: + Neumann Series matrix inversion approximation :func:`torchopt.linalg.ns`. + + Example:: + + >>> A = {'a': torch.eye(5, 5), 'b': torch.eye(3, 3)} + >>> x = {'a': torch.randn(5), 'b': torch.randn(3)} + >>> def matvec(x: TensorTree) -> TensorTree: + ... return {'a': A['a'] @ x['a'], 'b': A['b'] @ x['b']} + >>> b = matvec(x) + >>> solver = solve_inv(ns=True, maxiter=10) + >>> x_hat = solver(matvec, b) + >>> assert torch.allclose(x_hat['a'], x['a']) and torch.allclose(x_hat['b'], x['b']) + + """ + return functools.partial(_solve_inv, **kwargs) diff --git a/torchopt/linear_solve/normal_cg.py b/torchopt/linear_solve/normal_cg.py new file mode 100644 index 00000000..3646d7f4 --- /dev/null +++ b/torchopt/linear_solve/normal_cg.py @@ -0,0 +1,120 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/google/jaxopt/blob/main/jaxopt/_src/linear_solve.py +# ============================================================================== +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Linear algebra solver for ``A^T A x = A^T b`` using conjugate gradient.""" + +# pylint: disable=invalid-name + +import functools +from typing import Callable, Optional + +from torchopt import linalg +from torchopt.linear_solve.utils import make_normal_matvec, make_ridge_matvec, make_rmatvec +from torchopt.typing import TensorTree + + +__all__ = ['solve_normal_cg'] + + +def _solve_normal_cg( + matvec: Callable[[TensorTree], TensorTree], # (x) -> A @ x + b: TensorTree, + ridge: Optional[float] = None, + init: Optional[TensorTree] = None, + **kwargs, +) -> TensorTree: + """Solves the normal equation ``A^T A x = A^T b`` using conjugate gradient. + + This can be used to solve ``A x = b`` using conjugate gradient when ``A`` is not hermitian, + positive definite. + + Args: + matvec: A function that returns the product between ``A`` and a vector. + b: A tree of tensors for the right hand side of the equation. + ridge: Optional ridge regularization. Solves the equation for ``(A.T @ A + ridge * I) @ x = A.T @ b``. + init: Optional initialization to be used by normal conjugate gradient. + **kwargs: Additional keyword arguments for the conjugate gradient solver + :func:`torchopt.linalg.cg`. + + Returns: + The solution with the same structure as ``b``. + """ + if init is None: + example_x = b # This assumes that matvec is a square linear operator. + else: + example_x = init + + rmatvec = make_rmatvec(matvec, example_x) # (x) -> A.T @ x + normal_matvec = make_normal_matvec(matvec) # (x) -> A.T @ A @ x + + if ridge is not None: + # (x) -> A.T @ A @ x + ridge * x + # i.e. (x) -> (A.T @ A + ridge * I) @ x + normal_matvec = make_ridge_matvec(normal_matvec, ridge=ridge) + + rhs = rmatvec(b) # A.T @ b + + # Returns solution for `(A.T @ A + ridge * I) @ x = A.T @ b`. + return linalg.cg(normal_matvec, rhs, x0=init, **kwargs) + + +def solve_normal_cg(**kwargs): + """A wrapper that returns a solver function to solve ``A^T A x = A^T b`` using conjugate gradient. + + This can be used to solve ``A x = b`` using conjugate gradient when ``A`` is not hermitian, + positive definite. + + Args: + ridge: Optional ridge regularization. Solves the equation for ``(A.T @ A + ridge * I) @ x = A.T @ b``. + init: Optional initialization to be used by normal conjugate gradient. + **kwargs: Additional keyword arguments for the conjugate gradient solver + :func:`torchopt.linalg.cg`. + + Returns: + A solver function with signature ``(matvec, b) -> x`` that solves ``A^T A x = A^T b`` using + conjugate gradient where ``matvec(v) = A v``. + + See Also: + Conjugate gradient iteration :func:`torchopt.linalg.cg`. + + Example:: + + >>> A = {'a': torch.randn(5, 5), 'b': torch.randn(3, 3)} + >>> x = {'a': torch.randn(5), 'b': torch.randn(3)} + >>> def matvec(x: TensorTree) -> TensorTree: + ... return {'a': A['a'] @ x['a'], 'b': A['b'] @ x['b']} + >>> b = matvec(x) + >>> solver = solve_normal_cg(init={'a': torch.zeros(5), 'b': torch.zeros(3)}) + >>> x_hat = solver(matvec, b) + >>> assert torch.allclose(x_hat['a'], x['a']) and torch.allclose(x_hat['b'], x['b']) + + """ + return functools.partial(_solve_normal_cg, **kwargs) diff --git a/torchopt/linear_solve/utils.py b/torchopt/linear_solve/utils.py new file mode 100644 index 00000000..a7e93e65 --- /dev/null +++ b/torchopt/linear_solve/utils.py @@ -0,0 +1,114 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/google/jaxopt/blob/main/jaxopt/_src/linear_solve.py +# ============================================================================== +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for linear algebra solvers.""" + +from typing import Callable, Tuple + +import functorch + +from torchopt import pytree +from torchopt.typing import TensorTree + + +def make_rmatvec( + matvec: Callable[[TensorTree], TensorTree], example_x: TensorTree +) -> Callable[[TensorTree], TensorTree]: + """Returns a function that computes ``rmatvec(y) = A.T @ y`` from ``matvec(x) = A @ x``.""" + _, vjp, *_ = functorch.vjp(matvec, example_x) + + return lambda y: vjp(y)[0] + + +def make_normal_matvec( + matvec: Callable[[TensorTree], TensorTree] +) -> Callable[[TensorTree], TensorTree]: + """Returns a function that computes ``normal_matvec(y) = A.T @ A @ y`` from ``matvec(x) = A @ x``.""" + + def normal_matvec(y: TensorTree) -> TensorTree: + """Computes ``A.T @ A @ y`` from ``matvec(x) = A @ x``.""" + matvec_y, vjp, *_ = functorch.vjp(matvec, y) + return vjp(matvec_y)[0] + + return normal_matvec + + +def make_ridge_matvec( + matvec: Callable[[TensorTree], TensorTree], ridge: float = 0.0 +) -> Callable[[TensorTree], TensorTree]: + """Returns a function that computes ``ridge_matvec(y) = A.T @ A @ y + ridge * y`` from ``matvec(x) = A @ x``.""" + + def ridge_matvec(y: TensorTree) -> TensorTree: + """Computes ``A.T @ A @ v + ridge * v`` from ``matvec(x) = A @ x``.""" + return pytree.tree_add_scalar_mul(matvec(y), y, alpha=ridge) + + return ridge_matvec + + +def materialize_matvec( + matvec: Callable[[TensorTree], TensorTree], x: TensorTree +) -> Tuple[ + TensorTree, + Callable[[TensorTree], TensorTree], + Callable[[TensorTree], TensorTree], + Callable[[TensorTree], TensorTree], +]: + """Materializes the matrix ``A`` used in ``matvec(x) = A @ x``.""" + x_flat, treespec = pytree.tree_flatten(x) + shapes = tuple(t.shape for t in x_flat) + + if all(t.ndim == 1 for t in x_flat): + + def tree_ravel(x: TensorTree) -> TensorTree: + return x + + def tree_unravel(y: TensorTree) -> TensorTree: + return y + + matvec_ravel = matvec + + else: + + def tree_ravel(x: TensorTree) -> TensorTree: + return pytree.tree_map(lambda t: t.contiguous().view(-1), x) + + def tree_unravel(y: TensorTree) -> TensorTree: + shapes_iter = iter(shapes) + return pytree.tree_map(lambda t: t.contiguous().view(next(shapes_iter)), y) + + def matvec_ravel(y: TensorTree) -> TensorTree: + return tree_ravel(matvec(tree_unravel(y))) + + nargs = len(x_flat) + jacobian_tree = functorch.jacfwd(matvec_ravel)(tree_ravel(x)) + jacobian_flat = pytree.tree_leaves(jacobian_tree) + jacobian_diag = [jacobian_flat[i + i * nargs] for i in range(nargs)] + return pytree.tree_unflatten(treespec, jacobian_diag), matvec_ravel, tree_ravel, tree_unravel diff --git a/torchopt/nn/__init__.py b/torchopt/nn/__init__.py new file mode 100644 index 00000000..57a8e802 --- /dev/null +++ b/torchopt/nn/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for neural network modules that hold meta-parameters and meta-modules.""" + +from torchopt.diff.implicit.nn.module import ImplicitMetaGradientModule # circular reference +from torchopt.nn.module import MetaGradientModule + + +__all__ = ['MetaGradientModule', 'ImplicitMetaGradientModule'] diff --git a/torchopt/nn/module.py b/torchopt/nn/module.py new file mode 100644 index 00000000..4a1364f1 --- /dev/null +++ b/torchopt/nn/module.py @@ -0,0 +1,456 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for neural network modules that hold meta-parameters and meta-modules.""" + +from collections import OrderedDict +from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Set, Tuple, Union + +import torch +import torch.nn as nn + +from torchopt import pytree + + +class MetaInputsContainer(NamedTuple): + """Container for parameters and modules in the constructor input arguments.""" + + meta_parameters: Set[torch.Tensor] + meta_modules: Set[nn.Module] + + +class MetaGradientModule(nn.Module): # pylint: disable=abstract-method + """Base class for neural network modules that hold meta-parameters and meta-modules.""" + + _meta_inputs: MetaInputsContainer + _meta_parameters: Dict[str, Optional[torch.Tensor]] + _meta_modules: Dict[str, Optional[nn.Module]] + + def __new__(cls, *args, **kwargs) -> 'MetaGradientModule': + """Creates a new module instance.""" + instance = super().__new__(cls) + flat_args: List[Any] + flat_args = pytree.tree_leaves((args, kwargs)) # type: ignore[arg-type] + meta_parameters = {x for x in flat_args if isinstance(x, torch.Tensor) and x.requires_grad} + meta_modules = {x for x in flat_args if isinstance(x, nn.Module) and x.training} + for meta_module in tuple(meta_modules): + meta_parameters.update(meta_module.parameters()) + meta_modules.update(meta_module.modules()) + + instance._meta_inputs = MetaInputsContainer(meta_parameters, meta_modules) + instance._meta_parameters: Dict[str, Optional[torch.Tensor]] = OrderedDict() # type: ignore[misc] + instance._meta_modules: Dict[str, Optional[nn.Module]] = OrderedDict() # type: ignore[misc] + return instance + + def __getattr__(self, name: str) -> Union[torch.Tensor, nn.Module]: + """Gets an attribute of the module.""" + if '_parameters' in self.__dict__: + _parameters = self.__dict__['_parameters'] + if name in _parameters: + return _parameters[name] + if '_buffers' in self.__dict__: + _buffers = self.__dict__['_buffers'] + if name in _buffers: + return _buffers[name] + if '_modules' in self.__dict__: + modules = self.__dict__['_modules'] + if name in modules: + return modules[name] + if '_meta_parameters' in self.__dict__: + _meta_parameters = self.__dict__['_meta_parameters'] + if name in _meta_parameters: + return _meta_parameters[name] + if '_meta_modules' in self.__dict__: + _meta_modules = self.__dict__['_meta_modules'] + if name in _meta_modules: + return _meta_modules[name] + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + # pylint: disable-next=too-many-branches,too-many-statements + def __setattr__(self, name: str, value: Union[torch.Tensor, nn.Module]) -> None: + """Sets an attribute of the module.""" + + def remove_from(*dicts_or_sets): + for dict_or_set in dicts_or_sets: + if name in dict_or_set: + if isinstance(dict_or_set, dict): + del dict_or_set[name] + else: + dict_or_set.discard(name) + + params = self.__dict__.get('_parameters') + meta_params = self.__dict__.get('_meta_parameters') + if isinstance(value, torch.Tensor) and value.requires_grad: + if params is None: + raise AttributeError('cannot assign parameters before Module.__init__() call') + if meta_params is None: + raise AttributeError( + 'cannot assign meta-parameters before MetaGradientModule.__init__() call' + ) + remove_from( + self.__dict__, + self._buffers, + self._modules, + self._non_persistent_buffers_set, + self._meta_parameters, + self._meta_modules, + ) + if value in self._meta_inputs.meta_parameters: + self.register_meta_parameter(name, value) + else: + self.register_parameter(name, value) + elif params is not None and name in params: + if value is not None: + raise TypeError( + f"cannot assign '{torch.typename(value)}' as parameter '{name}' " + f'(torch.Tensor or None expected)' + ) + self.register_parameter(name, value) # type: ignore[unreachable] + elif meta_params is not None and name in meta_params: + if value is not None: + raise TypeError( + f"cannot assign '{torch.typename(value)}' as meta-parameter '{name}' " + f'(torch.Tensor or None expected)' + ) + self.register_meta_parameter(name, value) # type: ignore[unreachable] + else: + modules = self.__dict__.get('_modules') + meta_modules = self.__dict__.get('_meta_modules') + if isinstance(value, nn.Module): + if modules is None: + raise AttributeError('cannot assign module before Module.__init__() call') + if meta_modules is None: + raise AttributeError( + 'cannot assign module before MetaGradientModule.__init__() call' + ) + remove_from( + self.__dict__, + self._parameters, + self._buffers, + self._non_persistent_buffers_set, + self._meta_parameters, + self._meta_modules, + ) + if value in self._meta_inputs.meta_modules: + meta_modules[name] = value + else: + modules[name] = value + elif modules is not None and name in modules: + if value is not None: + raise TypeError( + f"cannot assign '{torch.typename(value)}' as child module '{name}' " + f'(torch.nn.Module or None expected)' + ) + modules[name] = value # type: ignore[unreachable] + else: + buffers = self.__dict__.get('_buffers') + if buffers is not None and name in buffers: + if value is not None and not isinstance(value, torch.Tensor): + raise TypeError( + f"cannot assign '{torch.typename(value)}' as buffer '{name}' " + f'(torch.Tensor or None expected)' + ) + buffers[name] = value + else: + object.__setattr__(self, name, value) + + def __delattr__(self, name: str) -> None: + """Deletes an attribute of the module.""" + if name in self._parameters: + del self._parameters[name] + elif name in self._buffers: + del self._buffers[name] + self._non_persistent_buffers_set.discard(name) + elif name in self._modules: + del self._modules[name] + elif name in self._meta_parameters: + del self._meta_parameters[name] + elif name in self._meta_modules: + del self._meta_modules[name] + else: + object.__delattr__(self, name) + + def register_parameter(self, name: str, param: Optional[torch.Tensor]) -> None: + r"""Adds a parameter to the module. + + The parameter can be accessed as an attribute using given name. + + Args: + name (string): name of the parameter. The parameter can be accessed + from this module using the given name + param (torch.Tensor or None): parameter to be added to the module. If + ``None``, then operations that run on parameters, such as :attr:`cuda`, + are ignored. If ``None``, the parameter is **not** included in the + module's :attr:`state_dict`. + """ + if '_parameters' not in self.__dict__: + raise AttributeError('cannot assign parameter before Module.__init__() call') + if not isinstance(name, str): + raise TypeError(f'parameter name should be a string. Got {torch.typename(name)}') + if '.' in name: + raise KeyError("parameter name can't contain \".\"") + if name == '': + raise KeyError("parameter name can't be empty string \"\"") + if hasattr(self, name) and name not in self._parameters: + raise KeyError(f"attribute '{name}' already exists") + + if param is None: + self._parameters[name] = None + return + + if not isinstance(param, torch.Tensor): + raise TypeError( + f"cannot assign '{torch.typename(param)}' object to parameter '{name}' " + f'(torch.Tensor or None required)' + ) + if not param.requires_grad: + raise ValueError( + f"cannot assign Tensor that `requires_grad=False` to parameter '{name}'" + ) + if param in self._meta_inputs.meta_parameters: + raise ValueError( + f"cannot assign Tensor that is a meta-parameter to parameter '{name}'. " + f'Use self.register_meta_parameter() instead.' + ) + + self._parameters[name] = param # type: ignore + + def register_meta_parameter(self, name: str, param: Optional[torch.Tensor]) -> None: + r"""Adds a meta-parameter to the module. + + The meta-parameter can be accessed as an attribute using given name. + + Args: + name (string): name of the parameter. The parameter can be accessed + from this module using the given name + param (torch.Tensor or None): parameter to be added to the module. If + ``None``, then operations that run on parameters, such as :attr:`cuda`, + are ignored. If ``None``, the parameter is **not** included in the + module's :attr:`state_dict`. + """ + if '_meta_parameters' not in self.__dict__: + raise AttributeError( + 'cannot assign meta-parameter before MetaGradientModule.__init__() call' + ) + if not isinstance(name, str): + raise TypeError(f'meta-parameter name should be a string. Got {torch.typename(name)}') + if '.' in name: + raise KeyError("meta-parameter name can't contain \".\"") + if name == '': + raise KeyError("meta-parameter name can't be empty string \"\"") + if hasattr(self, name) and name not in self._meta_parameters: + raise KeyError(f"attribute '{name}' already exists") + + if param is None: + self._meta_parameters[name] = None + return + + if not isinstance(param, torch.Tensor): + raise TypeError( + f"cannot assign '{torch.typename(param)}' object to meta-parameter '{name}' " + f'(torch.Tensor or None required)' + ) + if not param.requires_grad: + raise ValueError( + f"cannot assign Tensor that `requires_grad=False` to meta-parameter '{name}'" + ) + + self._meta_parameters[name] = param + + def add_module(self, name: str, module: Optional[nn.Module]) -> None: + r"""Adds a child module to the current module. + + The module can be accessed as an attribute using the given name. + + Args: + name (string): name of the child module. The child module can be + accessed from this module using the given name + module (Module): child module to be added to the module. + """ + if not isinstance(module, nn.Module) and module is not None: + raise TypeError(f'{torch.typename(module)} is not a Module subclass') + if not isinstance(name, str): + raise TypeError(f'module name should be a string. Got {torch.typename(name)}') + if hasattr(self, name) and name not in self._modules: + raise KeyError(f"attribute '{name}' already exists") + if '.' in name: + raise KeyError(f"module name can't contain \".\", got: {name}") + if name == '': + raise KeyError("module name can't be empty string \"\"") + if module in self._meta_inputs.meta_modules: + raise ValueError( + f"cannot add module that is a meta-module to module '{name}'. " + f'Use self.add_meta_module() instead.' + ) + + self._modules[name] = module + + def register_module(self, name: str, module: Optional[nn.Module]) -> None: + r"""Alias for :func:`add_module`.""" + self.add_module(name, module) + + def add_meta_module(self, name: str, meta_module: Optional[nn.Module]) -> None: + r"""Adds a child meta-module to the current module. + + The meta-module can be accessed as an attribute using the given name. + + Args: + name (string): name of the child meta-module. The child meta-module can be + accessed from this module using the given name + meta_module (Module): child meta-module to be added to the module. + """ + if not isinstance(meta_module, nn.Module) and meta_module is not None: + raise TypeError(f'{torch.typename(meta_module)} is not a Module subclass') + if not isinstance(name, str): + raise TypeError(f'meta-module name should be a string. Got {torch.typename(name)}') + if hasattr(self, name) and name not in self._meta_modules: + raise KeyError(f"attribute '{name}' already exists") + if '.' in name: + raise KeyError(f"meta-module name can't contain \".\", got: {name}") + if name == '': + raise KeyError("meta-module name can't be empty string \"\"") + + self._meta_modules[name] = meta_module + + def register_meta_module(self, name: str, meta_module: Optional[nn.Module]) -> None: + r"""Alias for :func:`add_meta_module`.""" + self.add_meta_module(name, meta_module) + + def meta_parameters(self, recurse: bool = True) -> Iterator[torch.Tensor]: + r"""Returns an iterator over module meta-parameters. + + This is typically passed to an optimizer. + + Args: + recurse (bool): if True, then yields parameters of this module and + all submodules. Otherwise, yields only meta-parameters that + are direct members of this module. + + Yields: + Parameter: module meta-parameter + + Example:: + + >>> for param in model.meta_parameters(): + >>> print(type(param), param.size()) + (20L,) + (20L, 1L, 5L, 5L) + + """ + for _, meta_param in self.named_meta_parameters(recurse=recurse): + yield meta_param + + def named_meta_parameters( + self, prefix: str = '', recurse: bool = True + ) -> Iterator[Tuple[str, torch.Tensor]]: + r"""Returns an iterator over module meta-parameters, yielding both the name of the meta-parameter as well as the meta-parameter itself. + + Args: + prefix (str): prefix to prepend to all meta-parameter names. + recurse (bool): if True, then yields meta-parameters of this module + and all submodules. Otherwise, yields only meta-parameters that + are direct members of this module. + + Yields: + (string, Parameter): Tuple containing the name and parameter + + Example:: + + >>> for name, meta_param in self.named_meta_parameters(): + >>> if name in ['bias']: + >>> print(meta_param.size()) + + """ # pylint: disable=line-too-long + memo = set() + for name, param in getattr(self, '_meta_parameters', {}).items(): + if param is None or param in memo: + continue + memo.add(param) + yield prefix + name, param + for name, meta_module in getattr(self, '_meta_modules', {}).items(): + if meta_module is None: + continue + submodule_prefix = prefix + name + yield from meta_module.named_parameters(submodule_prefix, recurse) + + def meta_children(self) -> Iterator[nn.Module]: + r"""Returns an iterator over immediate children meta-modules. + + Yields: + Module: a child meta-module + """ + for _, module in self.named_meta_children(): + yield module + + def named_meta_children(self) -> Iterator[Tuple[str, nn.Module]]: + r"""Returns an iterator over immediate children meta-modules, yielding both the name of the meta-module as well as the meta-module itself. + + Yields: + (string, Module): Tuple containing a name and child meta-module + + Example:: + + >>> for name, meta_module in model.named_meta_children(): + >>> if name in ['conv4', 'conv5']: + >>> print(meta_module) + + """ # pylint: disable=line-too-long + memo = set() + for name, meta_module in self._meta_modules.items(): + if meta_module is not None and meta_module not in memo: + memo.add(meta_module) + yield name, meta_module + + def meta_modules(self) -> Iterator[nn.Module]: + r"""Returns an iterator over all meta-modules in the network. + + Yields: + Module: a meta-module in the network + + Note: + Duplicate meta-modules are returned only once. + """ + for _, meta_module in self.named_meta_modules(): + yield meta_module + + def named_meta_modules( + self, memo: Optional[Set[nn.Module]] = None, prefix: str = '', remove_duplicate: bool = True + ) -> Iterator[Tuple[str, nn.Module]]: + r"""Returns an iterator over all meta-modules in the network, yielding both the name of the meta-module as well as the meta-module itself. + + Args: + memo: a memo to store the set of meta-modules already added to the result + prefix: a prefix that will be added to the name of the meta-module + remove_duplicate: whether to remove the duplicated meta-module instances in the result + or not + + Yields: + (string, Module): Tuple of name and meta-module + + Note: + Duplicate modules are returned only once. + """ # pylint: disable=line-too-long + if memo is None: + memo = set() + if self in memo: + return + + if remove_duplicate: + memo.add(self) + + for name, meta_module in self._meta_modules.items(): + if meta_module is None: + continue + submodule_prefix = prefix + ('.' if prefix else '') + name + yield from meta_module.named_modules(memo, submodule_prefix, remove_duplicate) diff --git a/torchopt/optim/__init__.py b/torchopt/optim/__init__.py new file mode 100644 index 00000000..b75da23c --- /dev/null +++ b/torchopt/optim/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""object oriented optimizer implementations.""" + +from torchopt.optim import meta +from torchopt.optim.adam import Adam +from torchopt.optim.adamw import AdamW +from torchopt.optim.base import Optimizer +from torchopt.optim.func import FuncOptimizer +from torchopt.optim.rmsprop import RMSProp, RMSprop +from torchopt.optim.sgd import SGD diff --git a/torchopt/_src/optimizer/adam.py b/torchopt/optim/adam.py similarity index 93% rename from torchopt/_src/optimizer/adam.py rename to torchopt/optim/adam.py index 6776408e..8fcdff90 100644 --- a/torchopt/_src/optimizer/adam.py +++ b/torchopt/optim/adam.py @@ -12,14 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Adam optimizer.""" from typing import Iterable, Tuple import torch -from torchopt._src.alias import adam -from torchopt._src.optimizer.base import Optimizer -from torchopt._src.typing import ScalarOrSchedule +from torchopt import alias +from torchopt.optim.base import Optimizer +from torchopt.typing import ScalarOrSchedule + + +__all__ = ['Adam'] class Adam(Optimizer): @@ -42,7 +46,7 @@ def __init__( eps_root: float = 0.0, maximize: bool = False, use_accelerated_op: bool = False, - ): + ) -> None: r"""The :meth:`init` function. Args: @@ -68,7 +72,7 @@ def __init__( """ super().__init__( params, - adam( + alias.adam( lr=lr, betas=betas, eps=eps, diff --git a/torchopt/_src/optimizer/adamw.py b/torchopt/optim/adamw.py similarity index 92% rename from torchopt/_src/optimizer/adamw.py rename to torchopt/optim/adamw.py index 886cd77a..24362d59 100644 --- a/torchopt/_src/optimizer/adamw.py +++ b/torchopt/optim/adamw.py @@ -12,15 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""AdamW optimizer.""" from typing import Any, Callable, Iterable, Optional, Tuple, Union import torch -from torchopt._src import base # pylint: disable=unused-import -from torchopt._src.alias import adamw -from torchopt._src.optimizer.base import Optimizer -from torchopt._src.typing import ScalarOrSchedule +from torchopt import alias +from torchopt.optim.base import Optimizer +from torchopt.typing import Params, ScalarOrSchedule + + +__all__ = ['AdamW'] class AdamW(Optimizer): @@ -41,10 +44,10 @@ def __init__( weight_decay: float = 1e-2, *, eps_root: float = 0.0, - mask: Optional[Union[Any, Callable[['base.Params'], Any]]] = None, + mask: Optional[Union[Any, Callable[[Params], Any]]] = None, maximize: bool = False, use_accelerated_op: bool = False, - ): + ) -> None: r"""The :meth:`init` function. Args: @@ -79,7 +82,7 @@ def __init__( """ super().__init__( params, - adamw( + alias.adamw( lr=lr, betas=betas, eps=eps, diff --git a/torchopt/_src/optimizer/base.py b/torchopt/optim/base.py similarity index 55% rename from torchopt/_src/optimizer/base.py rename to torchopt/optim/base.py index 99e18b36..dc933f30 100644 --- a/torchopt/_src/optimizer/base.py +++ b/torchopt/optim/base.py @@ -12,20 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""The base class for optimizers.""" -from typing import Iterable +from typing import Callable, Iterable, List, Optional, Sequence, Tuple import torch -from torchopt._src.base import GradientTransformation -from torchopt._src.update import apply_updates -from torchopt._src.utils import pytree +from torchopt import pytree +from torchopt.base import UninitializedState +from torchopt.typing import GradientTransformation, OptState, Params, TupleOfTensors +from torchopt.update import apply_updates + + +__all__ = ['Optimizer'] class Optimizer: """A base class for classic optimizers that similar to :class:`torch.optim.Optimizer`.""" - def __init__(self, params: Iterable[torch.Tensor], impl: GradientTransformation): + def __init__(self, params: Iterable[torch.Tensor], impl: GradientTransformation) -> None: r"""The :meth:`init` function. Args: @@ -37,16 +42,19 @@ def __init__(self, params: Iterable[torch.Tensor], impl: GradientTransformation) Note that using ``Optimizer(sgd())`` or ``Optimizer(chain(sgd()))`` is equivalent to :class:`torchopt.SGD`. """ - self.impl = impl - self.param_groups = [] # type: ignore - self.param_tree_groups = [] # type: ignore - self.state_groups = [] # type: ignore + if not isinstance(impl, GradientTransformation): + raise TypeError(f'{impl} (type: {type(impl).__name__}) is not a GradientTransformation') + + self.impl: GradientTransformation = impl + self.param_groups: List[TupleOfTensors] = [] + self.param_treespecs: List[pytree.PyTreeSpec] = [] + self.state_groups: List[OptState] = [] - if not isinstance(params, list): - params = list(params) + if not isinstance(params, (list, tuple)): + params = tuple(params) self.add_param_group(params) - def zero_grad(self, set_to_none: bool = False): + def zero_grad(self, set_to_none: bool = False) -> None: r"""Sets the gradients of all optimized :class:`torch.Tensor`\s to zero. The behavior is similar to :meth:`torch.optim.Optimizer.zero_grad`. @@ -54,39 +62,38 @@ def zero_grad(self, set_to_none: bool = False): Args: set_to_none (bool): Instead of setting to zero, set the ``grads`` to :data:`None`. """ - for group in self.param_groups: - if set_to_none: + if set_to_none: - def f(p): - p.grad = None + def f(p): + p.grad = None - else: + else: - def f(p): - if p.grad is None: - return - if p.grad.grad_fn is not None: - p.grad.detach_() - else: - p.grad.requires_grad_(False) - p.grad.zero_() + def f(p): + if p.grad is None: + return + if p.grad.grad_fn is not None: + p.grad.detach_() + else: + p.grad.requires_grad_(False) + p.grad.zero_() - pytree.tree_map(f, group) + pytree.tree_map_(f, self.param_groups) # type: ignore[arg-type] - def state_dict(self): + def state_dict(self) -> Tuple[OptState, ...]: """Returns the state of the optimizer.""" - return self.state_groups + return tuple(self.state_groups) - def load_state_dict(self, state_dict): + def load_state_dict(self, state_dict: Sequence[OptState]) -> None: """Loads the optimizer state. Args: - state_dict (dict): Optimizer state. Should be an object returned from a call to + state_dict: Optimizer state. Should be an object returned from a call to :meth:`state_dict`. """ - self.state_groups = state_dict + self.state_groups[:] = list(state_dict) - def step(self, closure=None): + def step(self, closure: Optional[Callable[[], torch.Tensor]] = None) -> Optional[torch.Tensor]: """Performs a single optimization step. The behavior is similar to :meth:`torch.optim.Optimizer.step`. @@ -103,17 +110,19 @@ def f(p): return p.grad for i, (params, state) in enumerate(zip(self.param_groups, self.state_groups)): - grads = pytree.tree_map(f, params) + if isinstance(state, UninitializedState): + state = self.impl.init(params) + grads = pytree.tree_map(f, params) # type: ignore[arg-type] updates, new_state = self.impl.update(grads, state, params=params, inplace=True) self.param_groups[i] = apply_updates(params, updates, inplace=True) self.state_groups[i] = new_state return loss - def add_param_group(self, params): + def add_param_group(self, params: Params) -> None: """Add a param group to the optimizer's :attr:`param_groups`.""" - params, params_tree = pytree.tree_flatten(params) - params = tuple(params) - self.param_groups.append(params) - self.param_tree_groups.append(params_tree) - self.state_groups.append(self.impl.init(params)) + flat_params: TupleOfTensors + flat_params, params_treespec = pytree.tree_flatten_as_tuple(params) + self.param_groups.append(flat_params) + self.param_treespecs.append(params_treespec) + self.state_groups.append(UninitializedState()) diff --git a/torchopt/optim/func/__init__.py b/torchopt/optim/func/__init__.py new file mode 100644 index 00000000..f14fc6ae --- /dev/null +++ b/torchopt/optim/func/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional optimizer wrappers.""" + +from torchopt.optim.func.base import FuncOptimizer diff --git a/torchopt/optim/func/base.py b/torchopt/optim/func/base.py new file mode 100644 index 00000000..b3125d19 --- /dev/null +++ b/torchopt/optim/func/base.py @@ -0,0 +1,104 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional optimizer wrappers.""" + +from typing import Optional + +import torch + +from torchopt.base import GradientTransformation, UninitializedState +from torchopt.typing import OptState, Params +from torchopt.update import apply_updates + + +__all__ = ['FuncOptimizer'] + + +class FuncOptimizer: # pylint: disable=too-few-public-methods + """A wrapper class to hold the functional optimizer. + + This wrapper makes it easier to maintain the optimizer states. The optimizer states are held by + the wrapper internally. The wrapper provides a :meth:`step` function to compute the gradients + and update the parameters. + + See Also: + - The functional Adam optimizer: :func:`torchopt.adam`. + - The functional AdamW optimizer: :func:`torchopt.adamw`. + - The functional RMSprop optimizer: :func:`torchopt.rmsprop`. + - The functional SGD optimizer: :func:`torchopt.sgd`. + """ + + def __init__(self, impl: GradientTransformation, *, inplace: bool = False) -> None: + """The :meth:`init` function. + + Args: + impl (GradientTransformation): A low level optimizer function, it could be a optimizer + function provided by `alias.py` or a customized `chain` provided by `combine.py`. + inplace (optional): (default: :data:`False`) + The default value of ``inplace`` for each optimization update. + """ + if not isinstance(impl, GradientTransformation): + raise TypeError(f'{impl} (type: {type(impl).__name__}) is not a GradientTransformation') + + self.impl: GradientTransformation = impl + self.optim_state: Optional[OptState] = UninitializedState() + self.inplace: bool = bool(inplace) + + def step( + self, + loss: torch.Tensor, + params: Params, + inplace: Optional[bool] = None, + ) -> Params: + r"""Compute the gradients of loss to the network parameters and update network parameters. + + Graph of the derivative will be constructed, allowing to compute higher order derivative + products. We use the differentiable optimizer (pass argument inplace=False) to scale the + gradients and update the network parameters without modifying tensors in-place. + + Args: + loss: (torch.Tensor) + loss that is used to compute the gradients to network parameters. + params: (tree of torch.Tensor) + An tree of :class:`torch.Tensor`\s. Specifies what tensors should be optimized. + inplace (optional): (default: :data:`None`) + Whether to update the parameters in-place. If :data:`None`, use the default value + specified in the constructor. + """ + if isinstance(self.optim_state, UninitializedState): + self.optim_state = self.impl.init(params) + + if inplace is None: + inplace = self.inplace + + # Step parameter only + grads = torch.autograd.grad(loss, params, create_graph=True, allow_unused=True) + updates, self.optim_state = self.impl.update( + grads, self.optim_state, params=params, inplace=inplace + ) + new_params = apply_updates(params, updates, inplace=inplace) + return new_params + + def state_dict(self) -> OptState: + """Extract the references of the optimizer states. + + Note that the states are references, so any in-place operations will change the states + inside :class:`FuncOptimizer` at the same time. + """ + return self.optim_state + + def load_state_dict(self, state_dict: OptState) -> None: + """Load the references of the optimizer states.""" + self.optim_state = state_dict diff --git a/torchopt/optim/meta/__init__.py b/torchopt/optim/meta/__init__.py new file mode 100644 index 00000000..ba486d6d --- /dev/null +++ b/torchopt/optim/meta/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Differentiable Meta-Optimizers.""" + +from torchopt.optim.meta.adam import MetaAdam +from torchopt.optim.meta.adamw import MetaAdamW +from torchopt.optim.meta.base import MetaOptimizer +from torchopt.optim.meta.rmsprop import MetaRMSProp, MetaRMSprop +from torchopt.optim.meta.sgd import MetaSGD diff --git a/torchopt/_src/optimizer/meta/adam.py b/torchopt/optim/meta/adam.py similarity index 91% rename from torchopt/_src/optimizer/meta/adam.py rename to torchopt/optim/meta/adam.py index 6b76f959..9340b513 100644 --- a/torchopt/_src/optimizer/meta/adam.py +++ b/torchopt/optim/meta/adam.py @@ -12,14 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Differentiable Adam optimizer.""" from typing import Tuple import torch.nn as nn -from torchopt._src.alias import adam -from torchopt._src.optimizer.meta.base import MetaOptimizer -from torchopt._src.typing import ScalarOrSchedule +from torchopt import alias +from torchopt.optim.meta.base import MetaOptimizer +from torchopt.typing import ScalarOrSchedule + + +__all__ = ['MetaAdam'] class MetaAdam(MetaOptimizer): @@ -33,7 +37,7 @@ class MetaAdam(MetaOptimizer): # pylint: disable-next=too-many-arguments def __init__( self, - net: nn.Module, + module: nn.Module, lr: ScalarOrSchedule = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, @@ -43,11 +47,11 @@ def __init__( moment_requires_grad: bool = True, maximize: bool = False, use_accelerated_op: bool = False, - ): + ) -> None: """The :meth:`init` function. Args: - net: (nn.Module) + module: (nn.Module) A network whose parameters should be optimized. lr: (default: :const:`1e-3`) This is a fixed global scaling factor. @@ -71,8 +75,8 @@ def __init__( If :data:`True` use our implemented fused operator. """ super().__init__( - net, - adam( + module, + alias.adam( lr=lr, betas=betas, eps=eps, diff --git a/torchopt/_src/optimizer/meta/adamw.py b/torchopt/optim/meta/adamw.py similarity index 91% rename from torchopt/_src/optimizer/meta/adamw.py rename to torchopt/optim/meta/adamw.py index c38f3c5c..70f3a80a 100644 --- a/torchopt/_src/optimizer/meta/adamw.py +++ b/torchopt/optim/meta/adamw.py @@ -12,15 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Differentiable AdamW optimizer.""" from typing import Any, Callable, Optional, Tuple, Union import torch.nn as nn -from torchopt._src import base # pylint: disable=unused-import -from torchopt._src.alias import adamw -from torchopt._src.optimizer.meta.base import MetaOptimizer -from torchopt._src.typing import ScalarOrSchedule +from torchopt import alias +from torchopt.optim.meta.base import MetaOptimizer +from torchopt.typing import Params, ScalarOrSchedule + + +__all__ = ['MetaAdamW'] class MetaAdamW(MetaOptimizer): @@ -34,22 +37,22 @@ class MetaAdamW(MetaOptimizer): # pylint: disable-next=too-many-arguments def __init__( self, - net: nn.Module, + module: nn.Module, lr: ScalarOrSchedule = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 1e-2, *, eps_root: float = 0.0, - mask: Optional[Union[Any, Callable[['base.Params'], Any]]] = None, + mask: Optional[Union[Any, Callable[[Params], Any]]] = None, moment_requires_grad: bool = False, maximize: bool = False, use_accelerated_op: bool = False, - ): + ) -> None: """The :meth:`init` function. Args: - net: (nn.Module) + module: (nn.Module) A network whose parameters should be optimized. lr: (default: :const:`1e-3`) This is a fixed global scaling factor. @@ -82,8 +85,8 @@ def __init__( If :data:`True` use our implemented fused operator. """ super().__init__( - net, - adamw( + module, + alias.adamw( lr=lr, betas=betas, eps=eps, diff --git a/torchopt/_src/optimizer/meta/base.py b/torchopt/optim/meta/base.py similarity index 56% rename from torchopt/_src/optimizer/meta/base.py rename to torchopt/optim/meta/base.py index eb5a70b1..5993ecc1 100644 --- a/torchopt/_src/optimizer/meta/base.py +++ b/torchopt/optim/meta/base.py @@ -12,23 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""The base class for differentiable meta-optimizers.""" + +from typing import Dict, List, Optional, Sequence, Tuple import torch import torch.nn as nn -from torchopt._src.base import GradientTransformation -from torchopt._src.update import apply_updates -from torchopt._src.utils import pytree +from torchopt import pytree +from torchopt.base import UninitializedState +from torchopt.typing import GradientTransformation, OptState, TupleOfTensors +from torchopt.update import apply_updates +from torchopt.utils import extract_module_containers + + +__all__ = ['MetaOptimizer'] class MetaOptimizer: """The base class for high-level differentiable optimizers.""" - def __init__(self, net: nn.Module, impl: GradientTransformation): + def __init__(self, module: nn.Module, impl: GradientTransformation) -> None: """The :meth:`init` function. Args: - net: (nn.Module) + module: (nn.Module) A network whose parameters should be optimized. impl: (GradientTransformation) A low level optimizer function, it could be a optimizer function provided by @@ -37,13 +45,16 @@ def __init__(self, net: nn.Module, impl: GradientTransformation): ``MetaOptimizer(chain(sgd(moment_requires_grad=True)))`` is equivalent to :class:`torchopt.MetaSGD`. """ - self.impl = impl - self.param_containers_groups = [] # type: ignore - self.state_groups = [] # type: ignore + if not isinstance(impl, GradientTransformation): + raise TypeError(f'{impl} (type: {type(impl).__name__}) is not a GradientTransformation') + + self.impl: GradientTransformation = impl + self.param_containers_groups: List[Tuple[Dict[str, Optional[torch.Tensor]], ...]] = [] + self.state_groups: List[OptState] = [] - self.add_param_group(net) + self.add_param_group(module) - def step(self, loss: torch.Tensor): + def step(self, loss: torch.Tensor) -> None: # pylint: disable=too-many-locals """Compute the gradients of the loss to the network parameters and update network parameters. Graph of the derivative will be constructed, allowing to compute higher order derivative @@ -53,40 +64,44 @@ def step(self, loss: torch.Tensor): Args: loss: (torch.Tensor) The loss that is used to compute the gradients to the network parameters. - """ # pylint: disable=line-too-long + """ # Step parameter only - for i, (param_container, new_state) in enumerate( + for i, (param_container, state) in enumerate( zip(self.param_containers_groups, self.state_groups) ): - flattened_params, container_treedef = pytree.tree_flatten(param_container) - flattened_params = tuple(flattened_params) + flat_params: TupleOfTensors + flat_params, container_treespec = pytree.tree_flatten_as_tuple(param_container) # type: ignore[arg-type] + if isinstance(state, UninitializedState): + state = self.impl.init(flat_params) grads = torch.autograd.grad( - loss, flattened_params, create_graph=True, allow_unused=True + loss, + flat_params, + create_graph=True, + allow_unused=True, ) updates, new_state = self.impl.update( grads, - new_state, - params=flattened_params, + state, + params=flat_params, inplace=False, ) self.state_groups[i] = new_state - flattened_new_params = apply_updates(flattened_params, updates, inplace=False) - new_params = pytree.tree_unflatten(container_treedef, flattened_new_params) + flat_new_params = apply_updates(flat_params, updates, inplace=False) + new_params: Tuple[ + Dict[str, Optional[torch.Tensor]], ... + ] = pytree.tree_unflatten( # type: ignore[assignment] + container_treespec, flat_new_params + ) for container, new_param in zip(param_container, new_params): container.update(new_param) - def add_param_group(self, net): + def add_param_group(self, module: nn.Module) -> None: """Add a param group to the optimizer's :attr:`state_groups`.""" - # pylint: disable-next=import-outside-toplevel,cyclic-import - from torchopt._src.utils import _extract_container - - net_container = _extract_container(net, with_buffer=False) - flattened_params = tuple(pytree.tree_leaves(net_container)) - optimizer_state = self.impl.init(flattened_params) - self.param_containers_groups.append(net_container) - self.state_groups.append(optimizer_state) + params_container = extract_module_containers(module, with_buffers=False)[0] + self.param_containers_groups.append(params_container) + self.state_groups.append(UninitializedState()) - def state_dict(self): + def state_dict(self) -> Tuple[OptState, ...]: """Extract the references of the optimizer states. Note that the states are references, so any in-place operations will change the states @@ -94,6 +109,6 @@ def state_dict(self): """ return tuple(self.state_groups) - def load_state_dict(self, state_dict): + def load_state_dict(self, state_dict: Sequence[OptState]) -> None: """Load the references of the optimizer states.""" self.state_groups[:] = list(state_dict) diff --git a/torchopt/_src/optimizer/meta/rmsprop.py b/torchopt/optim/meta/rmsprop.py similarity index 90% rename from torchopt/_src/optimizer/meta/rmsprop.py rename to torchopt/optim/meta/rmsprop.py index 20183236..47c3e983 100644 --- a/torchopt/_src/optimizer/meta/rmsprop.py +++ b/torchopt/optim/meta/rmsprop.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Differentiable RMSProp optimizer.""" import torch.nn as nn -from torchopt._src.alias import rmsprop -from torchopt._src.optimizer.meta.base import MetaOptimizer -from torchopt._src.typing import ScalarOrSchedule +from torchopt import alias +from torchopt.optim.meta.base import MetaOptimizer +from torchopt.typing import ScalarOrSchedule + + +__all__ = ['MetaRMSProp', 'MetaRMSprop'] class MetaRMSProp(MetaOptimizer): @@ -31,7 +35,7 @@ class MetaRMSProp(MetaOptimizer): # pylint: disable-next=too-many-arguments def __init__( self, - net: nn.Module, + module: nn.Module, lr: ScalarOrSchedule = 1e-2, alpha: float = 0.99, eps: float = 1e-8, @@ -42,11 +46,11 @@ def __init__( initial_scale: float = 0.0, nesterov: bool = False, maximize: bool = False, - ): + ) -> None: """The :meth:`init` function. Args: - net: (nn.Module) + module: (nn.Module) A network whose parameters should be optimized. lr: (default: :const:`1e-2`) This is a fixed global scaling factor. @@ -72,8 +76,8 @@ def __init__( Maximize the params based on the objective, instead of minimizing. """ super().__init__( - net, - rmsprop( + module, + alias.rmsprop( lr=lr, alpha=alpha, eps=eps, diff --git a/torchopt/_src/optimizer/meta/sgd.py b/torchopt/optim/meta/sgd.py similarity index 89% rename from torchopt/_src/optimizer/meta/sgd.py rename to torchopt/optim/meta/sgd.py index b8ae5d24..f46158a6 100644 --- a/torchopt/_src/optimizer/meta/sgd.py +++ b/torchopt/optim/meta/sgd.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Differentiable SGD optimizer.""" import torch.nn as nn -from torchopt._src.alias import sgd -from torchopt._src.optimizer.meta.base import MetaOptimizer -from torchopt._src.typing import ScalarOrSchedule +from torchopt import alias +from torchopt.optim.meta.base import MetaOptimizer +from torchopt.typing import ScalarOrSchedule + + +__all__ = ['MetaSGD'] class MetaSGD(MetaOptimizer): @@ -31,7 +35,7 @@ class MetaSGD(MetaOptimizer): # pylint: disable-next=too-many-arguments def __init__( self, - net: nn.Module, + module: nn.Module, lr: ScalarOrSchedule, momentum: float = 0.0, weight_decay: float = 0.0, @@ -39,11 +43,11 @@ def __init__( nesterov: bool = False, moment_requires_grad: bool = True, maximize: bool = False, - ): + ) -> None: """The :meth:`init` function. Args: - net: (nn.Module) + module: (nn.Module) A network whose parameters should be optimized. lr: This is a fixed global scaling factor. momentum: (default: :const:`0.0`) @@ -62,8 +66,8 @@ def __init__( Maximize the params based on the objective, instead of minimizing. """ super().__init__( - net, - sgd( + module, + alias.sgd( lr=lr, momentum=momentum, weight_decay=weight_decay, diff --git a/torchopt/_src/optimizer/rmsprop.py b/torchopt/optim/rmsprop.py similarity index 93% rename from torchopt/_src/optimizer/rmsprop.py rename to torchopt/optim/rmsprop.py index 3b8634f3..dc649722 100644 --- a/torchopt/_src/optimizer/rmsprop.py +++ b/torchopt/optim/rmsprop.py @@ -12,14 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""RMSProp optimizer.""" from typing import Iterable import torch -from torchopt._src.alias import rmsprop -from torchopt._src.optimizer.base import Optimizer -from torchopt._src.typing import ScalarOrSchedule +from torchopt import alias +from torchopt.optim.base import Optimizer +from torchopt.typing import ScalarOrSchedule + + +__all__ = ['RMSProp', 'RMSprop'] class RMSProp(Optimizer): @@ -44,7 +48,7 @@ def __init__( initial_scale: float = 0.0, nesterov: bool = False, maximize: bool = False, - ): + ) -> None: r"""The `init` function. Args: @@ -75,7 +79,7 @@ def __init__( """ super().__init__( params, - rmsprop( + alias.rmsprop( lr=lr, alpha=alpha, eps=eps, diff --git a/torchopt/_src/optimizer/sgd.py b/torchopt/optim/sgd.py similarity index 92% rename from torchopt/_src/optimizer/sgd.py rename to torchopt/optim/sgd.py index a7f415f6..d83786ae 100644 --- a/torchopt/_src/optimizer/sgd.py +++ b/torchopt/optim/sgd.py @@ -12,14 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""SGD optimizer.""" from typing import Iterable import torch -from torchopt._src.alias import sgd -from torchopt._src.optimizer.base import Optimizer -from torchopt._src.typing import ScalarOrSchedule +from torchopt import alias +from torchopt.optim.base import Optimizer +from torchopt.typing import ScalarOrSchedule + + +__all__ = ['SGD'] class SGD(Optimizer): @@ -40,7 +44,7 @@ def __init__( dampening: float = 0.0, nesterov: bool = False, maximize: bool = False, - ): + ) -> None: r"""The :meth:`init` function. Args: @@ -61,7 +65,7 @@ def __init__( """ super().__init__( params, - sgd( + alias.sgd( lr=lr, momentum=momentum, weight_decay=weight_decay, diff --git a/torchopt/py.typed b/torchopt/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/torchopt/pytree.py b/torchopt/pytree.py new file mode 100644 index 00000000..0308b825 --- /dev/null +++ b/torchopt/pytree.py @@ -0,0 +1,193 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The PyTree utilities.""" + +import functools +import operator +from typing import Callable, List, Optional, Tuple + +import optree +import optree.typing as typing # pylint: disable=unused-import +import torch +import torch.distributed.rpc as rpc +from optree import * # pylint: disable=wildcard-import,unused-wildcard-import + +from torchopt.typing import Future, RRef, Scalar, T, TensorTree + + +__all__ = [ + *optree.__all__, + 'tree_flatten_as_tuple', + 'tree_pos', + 'tree_neg', + 'tree_add', + 'tree_add_scalar_mul', + 'tree_sub', + 'tree_sub_scalar_mul', + 'tree_mul', + 'tree_matmul', + 'tree_scalar_mul', + 'tree_truediv', + 'tree_vdot_real', + 'tree_wait', +] + + +def tree_flatten_as_tuple( + tree: PyTree[T], + is_leaf: Optional[Callable[[T], bool]] = None, + *, + none_is_leaf: bool = False, + namespace: str = '', +) -> Tuple[Tuple[T, ...], PyTreeSpec]: + """Flatten a pytree to a tuple of leaves and a PyTreeSpec. + + Args: + tree: The pytree to flatten. + is_leaf: A function that returns :data:`True` if a given node is a leaf. + none_is_leaf: If :data:`True`, None is considered a leaf rather than a internal node with no + children. + namespace: The namespace of custom tree node types. + + Returns: + A tuple of (leaves, treespec). + """ + leaves, treespec = tree_flatten(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) + return tuple(leaves), treespec + + +def acc_add(*args: T) -> T: + """Accumulate addition.""" + return functools.reduce(operator.add, args) + + +def acc_mul(*args: T) -> T: + """Accumulate multiplication.""" + return functools.reduce(operator.mul, args) + + +def acc_matmul(*args: T) -> T: + """Accumulate matrix multiplication.""" + return functools.reduce(operator.matmul, args) + + +def tree_pos(tree: PyTree[T]) -> PyTree[T]: + """Applies `operator.pos` over leaves.""" + return tree_map(operator.pos, tree) + + +def tree_neg(tree: PyTree[T]) -> PyTree[T]: + """Applies `operator.neg` over leaves.""" + return tree_map(operator.neg, tree) + + +def tree_add(*trees: PyTree[T]) -> PyTree[T]: + """Tree addition over leaves.""" + return tree_map(acc_add, *trees) + + +def tree_add_scalar_mul( + tree_x: TensorTree, tree_y: TensorTree, alpha: Optional[Scalar] = None +) -> TensorTree: + """Computes tree_x + alpha * tree_y.""" + if alpha is None: + return tree_map(lambda x, y: x.add(y), tree_x, tree_y) + return tree_map(lambda x, y: x.add(y, alpha=alpha), tree_x, tree_y) + + +def tree_sub(minuend_tree: PyTree[T], subtrahend_tree: PyTree[T]) -> PyTree[T]: + """Tree subtraction over leaves.""" + return tree_map(operator.sub, minuend_tree, subtrahend_tree) + + +def tree_sub_scalar_mul( + tree_x: TensorTree, tree_y: TensorTree, alpha: Optional[Scalar] = None +) -> TensorTree: + """Computes tree_x - alpha * tree_y.""" + if alpha is None: + return tree_map(lambda x, y: x.sub(y), tree_x, tree_y) + return tree_map(lambda x, y: x.sub(y, alpha=alpha), tree_x, tree_y) + + +def tree_mul(*trees: PyTree[T]) -> PyTree[T]: + """Tree multiplication over leaves.""" + return tree_map(acc_mul, *trees) + + +def tree_matmul(*trees: PyTree[T]) -> PyTree[T]: + """Tree matrix multiplication over leaves.""" + return tree_map(acc_matmul, *trees) + + +def tree_scalar_mul(scalar: Scalar, multiplicand_tree: PyTree[T]) -> PyTree[T]: + """Tree scalar multiplication over leaves.""" + return tree_map(lambda x: scalar * x, multiplicand_tree) + + +def tree_truediv(dividend_tree: PyTree[T], divisor_tree: PyTree[T]) -> PyTree[T]: + """Tree division over leaves.""" + return tree_map(operator.truediv, dividend_tree, divisor_tree) + + +def _vdot_real_kernel(x: torch.Tensor, y: torch.Tensor) -> float: + """Computes dot(x.conj(), y).real.""" + x = x.contiguous().view(-1) + y = y.contiguous().view(-1) + vdot = torch.dot(x.real, y.real).item() + if x.is_complex() and y.is_complex(): + vdot += torch.dot(x.imag, y.imag).item() + return vdot + + +def tree_vdot_real(tree_x: TensorTree, tree_y: TensorTree) -> float: + """Computes dot(tree_x.conj(), tree_y).real.sum().""" + leaves_x, treespec = tree_flatten(tree_x) + leaves_y = treespec.flatten_up_to(tree_y) + return sum(map(_vdot_real_kernel, leaves_x, leaves_y)) # type: ignore[arg-type] + + +def tree_wait(future_tree: PyTree[Future[T]]) -> PyTree[T]: + r"""Convert a tree of :class:`Future`\s to a tree of results.""" + futures, treespec = tree_flatten(future_tree) + + results = torch.futures.wait_all(futures) + + return tree_unflatten(treespec, results) + + +if rpc.is_available(): + + def tree_as_rref(tree: PyTree[T]) -> PyTree[RRef[T]]: + r"""Convert a tree of local objects to a tree of :class:`RRef`\s.""" + # pylint: disable-next=import-outside-toplevel,redefined-outer-name,reimported + from torch.distributed.rpc import RRef + + return tree_map(RRef, tree) + + def tree_to_here( + rref_tree: PyTree[RRef[T]], + timeout: float = rpc.api.UNSET_RPC_TIMEOUT, + ) -> PyTree[T]: + r"""Convert a tree of :class:`RRef`\s to a tree of local objects.""" + return tree_map(lambda x: x.to_here(timeout=timeout), rref_tree) + + def tree_local_value(rref_tree: PyTree[RRef[T]]) -> PyTree[T]: + r"""Return the local value of a tree of :class:`RRef`\s.""" + return tree_map(lambda x: x.local_value(), rref_tree) + + __all__.extend(['tree_as_rref', 'tree_to_here']) + + +del Callable, List, Optional, Tuple, optree, rpc, Scalar, T, RRef diff --git a/torchopt/schedule/__init__.py b/torchopt/schedule/__init__.py new file mode 100644 index 00000000..46f59550 --- /dev/null +++ b/torchopt/schedule/__init__.py @@ -0,0 +1,37 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/schedule.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Learning rate schedules.""" + +from torchopt.schedule.polynomial import linear_schedule, polynomial_schedule + + +__all__ = ['polynomial_schedule', 'linear_schedule'] diff --git a/torchopt/_src/schedule.py b/torchopt/schedule/polynomial.py similarity index 87% rename from torchopt/_src/schedule.py rename to torchopt/schedule/polynomial.py index d7367c2b..8d2c2056 100644 --- a/torchopt/_src/schedule.py +++ b/torchopt/schedule/polynomial.py @@ -29,14 +29,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Polynomial learning rate schedules.""" import logging import numpy as np +import torch -from torchopt._src import base -from torchopt._src.typing import Scalar -from torchopt._src.utils import pytree +from torchopt.typing import Numeric, Scalar, Schedule + + +__all__ = ['polynomial_schedule', 'linear_schedule'] def polynomial_schedule( @@ -45,7 +48,7 @@ def polynomial_schedule( power: Scalar, transition_steps: int, transition_begin: int = 0, -) -> base.Schedule: +) -> Schedule: """Constructs a schedule with polynomial transition from init to end value. Args: @@ -80,13 +83,11 @@ def polynomial_schedule( ) transition_begin = 0 - def schedule(count): - def impl(count): - count = np.clip(count - transition_begin, 0, transition_steps) - frac = 1 - count / transition_steps - return (init_value - end_value) * (frac**power) + end_value - - return pytree.tree_map(impl, count) + def schedule(count: Numeric) -> Numeric: + clip = torch.clamp if isinstance(count, torch.Tensor) else np.clip + count = clip(count - transition_begin, 0, transition_steps) # type: ignore[operator] + frac = 1.0 - count / transition_steps + return (init_value - end_value) * (frac**power) + end_value return schedule @@ -97,7 +98,7 @@ def linear_schedule( end_value: Scalar, transition_steps: int, transition_begin: int = 0, -) -> base.Schedule: +) -> Schedule: """Alias polynomial schedule to linear schedule for convenience.""" return polynomial_schedule( init_value=init_value, diff --git a/torchopt/transform/__init__.py b/torchopt/transform/__init__.py new file mode 100644 index 00000000..07c1a8e9 --- /dev/null +++ b/torchopt/transform/__init__.py @@ -0,0 +1,54 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/transform.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset transformations.""" + +from torchopt.transform.add_decayed_weights import add_decayed_weights +from torchopt.transform.nan_to_num import nan_to_num +from torchopt.transform.scale import scale +from torchopt.transform.scale_by_adam import scale_by_accelerated_adam, scale_by_adam +from torchopt.transform.scale_by_rms import scale_by_rms +from torchopt.transform.scale_by_schedule import scale_by_schedule +from torchopt.transform.scale_by_stddev import scale_by_stddev +from torchopt.transform.trace import trace + + +__all__ = [ + 'trace', + 'scale', + 'scale_by_schedule', + 'add_decayed_weights', + 'scale_by_adam', + 'scale_by_accelerated_adam', + 'scale_by_rms', + 'scale_by_stddev', + 'nan_to_num', +] diff --git a/torchopt/transform/add_decayed_weights.py b/torchopt/transform/add_decayed_weights.py new file mode 100644 index 00000000..700e9c7b --- /dev/null +++ b/torchopt/transform/add_decayed_weights.py @@ -0,0 +1,228 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/transform.py +# https://github.com/deepmind/optax/blob/master/optax/_src/wrappers.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset transformations for adding weight decay to updates.""" + +from typing import Any, Callable, NamedTuple, Optional, Union + +from torchopt import pytree +from torchopt.base import EmptyState, GradientTransformation, identity +from torchopt.transform.utils import tree_map_flat +from torchopt.typing import Params + + +__all__ = ['masked', 'add_decayed_weights'] + + +class MaskedState(NamedTuple): + """Maintains inner transform state for masked transformations.""" + + inner_state: Any + + +class MaskedNode(NamedTuple): + """A node used to mask out unspecified parts of a tree. + + This node is ignored when mapping functions across the tree e.g. using :func:`pytree.tree_map` + since it is a container without children. It can therefore be used to mask out parts of a tree. + """ + + +def masked( + inner: GradientTransformation, + mask: Union[Any, Callable[[Params], Any]], +) -> GradientTransformation: + """Mask updates so only some are transformed, the rest are passed through. + + For example, it is common to skip weight decay for BatchNorm scale and all bias parameters. In + many networks, these are the only parameters with only one dimension. So, you may create a mask + function to mask these out as follows:: + mask_fn = lambda p: pytree.tree_map(lambda x: x.ndim != 1, p) + weight_decay = torchopt.masked(torchopt.add_decayed_weights(0.001), mask_fn) + You may alternatively create the mask pytree upfront:: + mask = pytree.tree_map(lambda x: x.ndim != 1, params) + weight_decay = torchopt.masked(torchopt.add_decayed_weights(0.001), mask) + For the ``inner`` transform, state will only be stored for the parameters that have a mask value + of :data:`True`. + + Args: + inner: Inner transformation to mask. + mask: A tree with same structure as (or a prefix of) the params tree, or a Callable that + returns such a tree given the params/updates. The leaves should be booleans, :data:`True` + for leaves/subtrees you want to apply the transformation to, and :data:`False` for those + you want to skip. The mask must be static for the gradient transformation to be jit-compilable. + + Returns: + A :class:`GradientTransformation` wrapping ``inner``. + """ + return _masked(inner=inner, mask=mask, already_flattened=False) + + +def _masked_flat( + inner: GradientTransformation, + mask: Union[Any, Callable[[Params], Any]], +) -> GradientTransformation: + return _masked(inner, mask, already_flattened=True) + + +def _masked( + inner: GradientTransformation, + mask: Union[Any, Callable[[Params], Any]], + *, + already_flattened: bool = False, +) -> GradientTransformation: + + if already_flattened: + tree_map = tree_map_flat + else: + tree_map = pytree.tree_map # type: ignore[assignment] + + def tree_mask(params, mask_tree): + return tree_map(lambda p, m: p if m else MaskedNode(), params, mask_tree) + + def init_fn(params): + mask_tree = mask(params) if callable(mask) else mask + masked_params = tree_mask(params, mask_tree) + return MaskedState(inner_state=inner.init(masked_params)) + + def update_fn(updates, state, params=None, inplace=True): # pylint: disable=unused-argument + mask_tree = mask(updates) if callable(mask) else mask + masked_updates = tree_mask(updates, mask_tree) + masked_params = None if params is None else tree_mask(params, mask_tree) + + new_masked_updates, new_inner_state = inner.update( + masked_updates, state.inner_state, params=masked_params, inplace=inplace + ) + + new_updates = tree_map( + lambda new_u, old_u, m: new_u if m else old_u, new_masked_updates, updates, mask_tree + ) + return new_updates, MaskedState(inner_state=new_inner_state) + + return GradientTransformation(init_fn, update_fn) + + +masked.flat = _masked_flat # type: ignore[attr-defined] +masked.impl = _masked # type: ignore[attr-defined] + + +AddDecayedWeightsState = EmptyState + + +def add_decayed_weights( + weight_decay: float = 0.0, + mask: Optional[Union[Any, Callable[[Params], Any]]] = None, +) -> GradientTransformation: + """Add parameter scaled by `weight_decay`. + + Args: + weight_decay: a scalar weight decay rate. + mask: a tree with same structure as (or a prefix of) the params tree, or a Callable that + returns such a pytree given the params/updates. The leaves should be booleans, + :data:`True` for leaves/subtrees you want to apply the transformation to, and + :data:`False` for those you want to skip. + + Returns: + An (init_fn, update_fn) tuple. + """ + return _add_decayed_weights( + weight_decay=weight_decay, + mask=mask, + already_flattened=False, + ) + + +def _add_decayed_weights_flat( + weight_decay: float = 0.0, + mask: Optional[Union[Any, Callable[[Params], Any]]] = None, +) -> GradientTransformation: + return _add_decayed_weights( + weight_decay=weight_decay, + mask=mask, + already_flattened=True, + ) + + +def _add_decayed_weights( + weight_decay: float = 0.0, + mask: Optional[Union[Any, Callable[[Params], Any]]] = None, + *, + already_flattened: bool = False, +) -> GradientTransformation: + if not 0.0 <= weight_decay: # pylint: disable=unneeded-not + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + + if weight_decay == 0.0 and mask is None: + return identity() + + if already_flattened: + tree_map = tree_map_flat + else: + tree_map = pytree.tree_map # type: ignore[assignment] + + def init_fn(params): # pylint: disable=unused-argument + return AddDecayedWeightsState() + + def update_fn(updates, state, params=None, inplace=True): # pylint: disable=unused-argument + assert params is not None, ( + 'Parameters are required for weight decay. ' + 'Call `update(updates, state, params=params)` instead.' + ) + + if inplace: + + def f(g, p): + if g.requires_grad: + return g.add_(p, alpha=weight_decay) + return g.add_(p.data, alpha=weight_decay) + + else: + + def f(g, p): + return g.add(p, alpha=weight_decay) + + updates = tree_map(f, updates, params) + return updates, state + + # If mask is not `None`, apply mask to the gradient transformation. + # E.g. it is common to skip weight decay on bias units and batch stats. + if mask is not None: + return masked.impl( # type: ignore[attr-defined] + inner=GradientTransformation(init_fn, update_fn), + mask=mask, + already_flattened=already_flattened, + ) + return GradientTransformation(init_fn, update_fn) + + +add_decayed_weights.flat = _add_decayed_weights_flat # type: ignore[attr-defined] +add_decayed_weights.impl = _add_decayed_weights # type: ignore[attr-defined] diff --git a/torchopt/transform/nan_to_num.py b/torchopt/transform/nan_to_num.py new file mode 100644 index 00000000..11890c1b --- /dev/null +++ b/torchopt/transform/nan_to_num.py @@ -0,0 +1,49 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset transformations that replaces updates with non-finite values to the given numbers.""" + +from typing import Optional + +from torchopt import pytree +from torchopt.base import EmptyState, GradientTransformation + + +def nan_to_num( + nan: float = 0.0, posinf: Optional[float] = None, neginf: Optional[float] = None +) -> GradientTransformation: + """Replaces updates with values ``nan`` / ``+inf`` / ``-inf`` to the given numbers. + + Returns: + An ``(init_fn, update_fn)`` tuple. + """ + + def init_fn(params): # pylint: disable=unused-argument + return EmptyState() + + def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument + if inplace: + + def f(g): + return g.nan_to_num_(nan=nan, posinf=posinf, neginf=neginf) + + else: + + def f(g): + return g.nan_to_num(nan=nan, posinf=posinf, neginf=neginf) + + new_updates = pytree.tree_map(f, updates) + return new_updates, state + + return GradientTransformation(init_fn, update_fn) diff --git a/torchopt/transform/scale.py b/torchopt/transform/scale.py new file mode 100644 index 00000000..828b4b2f --- /dev/null +++ b/torchopt/transform/scale.py @@ -0,0 +1,88 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/transform.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset transformation for scaling updates by learning rate.""" + +from torchopt import pytree +from torchopt.base import EmptyState, GradientTransformation +from torchopt.transform.utils import tree_map_flat + + +__all__ = ['scale'] + + +ScaleState = EmptyState + + +def scale(step_size: float) -> GradientTransformation: + """Scale updates by some fixed scalar ``step_size``. + + Args: + step_size: A scalar corresponding to a fixed scaling factor for updates. + + Returns: + An ``(init_fn, update_fn)`` tuple. + """ + return _scale(step_size=step_size, already_flattened=False) + + +def _scale_flat(step_size: float) -> GradientTransformation: + return _scale(step_size=step_size, already_flattened=True) + + +def _scale(step_size: float, *, already_flattened: bool = False) -> GradientTransformation: + if already_flattened: + tree_map = tree_map_flat + else: + tree_map = pytree.tree_map # type: ignore[assignment] + + def init_fn(params): # pylint: disable=unused-argument + return ScaleState() + + def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument + if inplace: + + def f(g): + return g.mul_(step_size) + + else: + + def f(g): + return g.mul(step_size) + + updates = tree_map(f, updates) + return updates, state + + return GradientTransformation(init_fn, update_fn) + + +scale.flat = _scale_flat # type: ignore[attr-defined] +scale.impl = _scale # type: ignore[attr-defined] diff --git a/torchopt/transform/scale_by_adam.py b/torchopt/transform/scale_by_adam.py new file mode 100644 index 00000000..f0065712 --- /dev/null +++ b/torchopt/transform/scale_by_adam.py @@ -0,0 +1,316 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/transform.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset transformations for scaling updates by Adam.""" + +# pylint: disable=invalid-name + +from typing import NamedTuple + +import torch + +from torchopt import pytree +from torchopt.accelerated_op import AdamOp +from torchopt.base import GradientTransformation +from torchopt.transform.utils import inc_count, tree_map_flat, update_moment +from torchopt.typing import SequenceOfTensors, Updates + + +__all__ = ['scale_by_adam', 'scale_by_accelerated_adam'] + + +TRIPLE_PYTREE_SPEC = pytree.tree_structure((0, 1, 2)) # type: ignore[arg-type] + + +class ScaleByAdamState(NamedTuple): + """State for the Adam algorithm.""" + + mu: Updates + nu: Updates + count: SequenceOfTensors # type: ignore + + +def _bias_correction(moment, decay, count, *, already_flattened=False): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + + def f(t, c): # pylint: disable=invalid-name + return t.div(1 - decay**c) + + if already_flattened: + return tree_map_flat(f, moment, count) + return pytree.tree_map(f, moment, count) + + +def scale_by_adam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + moment_requires_grad: bool = False, +) -> GradientTransformation: + """Rescale updates according to the Adam algorithm. + + References: + [Kingma et al, 2014](https://arxiv.org/abs/1412.6980) + + Args: + b1: (default: :const:`0.9`) + Decay rate for the exponentially weighted average of grads. + b2: (default: :const:`0.999`) + Decay rate for the exponentially weighted average of squared grads. + eps: (default: :const:`1e-8`) + Term added to the denominator to improve numerical stability. + eps_root: (default: :const:`0.0`) + Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + moment_requires_grad: (default: :data:`False`) + If :data:`True`, states will be created with flag `requires_grad = True`. + + Returns: + An (init_fn, update_fn) tuple. + """ + return _scale_by_adam( + b1=b1, + b2=b2, + eps=eps, + eps_root=eps_root, + moment_requires_grad=moment_requires_grad, + already_flattened=False, + ) + + +def _scale_by_adam_flat( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + moment_requires_grad: bool = False, +) -> GradientTransformation: + return _scale_by_adam( + b1=b1, + b2=b2, + eps=eps, + eps_root=eps_root, + moment_requires_grad=moment_requires_grad, + already_flattened=True, + ) + + +def _scale_by_adam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + moment_requires_grad: bool = False, + *, + already_flattened: bool = False, +) -> GradientTransformation: + # pylint: disable=unneeded-not + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= b1 < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {b1}') + if not 0.0 <= b2 < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {b2}') + # pylint: enable=unneeded-not + + if already_flattened: + tree_map = tree_map_flat + else: + tree_map = pytree.tree_map # type: ignore[assignment] + + def init_fn(params): + zero = tree_map( # count init + lambda t: torch.zeros(1, dtype=torch.int64, device=t.device).squeeze_(), params + ) + mu = tree_map( # first moment + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params + ) + nu = tree_map( # second moment + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params + ) + return ScaleByAdamState(mu=mu, nu=nu, count=zero) + + def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument + mu = update_moment.impl( # type: ignore[attr-defined] + updates, state.mu, b1, order=1, inplace=inplace, already_flattened=already_flattened + ) + nu = update_moment.impl( # type: ignore[attr-defined] + updates, state.nu, b2, order=2, inplace=inplace, already_flattened=already_flattened + ) + # pylint: disable=line-too-long + count_inc = inc_count.impl(updates, state.count, already_flattened=already_flattened) # type: ignore[attr-defined] + mu_hat = _bias_correction(mu, b1, count_inc, already_flattened=already_flattened) + nu_hat = _bias_correction(nu, b2, count_inc, already_flattened=already_flattened) + + if inplace: + + def f(g, m, v): # pylint: disable=unused-argument + return m.div_(v.add_(eps_root).sqrt_().add(eps)) + + else: + + def f(g, m, v): # pylint: disable=unused-argument + return m.div(v.add(eps_root).sqrt_().add(eps)) + + updates = tree_map(f, updates, mu_hat, nu_hat) + return updates, ScaleByAdamState(mu=mu, nu=nu, count=count_inc) + + return GradientTransformation(init_fn, update_fn) + + +scale_by_adam.flat = _scale_by_adam_flat # type: ignore[attr-defined] +scale_by_adam.impl = _scale_by_adam # type: ignore[attr-defined] + + +def scale_by_accelerated_adam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + moment_requires_grad: bool = False, +) -> GradientTransformation: + """Rescale updates according to the Adam algorithm. + + This function is accelerated by using some fused accelerated operators. + + References: + [Kingma et al, 2014](https://arxiv.org/abs/1412.6980) + + Args: + b1: (default: :const:`0.9`) + Decay rate for the exponentially weighted average of grads. + b2: (default: :const:`0.999`) + Decay rate for the exponentially weighted average of squared grads. + eps: (default: :const:`1e-8`) + Term added to the denominator to improve numerical stability. + eps_root: (default: :const:`0.0`) + Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + moment_requires_grad: (default: :data:`False`) + If :data:`True`, states will be created with flag `requires_grad = True`. + + Returns: + An (init_fn, update_fn) tuple. + """ + return _scale_by_accelerated_adam( + b1=b1, + b2=b2, + eps=eps, + eps_root=eps_root, + moment_requires_grad=moment_requires_grad, + already_flattened=False, + ) + + +def _scale_by_accelerated_adam_flat( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + moment_requires_grad: bool = False, +) -> GradientTransformation: + return _scale_by_accelerated_adam( + b1=b1, + b2=b2, + eps=eps, + eps_root=eps_root, + moment_requires_grad=moment_requires_grad, + already_flattened=True, + ) + + +def _scale_by_accelerated_adam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + moment_requires_grad: bool = False, + *, + already_flattened: bool = False, +) -> GradientTransformation: + # pylint: disable=unneeded-not + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= b1 < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {b1}') + if not 0.0 <= b2 < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {b2}') + # pylint: enable=unneeded-not + + if already_flattened: + tree_map = tree_map_flat + + # pylint: disable-next=unused-argument + def update_fn(updates, state, *, params=None, inplace=True): + count_inc = inc_count.impl(updates, state.count, already_flattened=True) # type: ignore[attr-defined] + + op = AdamOp(b1=b1, b2=b2, eps=eps, eps_root=eps_root, inplace=inplace) + out = tree_map_flat(op, state.mu, state.nu, updates, count_inc) + + new_mu, new_nu, new_updates = tuple(zip(*out)) # transpose + return new_updates, ScaleByAdamState(mu=new_mu, nu=new_nu, count=count_inc) + + else: + tree_map = pytree.tree_map # type: ignore[assignment] + + # pylint: disable-next=unused-argument + def update_fn(updates, state, *, params=None, inplace=True): + count_inc = inc_count.impl(updates, state.count, already_flattened=False) # type: ignore[attr-defined] + + treespec = pytree.tree_structure(updates) + + op = AdamOp(b1=b1, b2=b2, eps=eps, eps_root=eps_root, inplace=inplace) + out = pytree.tree_map(op, state.mu, state.nu, updates, count_inc) + + new_mu: Updates + new_nu: Updates + new_updates: Updates + new_mu, new_nu, new_updates = pytree.tree_transpose(treespec, TRIPLE_PYTREE_SPEC, out) # type: ignore[misc] + return new_updates, ScaleByAdamState(mu=new_mu, nu=new_nu, count=count_inc) + + def init_fn(params): + zero = tree_map( # count init + lambda t: torch.zeros(1, dtype=torch.int64, device=t.device).squeeze_(), params + ) + mu = tree_map( # first moment + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params + ) + nu = tree_map( # second moment + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params + ) + return ScaleByAdamState(mu=mu, nu=nu, count=zero) + + return GradientTransformation(init_fn, update_fn) + + +scale_by_accelerated_adam.flat = _scale_by_accelerated_adam_flat # type: ignore[attr-defined] +scale_by_accelerated_adam.impl = _scale_by_accelerated_adam # type: ignore[attr-defined] diff --git a/torchopt/transform/scale_by_rms.py b/torchopt/transform/scale_by_rms.py new file mode 100644 index 00000000..3451fafe --- /dev/null +++ b/torchopt/transform/scale_by_rms.py @@ -0,0 +1,136 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/transform.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset transformations for scaling updates by exponential root mean-squared (RMS).""" + +from typing import NamedTuple + +import torch + +from torchopt import pytree +from torchopt.base import GradientTransformation +from torchopt.transform.utils import tree_map_flat, update_moment +from torchopt.typing import Updates + + +__all__ = ['scale_by_rms'] + + +class ScaleByRmsState(NamedTuple): + """State for exponential root mean-squared (RMS)-normalized updates.""" + + nu: Updates + + +def scale_by_rms( + alpha: float = 0.9, eps: float = 1e-8, initial_scale: float = 0.0 +) -> GradientTransformation: + """Rescale updates by the root of the exp. moving avg of the square. + + References: + [Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) + + Args: + alpha: (default: :const:`0.9`) + Decay rate for the exponentially weighted average of squared grads. + eps: (default: :const:`1e-8`) + Term added to the denominator to improve numerical stability. + initial_scale: (default: :const:`0.0`) + Initial value for second moment + + Returns: + An (init_fn, update_fn) tuple. + """ + return _scale_by_rms( + alpha=alpha, + eps=eps, + initial_scale=initial_scale, + already_flattened=False, + ) + + +def _scale_by_rms_flat( + alpha: float = 0.9, eps: float = 1e-8, initial_scale: float = 0.0 +) -> GradientTransformation: + return _scale_by_rms( + alpha=alpha, + eps=eps, + initial_scale=initial_scale, + already_flattened=True, + ) + + +def _scale_by_rms( + alpha: float = 0.9, + eps: float = 1e-8, + initial_scale: float = 0.0, + *, + already_flattened: bool = False, +) -> GradientTransformation: + # pylint: disable=unneeded-not + if not 0.0 <= alpha: + raise ValueError(f'Invalid alpha value: {alpha}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + # pylint: enable=unneeded-not + + if already_flattened: + tree_map = tree_map_flat + else: + tree_map = pytree.tree_map # type: ignore[assignment] + + def init_fn(params): + nu = tree_map(lambda n: torch.full_like(n, initial_scale), params) # second moment + return ScaleByRmsState(nu=nu) + + def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument + nu = update_moment.impl( # type: ignore[attr-defined] + updates, state.nu, alpha, order=2, inplace=inplace, already_flattened=already_flattened + ) + + if inplace: + + def f(g, n): # pylint: disable=invalid-name + return g.div_(n.sqrt().add_(eps)) + + else: + + def f(g, n): # pylint: disable=invalid-name + return g.div(n.sqrt().add(eps)) + + updates = tree_map(f, updates, nu) + return updates, ScaleByRmsState(nu=nu) + + return GradientTransformation(init_fn, update_fn) + + +scale_by_rms.flat = _scale_by_rms_flat # type: ignore[attr-defined] +scale_by_rms.impl = _scale_by_rms # type: ignore[attr-defined] diff --git a/torchopt/transform/scale_by_schedule.py b/torchopt/transform/scale_by_schedule.py new file mode 100644 index 00000000..49b6abb7 --- /dev/null +++ b/torchopt/transform/scale_by_schedule.py @@ -0,0 +1,114 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/transform.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset transformation for scaling updates by learning rate schedules.""" + +from typing import NamedTuple + +import torch + +from torchopt import pytree +from torchopt.base import GradientTransformation +from torchopt.transform.utils import inc_count, tree_map_flat +from torchopt.typing import Schedule, SequenceOfTensors + + +__all__ = ['scale_by_schedule'] + + +class ScaleByScheduleState(NamedTuple): + """Maintains count for scale scheduling.""" + + count: SequenceOfTensors # type: ignore + + +def scale_by_schedule(step_size_fn: Schedule) -> GradientTransformation: + """Scale updates using a custom schedule for the ``step_size``. + + Args: + step_size_fn: + A function that takes an update count as input and proposes the ``step_size`` to + multiply the updates by. + + Returns: + An ``(init_fn, update_fn)`` tuple. + """ + return _scale_by_schedule(step_size_fn=step_size_fn, already_flattened=False) + + +def _scale_by_schedule_flat(step_size_fn: Schedule) -> GradientTransformation: + return _scale_by_schedule(step_size_fn=step_size_fn, already_flattened=True) + + +def _scale_by_schedule( + step_size_fn: Schedule, *, already_flattened: bool = False +) -> GradientTransformation: + if already_flattened: + tree_map = tree_map_flat + else: + tree_map = pytree.tree_map # type: ignore[assignment] + + def init_fn(params): + zero = tree_map( # count init + lambda t: torch.zeros(1, dtype=torch.int64, device=t.device).squeeze_(), params + ) + return ScaleByScheduleState(count=zero) + + def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument + if inplace: + + def f(g, c): # pylint: disable=invalid-name + step_size = step_size_fn(c) + return g.mul_(step_size) + + else: + + def f(g, c): # pylint: disable=invalid-name + step_size = step_size_fn(c) + return g.mul(step_size) + + updates = tree_map(f, updates, state.count) + return ( + updates, + ScaleByScheduleState( + count=inc_count.impl( # type: ignore[attr-defined] + updates, + state.count, + already_flattened=already_flattened, + ) + ), + ) + + return GradientTransformation(init_fn, update_fn) + + +scale_by_schedule.flat = _scale_by_schedule_flat # type: ignore[attr-defined] +scale_by_schedule.impl = _scale_by_schedule # type: ignore[attr-defined] diff --git a/torchopt/transform/scale_by_stddev.py b/torchopt/transform/scale_by_stddev.py new file mode 100644 index 00000000..37138566 --- /dev/null +++ b/torchopt/transform/scale_by_stddev.py @@ -0,0 +1,143 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/transform.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset transformations for scaling updates by the root of the centered exponential moving average.""" + +# pylint: disable=invalid-name + +from typing import NamedTuple + +import torch + +from torchopt import pytree +from torchopt.base import GradientTransformation +from torchopt.transform.utils import tree_map_flat, update_moment +from torchopt.typing import Updates + + +__all__ = ['scale_by_stddev'] + + +class ScaleByRStdDevState(NamedTuple): + """State for centered exponential moving average of squares of updates.""" + + mu: Updates + nu: Updates + + +def scale_by_stddev( + alpha: float = 0.9, eps: float = 1e-8, initial_scale: float = 0.0 +) -> GradientTransformation: + """Rescale updates by the root of the centered exponential moving average of squares. + + References: + [Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) + + Args: + alpha: (default: :const:`0.9`) + Decay rate for the exponentially weighted average of squared grads. + eps: (default: :const:`1e-8`) + Term added to the denominator to improve numerical stability. + initial_scale: (default: :const:`0.0`) + Initial value for second moment + + Returns: + An (init_fn, update_fn) tuple. + """ + return _scale_by_stddev( + alpha=alpha, + eps=eps, + initial_scale=initial_scale, + already_flattened=False, + ) + + +def _scale_by_stddev_flat( + alpha: float = 0.9, eps: float = 1e-8, initial_scale: float = 0.0 +) -> GradientTransformation: + return _scale_by_stddev( + alpha=alpha, + eps=eps, + initial_scale=initial_scale, + already_flattened=True, + ) + + +def _scale_by_stddev( + alpha: float = 0.9, + eps: float = 1e-8, + initial_scale: float = 0.0, + *, + already_flattened: bool = False, +) -> GradientTransformation: + # pylint: disable=unneeded-not + if not 0.0 <= alpha: + raise ValueError(f'Invalid alpha value: {alpha}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + # pylint: enable=unneeded-not + + if already_flattened: + tree_map = tree_map_flat + else: + tree_map = pytree.tree_map # type: ignore[assignment] + + def init_fn(params): + mu = tree_map(torch.zeros_like, params) # first moment + nu = tree_map(lambda n: torch.full_like(n, initial_scale), params) # second moment + return ScaleByRStdDevState(mu=mu, nu=nu) + + def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument + mu = update_moment.impl( # type: ignore[attr-defined] + updates, state.mu, alpha, order=1, inplace=inplace, already_flattened=already_flattened + ) + nu = update_moment.impl( # type: ignore[attr-defined] + updates, state.nu, alpha, order=2, inplace=inplace, already_flattened=already_flattened + ) + + if inplace: + + def f(g, m, n): + return g.div_(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) + + else: + + def f(g, m, n): + return g.div(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) + + updates = tree_map(f, updates, mu, nu) + return updates, ScaleByRStdDevState(mu=mu, nu=nu) + + return GradientTransformation(init_fn, update_fn) + + +scale_by_stddev.flat = _scale_by_stddev_flat # type: ignore[attr-defined] +scale_by_stddev.impl = _scale_by_stddev # type: ignore[attr-defined] diff --git a/torchopt/transform/trace.py b/torchopt/transform/trace.py new file mode 100644 index 00000000..1d741d04 --- /dev/null +++ b/torchopt/transform/trace.py @@ -0,0 +1,194 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/transform.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset transformations for scaling updates by Adam.""" + +# pylint: disable=invalid-name + +from typing import NamedTuple + +import torch + +from torchopt import pytree +from torchopt.base import GradientTransformation, identity +from torchopt.transform.utils import tree_map_flat +from torchopt.typing import Params + + +__all__ = ['trace'] + + +class TraceState(NamedTuple): + """Holds an aggregation of past updates.""" + + trace: Params + + +def trace( + momentum: float = 0.9, + dampening: float = 0.0, + nesterov: bool = False, + moment_requires_grad: bool = False, +) -> GradientTransformation: + """Compute a trace of past updates. + + Note: `trace` and `ema` have very similar but distinct updates; + `trace = decay * trace + t`, while `ema = decay * ema + (1 - decay) * t`. + Both are frequently found in the optimization literature. + + Args: + momentum: (default: :const:`0.9`) + The decay rate for the trace of past updates. + dampening: (default: :const:`0.0`) + Dampening for momentum. + nesterov: (default: :data:`False`) + Whether to use Nesterov momentum. + moment_requires_grad: (default: :data:`False`) + If :data:`True`, states will be created with flag `requires_grad = True`. + + Returns: + An (init_fn, update_fn) tuple. + """ + return _trace( + momentum=momentum, + dampening=dampening, + nesterov=nesterov, + moment_requires_grad=moment_requires_grad, + already_flattened=False, + ) + + +def _trace_flat( + momentum: float = 0.9, + dampening: float = 0.0, + nesterov: bool = False, + moment_requires_grad: bool = False, +) -> GradientTransformation: + return _trace( + momentum=momentum, + dampening=dampening, + nesterov=nesterov, + moment_requires_grad=moment_requires_grad, + already_flattened=True, + ) + + +def _trace( + momentum: float = 0.9, + dampening: float = 0.0, + nesterov: bool = False, + moment_requires_grad: bool = False, + *, + already_flattened: bool = False, +) -> GradientTransformation: + # pylint: disable=unneeded-not + if not 0.0 <= momentum: + raise ValueError(f'Invalid momentum value: {momentum}') + if nesterov and (momentum <= 0.0 or dampening != 0.0): + raise ValueError('Nesterov momentum requires a momentum and zero dampening') + # pylint: enable=unneeded-not + + if momentum == 0.0: + return identity() + + if already_flattened: + tree_map = tree_map_flat + else: + tree_map = pytree.tree_map # type: ignore[assignment] + + def init_fn(params): + return TraceState( + trace=tree_map( + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params + ) + ) + + first_call = True + + def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument + nonlocal first_call + + if nesterov: + if inplace: + + def f1(g, t): + if first_call: + return t.add_(g) + return t.mul_(momentum).add_(g) + + def f2(g, t): + return g.add_(t, alpha=momentum) + + new_trace = tree_map(f1, updates, state.trace) + updates = tree_map(f2, updates, new_trace) + else: + + def f1(g, t): + if first_call: + return t.add(g) + return t.mul(momentum).add_(g) + + def f2(g, t): + return g.add(t, alpha=momentum) + + new_trace = tree_map(f1, updates, state.trace) + updates = tree_map(f2, updates, new_trace) + else: + if inplace: + + def f(g, t): + if first_call: + return t.add(g) + return t.mul_(momentum).add_(g, alpha=1.0 - dampening) + + def copy_(g, t): + return g.copy_(t) + + new_trace = tree_map(f, updates, state.trace) + updates = tree_map(copy_, updates, new_trace) + else: + + def f(g, t): + if first_call: + return t.add(g) + return t.mul(momentum).add_(g, alpha=1.0 - dampening) + + new_trace = tree_map(f, updates, state.trace) + updates = tree_map(torch.clone, new_trace) + + first_call = False + return updates, TraceState(trace=new_trace) + + return GradientTransformation(init_fn, update_fn) + + +trace.flat = _trace_flat # type: ignore[attr-defined] +trace.impl = _trace # type: ignore[attr-defined] diff --git a/torchopt/transform/utils.py b/torchopt/transform/utils.py new file mode 100644 index 00000000..497df44e --- /dev/null +++ b/torchopt/transform/utils.py @@ -0,0 +1,151 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/transform.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for the preset transformations.""" + +from collections import deque +from typing import Any, Callable, Iterable, List + +import torch + +from torchopt import pytree +from torchopt.typing import TensorTree, Updates + + +__all__ = ['tree_map_flat', 'tree_map_flat_', 'inc_count', 'update_moment'] + + +INT64_MAX = torch.iinfo(torch.int64).max + + +def tree_map_flat(func: Callable, *flat_args: Any, none_is_leaf: bool = False) -> List[Any]: + """Apply a function to each element of a flattened list.""" + if none_is_leaf: + fn = func + else: + + def fn(x, *xs): + return func(x, *xs) if x is not None else None + + return list(map(fn, *flat_args)) + + +def tree_map_flat_( + func: Callable, flat_arg: Iterable[Any], *flat_args: Any, none_is_leaf: bool = False +) -> Iterable[Any]: + """Apply a function to each element of a flattened list.""" + if none_is_leaf: + fn = func + else: + + def fn(x, *xs): + return func(x, *xs) if x is not None else None + + flat_results = map(fn, flat_arg, *flat_args) + deque(flat_results, maxlen=0) # consume and exhaust the iterable + return flat_arg + + +def inc_count(updates: Updates, count: TensorTree) -> TensorTree: + """Increments int counter by one. + + Returns: + A counter incremented by one, or :data:`INT64_MAX` if the maximum precision is reached. + """ + return _inc_count(updates=updates, count=count, already_flattened=False) + + +def _inc_count_flat(updates: Updates, count: TensorTree) -> TensorTree: + return _inc_count(updates=updates, count=count, already_flattened=True) + + +def _inc_count( + updates: Updates, count: TensorTree, *, already_flattened: bool = False +) -> TensorTree: + def f(c, g): # pylint: disable=invalid-name + return c + (c != INT64_MAX).to(torch.int64) if g is not None else c + + if already_flattened: + return tree_map_flat(f, count, updates) + return pytree.tree_map(f, count, updates) + + +inc_count.flat = _inc_count_flat # type: ignore[attr-defined] +inc_count.impl = _inc_count # type: ignore[attr-defined] + + +def update_moment(updates, moments, decay, *, order, inplace=True): + """Compute the exponential moving average of the ``order``-th moment.""" + return _update_moment( + updates, moments, decay, order=order, inplace=inplace, already_flattened=False + ) + + +def _update_moment_flat(updates, moments, decay, *order, inplace=True): + return _update_moment( + updates, moments, decay, order=order, inplace=inplace, already_flattened=True + ) + + +def _update_moment(updates, moments, decay, *, order, inplace=True, already_flattened=False): + assert order in (1, 2) + + if inplace: + + if order == 2: + + def f(g, t): + return t.mul_(decay).addcmul_(g, g, value=1 - decay) if g is not None else t + + else: + + def f(g, t): + return t.mul_(decay).add_(g, alpha=1 - decay) if g is not None else t + + else: + + if order == 2: + + def f(g, t): + return t.mul(decay).addcmul_(g, g, value=1 - decay) if g is not None else t + + else: + + def f(g, t): + return t.mul(decay).add_(g, alpha=1 - decay) if g is not None else t + + if already_flattened: + return tree_map_flat(f, updates, moments) + return pytree.tree_map(f, updates, moments, none_is_leaf=True) + + +update_moment.flat = _update_moment_flat # type: ignore[attr-defined] +update_moment.impl = _update_moment # type: ignore[attr-defined] diff --git a/torchopt/typing.py b/torchopt/typing.py new file mode 100644 index 00000000..a7499a99 --- /dev/null +++ b/torchopt/typing.py @@ -0,0 +1,127 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Typing utilities.""" + +from typing import Callable, List, Optional, Sequence, Tuple, TypeVar, Union +from typing_extensions import TypeAlias # Python 3.10+ +from typing_extensions import Protocol, runtime_checkable # Python 3.8+ + +import torch +import torch.distributed.rpc as rpc +from optree.typing import PyTree, PyTreeTypeVar +from torch import Tensor +from torch.distributions import Distribution +from torch.futures import Future +from torch.types import Device + +from torchopt.base import ( + ChainedGradientTransformation, + EmptyState, + GradientTransformation, + UninitializedState, +) + + +__all__ = [ + 'GradientTransformation', + 'ChainedGradientTransformation', + 'EmptyState', + 'UninitializedState', + 'Params', + 'Updates', + 'OptState', + 'Scalar', + 'Numeric', + 'Schedule', + 'ScalarOrSchedule', + 'PyTree', + 'Tensor', + 'OptionalTensor', + 'ListOfTensors', + 'TupleOfTensors', + 'SequenceOfTensors', + 'TensorOrTensors', + 'TensorTree', + 'ListOfOptionalTensors', + 'TupleOfOptionalTensors', + 'SequenceOfOptionalTensors', + 'OptionalTensorOrOptionalTensors', + 'OptionalTensorTree', + 'Future', + 'LinearSolver', + 'Device', + 'Size', + 'Distribution', + 'SampleFunc', + 'Samplable', +] + +T = TypeVar('T') + +Scalar: TypeAlias = Union[float, int, bool] +Numeric: TypeAlias = Union[Tensor, Scalar] + +Schedule: TypeAlias = Callable[[Numeric], Numeric] +ScalarOrSchedule: TypeAlias = Union[float, Schedule] + +OptionalTensor = Optional[Tensor] + +ListOfTensors = List[Tensor] +TupleOfTensors = Tuple[Tensor, ...] +SequenceOfTensors = Sequence[Tensor] +TensorOrTensors = Union[Tensor, SequenceOfTensors] +TensorTree: TypeAlias = PyTreeTypeVar('TensorTree', Tensor) # type: ignore[valid-type] + +ListOfOptionalTensors = List[OptionalTensor] +TupleOfOptionalTensors = Tuple[OptionalTensor, ...] +SequenceOfOptionalTensors = Sequence[OptionalTensor] +OptionalTensorOrOptionalTensors = Union[OptionalTensor, SequenceOfOptionalTensors] +OptionalTensorTree: TypeAlias = PyTreeTypeVar('OptionalTensorTree', OptionalTensor) # type: ignore[valid-type] + +# Parameters are arbitrary nests of `torch.Tensor`. +Params: TypeAlias = TensorTree +Updates: TypeAlias = Params # Gradient updates are of the same type as parameters. +OptState: TypeAlias = TensorTree # States are arbitrary nests of `torch.Tensor`. + +if rpc.is_available(): + from torch.distributed.rpc import RRef # pylint: disable=ungrouped-imports,unused-import + + __all__.extend(['RRef']) +else: + RRef = None # type: ignore[misc,assignment] # pylint: disable=invalid-name + +# solver(matvec, b) -> solution +LinearSolver: TypeAlias = Callable[[Callable[[TensorTree], TensorTree], TensorTree], TensorTree] + + +Size = torch.Size + +# sample(sample_shape) -> Tensor +SampleFunc: TypeAlias = Callable[[Size], Union[Tensor, Sequence[Numeric]]] + + +@runtime_checkable +class Samplable(Protocol): # pylint: disable=too-few-public-methods + """Abstract protocol class that supports sampling.""" + + def sample( + self, sample_shape: Size = Size() # pylint: disable=unused-argument + ) -> Union[Tensor, Sequence[Numeric]]: + # pylint: disable-next=line-too-long + """Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.""" + raise NotImplementedError + + +Samplable.register(Distribution) diff --git a/torchopt/_src/update.py b/torchopt/update.py similarity index 92% rename from torchopt/_src/update.py rename to torchopt/update.py index 753292d7..85e93673 100644 --- a/torchopt/_src/update.py +++ b/torchopt/update.py @@ -29,14 +29,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Helper functions for applying updates.""" -from torchopt._src import base # pylint: disable=unused-import -from torchopt._src.utils import pytree +from torchopt import pytree +from torchopt.typing import Params, Updates -def apply_updates( - params: 'base.Params', updates: 'base.Updates', *, inplace: bool = True -) -> 'base.Params': +__all__ = ['apply_updates'] + + +def apply_updates(params: Params, updates: Updates, *, inplace: bool = True) -> Params: """Applies an update to the corresponding parameters. This is a utility functions that applies an update to a set of parameters, and then returns the diff --git a/torchopt/utils.py b/torchopt/utils.py new file mode 100644 index 00000000..f60bc6d6 --- /dev/null +++ b/torchopt/utils.py @@ -0,0 +1,506 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for TorchOpt.""" + +import copy +import itertools +from typing import ( + TYPE_CHECKING, + Dict, + List, + NamedTuple, + Optional, + Sequence, + Set, + Tuple, + Union, + cast, + overload, +) +from typing_extensions import Literal # Python 3.8+ +from typing_extensions import TypeAlias # Python 3.10+ + +import torch +import torch.nn as nn + +from torchopt import pytree +from torchopt.typing import Device, OptState, TensorTree + + +if TYPE_CHECKING: + from torchopt.optim.meta.base import MetaOptimizer + + +__all__ = [ + 'ModuleState', + 'stop_gradient', + 'extract_state_dict', + 'recover_state_dict', + 'module_clone', + 'module_detach_', +] + + +class ModuleState(NamedTuple): + """Container for module state.""" + + params: Tuple[Dict[str, torch.Tensor], ...] + buffers: Tuple[Dict[str, torch.Tensor], ...] + visual_contents: Optional[Dict] = None + detach_buffers: bool = False + + +CopyMode: TypeAlias = Literal['reference', 'copy', 'deepcopy', 'ref', 'clone', 'deepclone'] + + +def stop_gradient(target: Union[TensorTree, ModuleState, nn.Module, 'MetaOptimizer']) -> None: + """Stop the gradient for the input object. + + Since a tensor use :attr:`grad_fn` to connect itself with the previous computation graph, the + backpropagated gradient will flow over the tensor and continue flow to the tensors that is + connected by :attr:`grad_fn`. Some algorithms requires manually detaching tensors from the + computation graph. + + Note that the :func:`stop_gradient` operation is in-place. + + Args: + target: The target that to be detached from the computation graph, it could be a + :class:`nn.Module`, :class:`torchopt.MetaOptimizer`, state of the + :class:`torchopt.MetaOptimizer`, or just a plain list of tensors. + inplace: If :data:`True`, the target will be detached in-place. if :data:`Frue`, this + function will return a detached copy of the target. The in-place operation is fast and + memory efficient but may raise backpropagation error. + """ + # pylint: disable-next=import-outside-toplevel + from torchopt.optim.meta.base import MetaOptimizer + + def fn_(obj): + if isinstance(obj, torch.Tensor): + requires_grad = obj.requires_grad + obj.detach_().requires_grad_(requires_grad) + + if isinstance(target, ModuleState): + true_target = cast(TensorTree, (target.params, target.buffers)) + elif isinstance(target, nn.Module): + true_target = cast(TensorTree, tuple(target.parameters())) + elif isinstance(target, MetaOptimizer): + true_target = cast(TensorTree, target.state_dict()) + else: + true_target = cast(TensorTree, target) # tree of tensors + + pytree.tree_map_(fn_, true_target) + + +@overload +def extract_state_dict( + target: nn.Module, + *, + by: CopyMode = 'reference', + device: Device = None, + with_buffers: bool = True, + enable_visual: bool = False, + visual_prefix: str = '', +) -> ModuleState: + ... + + +@overload +def extract_state_dict( + target: 'MetaOptimizer', + *, + by: CopyMode = 'reference', + device: Device = None, + with_buffers: bool = True, + enable_visual: bool = False, + visual_prefix: str = '', +) -> Tuple[OptState, ...]: + ... + + +# pylint: disable-next=too-many-branches,too-many-locals +def extract_state_dict( + target: Union[nn.Module, 'MetaOptimizer'], + *, + by: CopyMode = 'reference', + device: Device = None, + with_buffers: bool = True, + detach_buffers: bool = False, + enable_visual: bool = False, + visual_prefix: str = '', +) -> Union[ModuleState, Tuple[OptState, ...]]: + """Extract target state. + + Since a tensor use :attr:`grad_fn` to connect itself with the previous computation graph, the + backpropagated gradient will flow over the tensor and continue flow to the tensors that is + connected by :attr:`grad_fn`. Some algorithms requires manually detaching tensors from the + computation graph. + + Note that the extracted state is a reference, which means any in-place operator will affect the + target that the state is extracted from. + + Args: + target: It could be a :class:`nn.Module` or :class:`torchopt.MetaOptimizer`. + by: The extract policy of tensors in the target. + - :const:`'reference'`: The extracted tensors will be references to the original + tensors. + - :const:`'copy'`: The extracted tensors will be clones of the original tensors. This + makes the copied tensors have :attr:`grad_fn` to be a ```` function + points to the original tensors. + - :const:`'deepcopy'`: The extracted tensors will be deep-copied from the original + tensors. The deep-copied tensors will detach from the original computation graph. + device: If specified, move the extracted state to the specified device. + with_buffers: Extract buffer together with parameters, this argument is only used if the + input target is :class:`nn.Module`. + detach_buffers: Whether to detach the reference to the buffers, this argument is only used + if the input target is :class:`nn.Module` and ``by='reference'``. + enable_visual: Add additional annotations, which could be used in computation graph + visualization. Currently, this flag only has effect on :class:`nn.Module` but we will + support :class:`torchopt.MetaOptimizer` later. + visual_prefix: Prefix for the visualization annotations. + + Returns: + State extracted of the input object. + """ + assert by in ('reference', 'copy', 'deepcopy', 'ref', 'clone', 'deepclone') + by = by.replace('clone', 'copy') + by = 'reference' if by == 'ref' else by + + # pylint: disable=import-outside-toplevel + from torchopt.optim.meta.base import MetaOptimizer + + if device is not None: + target_device = torch.device(device) + + def reference(t: torch.Tensor) -> torch.Tensor: + return t.to(device=target_device) + + def clone(t: torch.Tensor) -> torch.Tensor: + return t.clone().to(device=target_device) + + def clone_detach_(t: torch.Tensor) -> torch.Tensor: + if isinstance(t, nn.Parameter): + return nn.Parameter(t.clone().detach_(), requires_grad=t.requires_grad).to( + device=target_device + ) + return t.clone().detach_().to(device=target_device).requires_grad_(t.requires_grad) + + else: + + def reference(t: torch.Tensor) -> torch.Tensor: + return t + + def clone(t: torch.Tensor) -> torch.Tensor: + return t.clone() + + def clone_detach_(t: torch.Tensor) -> torch.Tensor: + if isinstance(t, nn.Parameter): + return nn.Parameter(t.clone().detach_(), requires_grad=t.requires_grad) + return t.clone().detach_().requires_grad_(t.requires_grad) + + if by == 'reference': + replicate = reference + elif by == 'copy': + replicate = clone + else: + replicate = clone_detach_ + + if isinstance(target, nn.Module): # pylint: disable=no-else-return + if enable_visual: + visual_contents = {} + + for k, v in target.named_parameters(): # pylint: disable=invalid-name + if v.grad_fn is not None: + visual_contents.update({v.grad_fn: (visual_prefix + k, v)}) + else: + visual_contents.update({v: visual_prefix + k}) # type: ignore[dict-item] + else: + visual_contents = None + + params: List[Dict[str, torch.Tensor]] = [] + buffers: List[Dict[str, torch.Tensor]] = [] + memo: Set[nn.Module] = set() + + def update_params(container): + if len(container) > 0: + params.append( + type(container)( + (k, replicate(v)) + for k, v in container.items() + if isinstance(v, torch.Tensor) + ) + ) + + def update_buffers(container): + if len(container) > 0: + fn = clone_detach_ if detach_buffers else replicate + buffers.append( + type(container)( + (k, fn(v)) for k, v in container.items() if isinstance(v, torch.Tensor) + ) + ) + + # pylint: disable=protected-access + update_params(target._parameters) + if with_buffers: + update_buffers(target._buffers) + memo.add(target) + for submodule in target.modules(): + if submodule in memo: + continue + update_params(submodule._parameters) + if with_buffers: + update_buffers(submodule._buffers) + memo.add(submodule) + + return ModuleState( + params=tuple(params), + buffers=tuple(buffers), + visual_contents=visual_contents, + detach_buffers=detach_buffers, + ) + + elif isinstance(target, MetaOptimizer): + state = target.state_dict() + + def get_variable(t): + if isinstance(t, torch.Tensor): + return replicate(t) + return t + + state = pytree.tree_map(get_variable, state) # type: ignore[arg-type,assignment] + return state + + raise RuntimeError(f'Unexpected class of {target}') + + +def extract_module_containers( + module: nn.Module, with_buffers: bool = True +) -> Tuple[ + Tuple[Dict[str, Optional[torch.Tensor]], ...], + Tuple[Dict[str, Optional[torch.Tensor]], ...], +]: + """Extract the references to the containers of parameters and buffers from a module.""" + if isinstance(module, nn.Module): + params: List[Dict[str, Optional[torch.Tensor]]] = [] + buffers: List[Dict[str, Optional[torch.Tensor]]] = [] + memo: Set[nn.Module] = set() + + def update_container(container, items): + if len(items) > 0: + container.append(items) # we need references to original dictionaries + + # pylint: disable=protected-access + update_container(params, module._parameters) + if with_buffers: + update_container(buffers, module._buffers) + memo.add(module) + for submodule in module.modules(): + if submodule in memo: + continue + update_container(params, submodule._parameters) + if with_buffers: + update_container(buffers, submodule._buffers) + memo.add(submodule) + return tuple(params), tuple(buffers) + + raise RuntimeError(f'Unexpected class of {module}') + + +def recover_state_dict( + target: Union[nn.Module, 'MetaOptimizer'], + state: Union[ModuleState, Sequence[OptState]], +) -> None: + """Recover state. + + This function is compatible for the ``extract_state``. + + Note that the recovering process is not in-place, so the tensors of the object will not be + modified. + + Args: + target: Target that need to recover. + state: The recovering state. + """ + # pylint: disable-next=import-outside-toplevel + from torchopt.optim.meta.base import MetaOptimizer + + if isinstance(target, nn.Module): + params, buffers, *_ = state = cast(ModuleState, state) + params_containers, buffers_containers = extract_module_containers(target, with_buffers=True) + + if state.detach_buffers: + + def clone_detach_(t: torch.Tensor) -> torch.Tensor: + if isinstance(t, nn.Parameter): + return nn.Parameter(t.clone().detach_(), requires_grad=t.requires_grad) + return t.clone().detach_().requires_grad_(t.requires_grad) + + buffers = cast( + Tuple[Dict[str, torch.Tensor], ...], + pytree.tree_map(clone_detach_, buffers), # type: ignore[arg-type] + ) + + for tgt, src in itertools.chain( + zip(params_containers, params), + zip(buffers_containers, buffers), + ): + tgt.update(src) + elif isinstance(target, MetaOptimizer): + state = cast(Sequence[OptState], state) + target.load_state_dict(state) + else: + raise RuntimeError(f'Unexpected class of {target}') + + +@overload +def module_clone( + target: nn.Module, + *, + by: CopyMode = 'reference', + detach_buffers: bool = False, + device: Device = None, +) -> nn.Module: + ... + + +@overload +def module_clone( + target: 'MetaOptimizer', + *, + by: CopyMode = 'reference', + detach_buffers: bool = False, + device: Device = None, +) -> 'MetaOptimizer': + ... + + +@overload +def module_clone( + target: TensorTree, + *, + by: CopyMode = 'reference', + detach_buffers: bool = False, + device: Device = None, +) -> TensorTree: + ... + + +# pylint: disable-next=too-many-locals +def module_clone( + target: Union[nn.Module, 'MetaOptimizer', TensorTree], + *, + by: CopyMode = 'reference', + detach_buffers: bool = False, + device: Device = None, +) -> Union[nn.Module, 'MetaOptimizer', TensorTree]: + """Clone a module. + + Args: + target: The target to be cloned. + by: The extract policy of tensors in the target. + - :const:`'reference'`: The extracted tensors will be references to the original + tensors. + - :const:`'copy'`: The extracted tensors will be clones of the original tensors. This + makes the copied tensors have :attr:`grad_fn` to be a ```` function + points to the original tensors. + - :const:`'deepcopy'`: The extracted tensors will be deep-copied from the original + tensors. The deep-copied tensors will detach from the original computation graph. + detach_buffers: Whether to detach the reference to the buffers, this argument is only used + if the input target is :class:`nn.Module` and ``by='reference'``. + device: If specified, move the cloned module to the specified device. + + Returns: + The cloned module. + """ + assert by in ('reference', 'copy', 'deepcopy', 'ref', 'clone', 'deepclone') + by = by.replace('clone', 'copy') + by = 'reference' if by == 'ref' else by + if device is not None: + device = torch.device(device) + + # pylint: disable-next=import-outside-toplevel + from torchopt.optim.meta.base import MetaOptimizer + + if isinstance(target, (nn.Module, MetaOptimizer)): + if isinstance(target, nn.Module): + containers = cast(TensorTree, extract_module_containers(target, with_buffers=True)) + else: + containers = cast(TensorTree, target.state_dict()) + tensors = pytree.tree_leaves(containers) + memo = {id(t): t for t in tensors} + cloned = copy.deepcopy(target, memo=memo) + state = extract_state_dict( # type: ignore[call-overload] + target, + by=by, + with_buffers=True, + detach_buffers=detach_buffers, + device=device, + ) + recover_state_dict(cloned, state) + return cloned + + # Tree of tensors + if device is not None: + target_device = torch.device(device) + + def reference(t: torch.Tensor) -> torch.Tensor: + return t.to(device=target_device) + + def clone(t: torch.Tensor) -> torch.Tensor: + return t.clone().to(device=target_device) + + def clone_detach_(t: torch.Tensor) -> torch.Tensor: + if isinstance(t, nn.Parameter): + return nn.Parameter(t.clone().detach_(), requires_grad=t.requires_grad).to( + device=target_device + ) + return t.clone().detach_().to(device=target_device).requires_grad_(t.requires_grad) + + else: + + def reference(t: torch.Tensor) -> torch.Tensor: + return t + + def clone(t: torch.Tensor) -> torch.Tensor: + return t.clone() + + def clone_detach_(t: torch.Tensor) -> torch.Tensor: + if isinstance(t, nn.Parameter): + return nn.Parameter(t.clone().detach_(), requires_grad=t.requires_grad) + return t.clone().detach_().requires_grad_(t.requires_grad) + + if by == 'reference': + replicate = reference + elif by == 'copy': + replicate = clone + else: + replicate = clone_detach_ + + return pytree.tree_map(replicate, cast(TensorTree, target)) + + +def module_detach_( + target: Union[TensorTree, ModuleState, nn.Module, 'MetaOptimizer'] +) -> Union[TensorTree, ModuleState, nn.Module, 'MetaOptimizer']: + """Detach a module from the computation graph. + + Args: + target: The target to be detached. + + Returns: + The detached module. + """ + stop_gradient(target) + return target diff --git a/torchopt/version.py b/torchopt/version.py index b79568e7..6d66f945 100644 --- a/torchopt/version.py +++ b/torchopt/version.py @@ -14,4 +14,38 @@ # ============================================================================== """TorchOpt: a high-performance optimizer library built upon PyTorch.""" -__version__ = '0.5.0' +__version__ = '0.6.0' +__license__ = 'Apache License, Version 2.0' +__author__ = 'TorchOpt Contributors' +__release__ = False + +if not __release__: + import os + import subprocess + + try: + prefix, sep, suffix = ( + subprocess.check_output( + ['git', 'describe', '--abbrev=7'], + cwd=os.path.dirname(os.path.abspath(__file__)), + stderr=subprocess.DEVNULL, + text=True, + ) + .strip() + .lstrip('v') + .replace('-', '.dev', 1) + .replace('-', '+', 1) + .partition('.dev') + ) + if sep: + version_prefix, dot, version_tail = prefix.rpartition('.') + prefix = f'{version_prefix}{dot}{int(version_tail) + 1}' + __version__ = sep.join((prefix, suffix)) + del version_prefix, dot, version_tail + else: + __version__ = prefix + del prefix, sep, suffix + except (OSError, subprocess.CalledProcessError): + pass + + del os, subprocess diff --git a/torchopt/_src/visual.py b/torchopt/visual.py similarity index 81% rename from torchopt/_src/visual.py rename to torchopt/visual.py index edf052bc..25a66ada 100644 --- a/torchopt/_src/visual.py +++ b/torchopt/visual.py @@ -15,15 +15,22 @@ # This file is modified from: # https://github.com/szagoruyko/pytorchviz/blob/master/torchviz/dot.py # ============================================================================== +"""Computation graph visualization.""" import warnings from collections import namedtuple -from typing import Dict, Generator +from typing import Generator, Iterable, Mapping, Optional, Union, cast import torch from graphviz import Digraph from pkg_resources import parse_version +from torchopt.typing import TensorOrTensors +from torchopt.utils import ModuleState + + +__all__ = ['make_dot', 'resize_graph'] + Node = namedtuple('Node', ('name', 'inputs', 'attr', 'op')) @@ -42,9 +49,9 @@ def get_fn_name(fn, show_attrs, max_attr_chars): continue val = getattr(fn, attr) attr = attr[len(SAVED_PREFIX) :] - if torch.is_tensor(val): + if isinstance(val, torch.Tensor): attrs[attr] = '[saved tensor]' - elif isinstance(val, tuple) and any(torch.is_tensor(t) for t in val): + elif isinstance(val, tuple) and any(isinstance(t, torch.Tensor) for t in val): attrs[attr] = '[saved tensors]' else: attrs[attr] = str(val) @@ -63,10 +70,20 @@ def truncate(s): # pylint: disable=invalid-name return name + '\n' + sep + '\n' + params -# mypy: ignore-errors # pylint: disable-next=too-many-branches,too-many-statements,too-many-locals def make_dot( - var: torch.Tensor, params=None, show_attrs=False, show_saved=False, max_attr_chars=50 + var: TensorOrTensors, + params: Optional[ + Union[ + Mapping[str, torch.Tensor], + ModuleState, + Generator, + Iterable[Union[Mapping[str, torch.Tensor], ModuleState, Generator]], + ] + ] = None, + show_attrs: bool = False, + show_saved: bool = False, + max_attr_chars: int = 50, ) -> Digraph: """Produces Graphviz representation of PyTorch autograd graph. @@ -106,22 +123,20 @@ def make_dot( param_map = {} if params is not None: - from torchopt._src.utils import _ModuleState # pylint: disable=import-outside-toplevel - - if isinstance(params, _ModuleState): + if isinstance(params, ModuleState) and params.visual_contents is not None: param_map.update(params.visual_contents) - elif isinstance(params, Dict): + elif isinstance(params, Mapping): param_map.update({v: k for k, v in params.items()}) elif isinstance(params, Generator): param_map.update({v: k for k, v in params}) else: for param in params: - if isinstance(param, _ModuleState): + if isinstance(param, ModuleState) and param.visual_contents is not None: param_map.update(param.visual_contents) elif isinstance(param, Generator): param_map.update({v: k for k, v in param}) else: - param_map.update({v: k for k, v in param.items()}) + param_map.update({v: k for k, v in cast(Mapping, param).items()}) node_attr = dict( style='filled', @@ -148,8 +163,8 @@ def get_var_name_with_flag(var): return f'{param_map[var][0]}\n{size_to_str(param_map[var][1].size())}' return None - def add_nodes(fn): - assert not torch.is_tensor(fn) + def add_nodes(fn): # pylint: disable=too-many-branches + assert not isinstance(fn, torch.Tensor) if fn in seen: return seen.add(fn) @@ -161,12 +176,12 @@ def add_nodes(fn): val = getattr(fn, attr) seen.add(val) attr = attr[len(SAVED_PREFIX) :] - if torch.is_tensor(val): + if isinstance(val, torch.Tensor): dot.edge(str(id(fn)), str(id(val)), dir='none') dot.node(str(id(val)), get_var_name(val, attr), fillcolor='orange') if isinstance(val, tuple): for i, t in enumerate(val): - if torch.is_tensor(t): + if isinstance(t, torch.Tensor): name = f'{attr}[{i}]' dot.edge(str(id(fn)), str(id(t)), dir='none') dot.node(str(id(t)), get_var_name(t, name), fillcolor='orange') @@ -203,21 +218,21 @@ def add_nodes(fn): dot.edge(str(id(t)), str(id(fn))) dot.node(str(id(t)), get_var_name(t), fillcolor='orange') - def add_base_tensor(var, color='darkolivegreen1'): - if var in seen: + def add_base_tensor(v, color='darkolivegreen1'): # pylint: disable=invalid-name + if v in seen: return - seen.add(var) - dot.node(str(id(var)), get_var_name(var), fillcolor=color) - if var.grad_fn: - add_nodes(var.grad_fn) - dot.edge(str(id(var.grad_fn)), str(id(var))) + seen.add(v) + dot.node(str(id(v)), get_var_name(v), fillcolor=color) + if v.grad_fn: + add_nodes(v.grad_fn) + dot.edge(str(id(v.grad_fn)), str(id(v))) # pylint: disable=protected-access - if var._is_view(): - add_base_tensor(var._base, color='darkolivegreen3') - dot.edge(str(id(var._base)), str(id(var)), style='dotted') + if v._is_view(): + add_base_tensor(v._base, color='darkolivegreen3') + dot.edge(str(id(v._base)), str(id(v)), style='dotted') # handle multiple outputs - if isinstance(var, tuple): + if isinstance(var, (tuple, list)): for v in var: # pylint: disable=invalid-name add_base_tensor(v) else: @@ -228,7 +243,7 @@ def add_base_tensor(var, color='darkolivegreen1'): return dot -def resize_graph(dot, size_per_element=0.5, min_size=12): +def resize_graph(dot: Digraph, size_per_element: float = 0.5, min_size: float = 12.0) -> None: """Resize the graph according to how much content it contains. Modify the graph in place. diff --git a/tutorials/1_Functional_Optimizer.ipynb b/tutorials/1_Functional_Optimizer.ipynb index f4194835..3d70eb62 100644 --- a/tutorials/1_Functional_Optimizer.ipynb +++ b/tutorials/1_Functional_Optimizer.ipynb @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "[](https://colab.research.google.com/drive/1yfi-ETyIptlIM7WFYWF_IFhX4WF3LldP?usp=sharing)" + "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/1_Functional_Optimizer.ipynb)" ] }, { @@ -88,7 +88,7 @@ " return jnp.matmul(x, params['weight']) + params['bias']\n", "\n", " # Obtain the `opt_state` that contains statistics for the optimizer\n", - " learning_rate = 1.\n", + " learning_rate = 1.0\n", " optimizer = optax.adam(learning_rate)\n", " opt_state = optimizer.init(params)\n", "\n", @@ -116,14 +116,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "Parameters before update: {\n", - " 'weight': DeviceArray([[1.]], dtype=float32)),\n", - " 'bias': DeviceArray([0.], dtype=float32)\n", - "}\n", - "Parameters after update: {\n", - " 'weight': DeviceArray([[6.735325e-06]], dtype=float32),\n", - " 'bias': DeviceArray([-0.99999326], dtype=float32)\n", - "}" + "Parameters before update:\n", + "OrderedDict([\n", + " ('weight', DeviceArray([[1.]], dtype=float32)),\n", + " ('bias', DeviceArray([0.], dtype=float32))\n", + "])\n", + "Parameters after update:\n", + "OrderedDict([\n", + " ('weight', DeviceArray([[6.735325e-06]], dtype=float32)),\n", + " ('bias', DeviceArray([-0.99999326], dtype=float32))\n", + "])\n" ] } ], @@ -153,7 +155,7 @@ " model, params = functorch.make_functional(net) # get the functional version of the model\n", "\n", " # Obtain the `opt_state` that contains statistics for the optimizer\n", - " learning_rate = 1.\n", + " learning_rate = 1.0\n", " optimizer = torchopt.adam(learning_rate)\n", " opt_state = optimizer.init(params)\n", "\n", @@ -165,7 +167,7 @@ "\n", " grads = torch.autograd.grad(loss, params)\n", " updates, opt_state = optimizer.update(grads, opt_state)\n", - " \n", + "\n", " print('Parameters before update:', params)\n", " params = torchopt.apply_updates(params, updates)\n", " print('Parameters after update:', params)" @@ -180,14 +182,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "Parameters before update: (\n", + "Parameters before update:\n", + "(\n", " Parameter containing: tensor([[1.]], requires_grad=True),\n", " Parameter containing: tensor([0.], requires_grad=True)\n", ")\n", - "Parameters after update: (\n", - " Parameter containing: tensor([[0.]], requires_grad=True),\n", - " Parameter containing: tensor([-1.], requires_grad=True)\n", - ")" + "Parameters after update:\n", + "(\n", + " Parameter containing: tensor([[6.6757e-06]], requires_grad=True),\n", + " Parameter containing: tensor([-1.0000], requires_grad=True)\n", + ")\n" ] } ], @@ -195,18 +199,77 @@ "interact_with_functorch()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "TorchOpt also offers a wrapper `torchopt.FuncOptimizer` to make it easier to maintain the optimizer states." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "def interact_with_functorch_with_wrapper():\n", + " batch_size = 1\n", + " dim = 1\n", + " net = Net(dim)\n", + " model, params = functorch.make_functional(net) # get the functional version of the model\n", + "\n", + " learning_rate = 1.0\n", + " optimizer = torchopt.FuncOptimizer(torchopt.adam(learning_rate))\n", + "\n", + " xs = 2 * torch.ones((batch_size, dim))\n", + " ys = torch.ones((batch_size, 1))\n", + "\n", + " pred = model(params, xs)\n", + " loss = mse(pred, ys)\n", + "\n", + " print('Parameters before update:', params)\n", + " params = optimizer.step(loss, params)\n", + " print('Parameters after update:', params)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Parameters before update:\n", + "(\n", + " Parameter containing: tensor([[1.]], requires_grad=True),\n", + " Parameter containing: tensor([0.], requires_grad=True)\n", + ")\n", + "Parameters after update:\n", + "(\n", + " tensor([[6.6757e-06]], grad_fn=),\n", + " tensor([-1.0000], grad_fn=)\n", + ")\n" + ] + } + ], + "source": [ + "interact_with_functorch_with_wrapper()" + ] + }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.3 Full TorchOpt\n", "\n", - "The third example is to illustrate that TorchOpt can also directly replace `torch.optim` with exactly the same usage. Note the API difference happens between `torchopt.adam()` and `torchopt.Adam()`." + "`torchopt.Optimizer` is the base class for our PyTorch-like optimizer. Combined with the functional optimizer `torchopt.sgd` and `torchopt.adam`, we can define our high-level API `torchopt.SGD` and `torchopt.Adam`. The third example is to illustrate that TorchOpt can also directly replace `torch.optim` with exactly the same usage. Note the API difference happens between `torchopt.adam()` and `torchopt.Adam()`." ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -215,8 +278,11 @@ " dim = 1\n", " net = Net(dim)\n", "\n", - " learning_rate = 1.\n", + " learning_rate = 1.0\n", + " # High-level API\n", " optim = torchopt.Adam(net.parameters(), lr=learning_rate)\n", + " # Low-level API\n", + " optim = torchopt.Optimizer(net.parameters(), torchopt.adam(lr=learning_rate))\n", "\n", " xs = 2 * torch.ones((batch_size, dim))\n", " ys = torch.ones((batch_size, 1))\n", @@ -233,21 +299,23 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Parameters before update: {\n", + "Parameters before update:\n", + "{\n", " 'fc.weight': Parameter containing: tensor([[1.]], requires_grad=True),\n", " 'fc.bias': Parameter containing: tensor([0.], requires_grad=True)\n", "}\n", - "Parameters after update: {\n", - " 'fc.weight': Parameter containing: tensor([[0.]], requires_grad=True),\n", - " 'fc.bias': Parameter containing: tensor([-1.], requires_grad=True)\n", - "}" + "Parameters after update:\n", + "{\n", + " 'fc.weight': Parameter containing: tensor([[6.6757e-06]], requires_grad=True),\n", + " 'fc.bias': Parameter containing: tensor([-1.0000], requires_grad=True)\n", + "}\n" ] } ], @@ -266,7 +334,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -275,7 +343,7 @@ " dim = 1\n", " net = Net(dim)\n", "\n", - " learning_rate = 1.\n", + " learning_rate = 1.0\n", " optim = torch.optim.Adam(net.parameters(), lr=learning_rate)\n", "\n", " xs = 2 * torch.ones((batch_size, dim))\n", @@ -293,21 +361,23 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Parameters before update: {\n", + "Parameters before update:\n", + "{\n", " 'fc.weight': Parameter containing: tensor([[1.]], requires_grad=True),\n", " 'fc.bias': Parameter containing: tensor([0.], requires_grad=True)\n", "}\n", - "Parameters after update: {\n", + "Parameters after update:\n", + "{\n", " 'fc.weight': Parameter containing: tensor([[1.1921e-07]], requires_grad=True),\n", " 'fc.bias': Parameter containing: tensor([-1.0000], requires_grad=True)\n", - "}" + "}\n" ] } ], @@ -328,7 +398,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -342,7 +412,7 @@ " meta_param = nn.Parameter(torch.ones(1))\n", "\n", " # SGD example\n", - " learning_rate = 1.\n", + " learning_rate = 1.0\n", " optimizer = torchopt.sgd(learning_rate)\n", " opt_state = optimizer.init(params)\n", "\n", @@ -356,7 +426,8 @@ "\n", " grads = torch.autograd.grad(loss, params, create_graph=True)\n", " updates, opt_state = optimizer.update(grads, opt_state, inplace=False)\n", - " params = torchopt.apply_updates(params, updates, inplace=False) # update parameters with single step SGD update\n", + " # Update parameters with single step SGD update\n", + " params = torchopt.apply_updates(params, updates, inplace=False)\n", "\n", " pred = model(params, xs)\n", " loss = mse(pred, ys)\n", @@ -367,7 +438,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -393,29 +464,29 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ - "optim = torchopt.adam(lr=1., moment_requires_grad=False)" + "optim = torchopt.adam(lr=1.0, moment_requires_grad=False)" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ - "optim = torchopt.adam(lr=1., moment_requires_grad=True)" + "optim = torchopt.adam(lr=1.0, moment_requires_grad=True)" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ - "optim = torchopt.sgd(lr=1., momentum=0.8, moment_requires_grad=True)" + "optim = torchopt.sgd(lr=1.0, momentum=0.8, moment_requires_grad=True)" ] }, { @@ -436,7 +507,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -453,7 +524,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -470,27 +541,27 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "net = Net(1).cuda()\n", - "optim = torchopt.Adam(net.parameters(), lr=1., use_accelerated_op=True)" + "optim = torchopt.Adam(net.parameters(), lr=1.0, use_accelerated_op=True)" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ - "optim = torchopt.adam(lr=1., use_accelerated_op=True)" + "optim = torchopt.adam(lr=1.0, use_accelerated_op=True)" ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3.8.13 64-bit", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -504,7 +575,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.9.15" }, "vscode": { "interpreter": { diff --git a/tutorials/2_Visualization.ipynb b/tutorials/2_Visualization.ipynb index f1af008f..3141f522 100644 --- a/tutorials/2_Visualization.ipynb +++ b/tutorials/2_Visualization.ipynb @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "[](https://colab.research.google.com/drive/1Uoo2epqZKmJNQOiO0EU8DGd33AVKBlAq?usp=sharing)" + "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/2_Visualization.ipynb)" ] }, { @@ -37,12 +37,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "\n" + "\n" ] }, { "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n139996637621680\n\ny\n ()\n\n\n\n139993377217744\n\nMulBackward0\n\n\n\n139993377217744->139996637621680\n\n\n\n\n\n139993377217840\n\nAccumulateGrad\n\n\n\n139993377217840->139993377217744\n\n\n\n\n\n139996637619360\n\nx\n ()\n\n\n\n139996637619360->139993377217840\n\n\n\n\n\n" + "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140534064715952\n\ny\n()\n\n\n\n140534064838304\n\nMulBackward0\n\n\n\n140534064838304->140534064715952\n\n\n\n\n\n140534064837776\n\nAccumulateGrad\n\n\n\n140534064837776->140534064838304\n\n\n\n\n\n140534064714832\n\nx\n()\n\n\n\n140534064714832->140534064837776\n\n\n\n\n\n" }, "metadata": {}, "output_type": "display_data" @@ -58,7 +58,7 @@ "import torchopt\n", "\n", "\n", - "x = torch.tensor(1., requires_grad=True)\n", + "x = torch.tensor(1.0, requires_grad=True)\n", "y = 2 * x\n", "display(torchopt.visual.make_dot(y, params={'x': x, 'y': y}))" ] @@ -86,12 +86,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "\n" + "\n" ] }, { "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n139993376880096\n\nloss\n ()\n\n\n\n139996875678480\n\nMseLossBackward0\n\n\n\n139996875678480->139993376880096\n\n\n\n\n\n139996875677952\n\nAddmmBackward0\n\n\n\n139996875677952->139996875678480\n\n\n\n\n\n139996875678336\n\nAccumulateGrad\n\n\n\n139996875678336->139996875677952\n\n\n\n\n\n139993376879696\n\nfc.bias\n (1)\n\n\n\n139993376879696->139996875678336\n\n\n\n\n\n139996875678912\n\nTBackward0\n\n\n\n139996875678912->139996875677952\n\n\n\n\n\n139996875679152\n\nAccumulateGrad\n\n\n\n139996875679152->139996875678912\n\n\n\n\n\n139993376879616\n\nfc.weight\n (1, 5)\n\n\n\n139993376879616->139996875679152\n\n\n\n\n\n" + "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140534659780336\n\nloss\n()\n\n\n\n140531595570768\n\nMseLossBackward0\n\n\n\n140531595570768->140534659780336\n\n\n\n\n\n140531595570576\n\nAddmmBackward0\n\n\n\n140531595570576->140531595570768\n\n\n\n\n\n140531595570528\n\nAccumulateGrad\n\n\n\n140531595570528->140531595570576\n\n\n\n\n\n140531595583632\n\nfc.bias\n(1)\n\n\n\n140531595583632->140531595570528\n\n\n\n\n\n140531595571104\n\nTBackward0\n\n\n\n140531595571104->140531595570576\n\n\n\n\n\n140531595570432\n\nAccumulateGrad\n\n\n\n140531595570432->140531595571104\n\n\n\n\n\n140531595582816\n\nfc.weight\n(1, 5)\n\n\n\n140531595582816->140531595570432\n\n\n\n\n\n" }, "metadata": {}, "output_type": "display_data" @@ -122,7 +122,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The computation graph of meta learning algorithms will be much more complex. Our visualization tool allows users take as input the extracted network state for better visualization." + "The computation graph of meta-learning algorithms will be much more complex. Our visualization tool allows users take as input the extracted network state for better visualization." ] }, { @@ -134,12 +134,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "\n" + "\n" ] }, { "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n139993376892384\n\nloss\n ()\n\n\n\n139993376862752\n\nMseLossBackward0\n\n\n\n139993376862752->139993376892384\n\n\n\n\n\n139993376862800\n\nAddBackward0\n\n\n\n139993376862800->139993376862752\n\n\n\n\n\n139993376862896\n\nAddmmBackward0\n\n\n\n139993376862896->139993376862800\n\n\n\n\n\n139993377217840\n\nAddBackward0\n step1.fc.bias\n (1)\n\n\n\n139993377217840->139993376862896\n\n\n\n\n\n139993376863136\n\nAccumulateGrad\n\n\n\n139993376863136->139993377217840\n\n\n\n\n\n139993376863664\n\nAddmmBackward0\n\n\n\n139993376863136->139993376863664\n\n\n\n\n\n139993376891904\n\nstep0.fc.bias\n (1)\n\n\n\n139993376891904->139993376863136\n\n\n\n\n\n139993376863088\n\nMulBackward0\n\n\n\n139993376863088->139993377217840\n\n\n\n\n\n139993376863184\n\nViewBackward0\n\n\n\n139993376863184->139993376863088\n\n\n\n\n\n139993376863376\n\nSumBackward1\n\n\n\n139993376863376->139993376863184\n\n\n\n\n\n139993376863472\n\nMseLossBackwardBackward0\n\n\n\n139993376863472->139993376863376\n\n\n\n\n\n139993376864000\n\nTBackward0\n\n\n\n139993376863472->139993376864000\n\n\n\n\n\n139993376863568\n\nAddBackward0\n\n\n\n139993376863568->139993376863472\n\n\n\n\n\n139993376863664->139993376863568\n\n\n\n\n\n139993376863760\n\nTBackward0\n\n\n\n139993376863760->139993376863664\n\n\n\n\n\n139993376863856\n\nAccumulateGrad\n\n\n\n139993376863856->139993376863760\n\n\n\n\n\n139993377218464\n\nAddBackward0\n step1.fc.weight\n (1, 5)\n\n\n\n139993376863856->139993377218464\n\n\n\n\n\n139993376891664\n\nstep0.fc.weight\n (1, 5)\n\n\n\n139993376891664->139993376863856\n\n\n\n\n\n139993376862848\n\nAccumulateGrad\n\n\n\n139993376862848->139993376862800\n\n\n\n\n\n139993376862848->139993376863568\n\n\n\n\n\n139996637619600\n\nmeta_param\n ()\n\n\n\n139996637619600->139993376862848\n\n\n\n\n\n139993376863040\n\nTBackward0\n\n\n\n139993376863040->139993376862896\n\n\n\n\n\n139993377218464->139993376863040\n\n\n\n\n\n139993376863424\n\nMulBackward0\n\n\n\n139993376863424->139993377218464\n\n\n\n\n\n139993376863616\n\nTBackward0\n\n\n\n139993376863616->139993376863424\n\n\n\n\n\n139993376863808\n\nTBackward0\n\n\n\n139993376863808->139993376863616\n\n\n\n\n\n139993376863904\n\nMmBackward0\n\n\n\n139993376863904->139993376863808\n\n\n\n\n\n139993376864000->139993376863904\n\n\n\n\n\n" + "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140531595614064\n\nloss\n()\n\n\n\n140531595567168\n\nMseLossBackward0\n\n\n\n140531595567168->140531595614064\n\n\n\n\n\n140531595569232\n\nAddBackward0\n\n\n\n140531595569232->140531595567168\n\n\n\n\n\n140531595568800\n\nAddmmBackward0\n\n\n\n140531595568800->140531595569232\n\n\n\n\n\n140534660247264\n\nAddBackward0\nstep1.fc.bias\n(1)\n\n\n\n140534660247264->140531595568800\n\n\n\n\n\n140534553595376\n\nAccumulateGrad\n\n\n\n140534553595376->140534660247264\n\n\n\n\n\n140534553592832\n\nAddmmBackward0\n\n\n\n140534553595376->140534553592832\n\n\n\n\n\n140534064448352\n\nstep0.fc.bias\n(1)\n\n\n\n140534064448352->140534553595376\n\n\n\n\n\n140534553595616\n\nMulBackward0\n\n\n\n140534553595616->140534660247264\n\n\n\n\n\n140534553594848\n\nViewBackward0\n\n\n\n140534553594848->140534553595616\n\n\n\n\n\n140534553594992\n\nSumBackward1\n\n\n\n140534553594992->140534553594848\n\n\n\n\n\n140534553594800\n\nMseLossBackwardBackward0\n\n\n\n140534553594800->140534553594992\n\n\n\n\n\n140531595617904\n\nTBackward0\n\n\n\n140534553594800->140531595617904\n\n\n\n\n\n140534553593072\n\nAddBackward0\n\n\n\n140534553593072->140534553594800\n\n\n\n\n\n140534553592832->140534553593072\n\n\n\n\n\n140534553593456\n\nTBackward0\n\n\n\n140534553593456->140534553592832\n\n\n\n\n\n140534553593888\n\nAccumulateGrad\n\n\n\n140534553593888->140534553593456\n\n\n\n\n\n140531595572368\n\nAddBackward0\nstep1.fc.weight\n(1, 5)\n\n\n\n140534553593888->140531595572368\n\n\n\n\n\n140531595612944\n\nstep0.fc.weight\n(1, 5)\n\n\n\n140531595612944->140534553593888\n\n\n\n\n\n140531595567888\n\nAccumulateGrad\n\n\n\n140531595567888->140531595569232\n\n\n\n\n\n140531595567888->140534553593072\n\n\n\n\n\n140531595613184\n\nmeta_param\n()\n\n\n\n140531595613184->140531595567888\n\n\n\n\n\n140534553594272\n\nTBackward0\n\n\n\n140534553594272->140531595568800\n\n\n\n\n\n140531595572368->140534553594272\n\n\n\n\n\n140534553593504\n\nMulBackward0\n\n\n\n140534553593504->140531595572368\n\n\n\n\n\n140534553592976\n\nTBackward0\n\n\n\n140534553592976->140534553593504\n\n\n\n\n\n140534553593216\n\nTBackward0\n\n\n\n140534553593216->140534553592976\n\n\n\n\n\n140534553593552\n\nMmBackward0\n\n\n\n140534553593552->140534553593216\n\n\n\n\n\n140531595617904->140534553593552\n\n\n\n\n\n" }, "metadata": {}, "output_type": "display_data" @@ -163,7 +163,7 @@ "ys = torch.ones((batch_size, 1))\n", "\n", "optimizer = torchopt.MetaSGD(net, lr=1e-3)\n", - "meta_param = torch.tensor(1., requires_grad=True)\n", + "meta_param = torch.tensor(1.0, requires_grad=True)\n", "\n", "# Set enable_visual\n", "net_state_0 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step0.')\n", @@ -179,13 +179,17 @@ "loss = F.mse_loss(pred, torch.ones_like(pred))\n", "\n", "# Draw computation graph\n", - "display(torchopt.visual.make_dot(loss, [net_state_0, net_state_1, {'meta_param': meta_param, 'loss': loss}]))" + "display(\n", + " torchopt.visual.make_dot(\n", + " loss, [net_state_0, net_state_1, {'meta_param': meta_param, 'loss': loss}]\n", + " )\n", + ")" ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3.8.13 ('torchopt')", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -199,7 +203,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.9.15" }, "vscode": { "interpreter": { diff --git a/tutorials/3_Meta_Optimizer.ipynb b/tutorials/3_Meta_Optimizer.ipynb index aaca9e3f..d50ace2d 100644 --- a/tutorials/3_Meta_Optimizer.ipynb +++ b/tutorials/3_Meta_Optimizer.ipynb @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "[](https://colab.research.google.com/drive/1lo9q2gQz073urYln-4Yub5s8APUoHvQJ?usp=sharing)" + "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/3_Meta_Optimizer.ipynb)" ] }, { @@ -34,7 +34,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Assume a tensor $x$ is a meta parameter and $a$ is a normal parameters (such as network parameters). We have inner loss $\\mathcal{L}^{\\textrm{in}} = a_0 \\cdot x^2$ and we update $a$ use the gradient $\\frac{\\partial \\mathcal{L}^{\\textrm{in}}}{\\partial a_0} = x^2$ and $a_1 = a_0 - \\eta \\, \\frac{\\partial \\mathcal{L}^{\\textrm{in}}}{\\partial a_0} = a_0 - \\eta \\, x^2$. Then we compute the outer loss $\\mathcal{L}^{\\textrm{out}} = a_1 \\cdot x^2$. So the gradient of outer loss to $x$ would be:\n", + "Assume a tensor $x$ is a meta-parameter and $a$ is a normal parameters (such as network parameters). We have inner loss $\\mathcal{L}^{\\textrm{in}} = a_0 \\cdot x^2$ and we update $a$ use the gradient $\\frac{\\partial \\mathcal{L}^{\\textrm{in}}}{\\partial a_0} = x^2$ and $a_1 = a_0 - \\eta \\, \\frac{\\partial \\mathcal{L}^{\\textrm{in}}}{\\partial a_0} = a_0 - \\eta \\, x^2$. Then we compute the outer loss $\\mathcal{L}^{\\textrm{out}} = a_1 \\cdot x^2$. So the gradient of outer loss to $x$ would be:\n", "\n", "$$\n", "\\begin{split}\n", @@ -73,17 +73,17 @@ "class Net(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", - " self.a = nn.Parameter(torch.tensor(1.), requires_grad=True)\n", - " \n", + " self.a = nn.Parameter(torch.tensor(1.0), requires_grad=True)\n", + "\n", " def forward(self, x):\n", - " return self.a * (x ** 2)" + " return self.a * (x**2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Then we declare the network (parameterized by `a`) and the meta parameter `x`. Do not forget to set flag `requires_grad=True` for `x`." + "Then we declare the network (parameterized by `a`) and the meta-parameter `x`. Do not forget to set flag `requires_grad=True` for `x`." ] }, { @@ -93,20 +93,40 @@ "outputs": [], "source": [ "net = Net()\n", - "x = nn.Parameter(torch.tensor(2.), requires_grad=True)" + "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Next we declare the meta optimizer. The meta optimizer takes as input the network and use method `step` to update the network (parameterized by `a`)." + "Next we declare the meta-optimizer. Here we show two equivalent ways of defining the meta-optimizer. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, + "outputs": [], + "source": [ + "# Low-level API\n", + "optim = torchopt.MetaOptimizer(net, torchopt.sgd(lr=1.0))\n", + "\n", + "# High level API\n", + "optim = torchopt.MetaSGD(net, lr=1.0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The meta-optimizer takes the network as input and use method `step` to update the network (parameterized by `a`). Finally, we show how a bi-level process works." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -117,8 +137,6 @@ } ], "source": [ - "optim = torchopt.MetaSGD(net, lr=1.)\n", - "\n", "inner_loss = net(x)\n", "optim.step(inner_loss)\n", "\n", @@ -137,7 +155,7 @@ "source": [ "### 1.1 Track the Gradient of Momentum\n", "\n", - "Note that most modern optimizers involve moment term in the gradient update (basically only SGD with `momentum = 0` does not involve). We provide an option for user to choose whether to also track the meta-gradient through moment term. The default option is `moment_requires_grad=True`." + "Note that most modern optimizers involve moment term in the gradient update (basically only SGD with `momentum=0` does not involve). We provide an option for user to choose whether to also track the meta-gradient through moment term. The default option is `moment_requires_grad=True`." ] }, { @@ -149,19 +167,19 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\n" + "\n" ] }, { "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140393111569088\n\nouter_loss\n ()\n\n\n\n140393111544592\n\nMseLossBackward0\n\n\n\n140393111544592->140393111569088\n\n\n\n\n\n140393111544736\n\nMulBackward0\n\n\n\n140393111544736->140393111544592\n\n\n\n\n\n140396237940576\n\nAddBackward0\n step1.a\n ()\n\n\n\n140396237940576->140393111544736\n\n\n\n\n\n140393111545216\n\nAccumulateGrad\n\n\n\n140393111545216->140396237940576\n\n\n\n\n\n140393111545984\n\nMulBackward0\n\n\n\n140393111545216->140393111545984\n\n\n\n\n\n140393111534464\n\nstep0.a\n ()\n\n\n\n140393111534464->140393111545216\n\n\n\n\n\n140393111544112\n\nMulBackward0\n\n\n\n140393111544112->140396237940576\n\n\n\n\n\n140393111545168\n\nDivBackward0\n\n\n\n140393111545168->140393111544112\n\n\n\n\n\n140393111545408\n\nDivBackward0\n\n\n\n140393111545408->140393111545168\n\n\n\n\n\n140393111545552\n\nAddBackward0\n\n\n\n140393111545552->140393111545408\n\n\n\n\n\n140393111545648\n\nPowBackward0\n\n\n\n140393111545648->140393111545552\n\n\n\n\n\n140393111545744\n\nMulBackward0\n\n\n\n140393111545744->140393111545648\n\n\n\n\n\n140393111546272\n\nPowBackward0\n\n\n\n140393111545744->140393111546272\n\n\n\n\n\n140393111545840\n\nMseLossBackwardBackward0\n\n\n\n140393111545840->140393111545744\n\n\n\n\n\n140393111545984->140393111545840\n\n\n\n\n\n140393111545792\n\nPowBackward0\n\n\n\n140393111545792->140393111545744\n\n\n\n\n\n140393111545792->140393111545984\n\n\n\n\n\n140393111546128\n\nAccumulateGrad\n\n\n\n140393111546128->140393111545792\n\n\n\n\n\n140393111545024\n\nPowBackward0\n\n\n\n140393111546128->140393111545024\n\n\n\n\n\n140393111534624\n\nx\n ()\n\n\n\n140393111534624->140393111546128\n\n\n\n\n\n140393111545360\n\nAddBackward0\n\n\n\n140393111545360->140393111545168\n\n\n\n\n\n140393111545696\n\nSqrtBackward0\n\n\n\n140393111545696->140393111545360\n\n\n\n\n\n140393111545936\n\nAddBackward0\n\n\n\n140393111545936->140393111545696\n\n\n\n\n\n140393111545888\n\nDivBackward0\n\n\n\n140393111545888->140393111545936\n\n\n\n\n\n140393111546176\n\nAddBackward0\n\n\n\n140393111546176->140393111545888\n\n\n\n\n\n140393111546272->140393111546176\n\n\n\n\n\n140393111545024->140393111544736\n\n\n\n\n\n" + "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140447553047184\n\nouter_loss\n()\n\n\n\n140447553041216\n\nMseLossBackward0\n\n\n\n140447553041216->140447553047184\n\n\n\n\n\n140447553042896\n\nMulBackward0\n\n\n\n140447553042896->140447553041216\n\n\n\n\n\n140447553019088\n\nAddBackward0\nstep1.a\n()\n\n\n\n140447553019088->140447553042896\n\n\n\n\n\n140447553041072\n\nAccumulateGrad\n\n\n\n140447553041072->140447553019088\n\n\n\n\n\n140447553043664\n\nMulBackward0\n\n\n\n140447553041072->140447553043664\n\n\n\n\n\n140447553045344\n\nstep0.a\n()\n\n\n\n140447553045344->140447553041072\n\n\n\n\n\n140447553041120\n\nMulBackward0\n\n\n\n140447553041120->140447553019088\n\n\n\n\n\n140447553043040\n\nDivBackward0\n\n\n\n140447553043040->140447553041120\n\n\n\n\n\n140447553043184\n\nDivBackward0\n\n\n\n140447553043184->140447553043040\n\n\n\n\n\n140447553043328\n\nAddBackward0\n\n\n\n140447553043328->140447553043184\n\n\n\n\n\n140447553043424\n\nMulBackward0\n\n\n\n140447553043424->140447553043328\n\n\n\n\n\n140447553043856\n\nAddcmulBackward0\n\n\n\n140447553043424->140447553043856\n\n\n\n\n\n140447553043424->140447553043856\n\n\n\n\n\n140447553043520\n\nMseLossBackwardBackward0\n\n\n\n140447553043520->140447553043424\n\n\n\n\n\n140447553043664->140447553043520\n\n\n\n\n\n140447553043472\n\nPowBackward0\n\n\n\n140447553043472->140447553043424\n\n\n\n\n\n140447553043472->140447553043664\n\n\n\n\n\n140447553043808\n\nAccumulateGrad\n\n\n\n140447553043808->140447553043472\n\n\n\n\n\n140447553041264\n\nPowBackward0\n\n\n\n140447553043808->140447553041264\n\n\n\n\n\n140447553045584\n\nx\n()\n\n\n\n140447553045584->140447553043808\n\n\n\n\n\n140447553043136\n\nAddBackward0\n\n\n\n140447553043136->140447553043040\n\n\n\n\n\n140447553043232\n\nSqrtBackward0\n\n\n\n140447553043232->140447553043136\n\n\n\n\n\n140447553043760\n\nAddBackward0\n\n\n\n140447553043760->140447553043232\n\n\n\n\n\n140447553043904\n\nDivBackward0\n\n\n\n140447553043904->140447553043760\n\n\n\n\n\n140447553043856->140447553043904\n\n\n\n\n\n140447553041264->140447553042896\n\n\n\n\n\n" }, "metadata": {}, "output_type": "display_data" @@ -169,10 +187,10 @@ ], "source": [ "net = Net()\n", - "x = nn.Parameter(torch.tensor(2.), requires_grad=True)\n", - "y = torch.tensor(1.)\n", + "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n", + "y = torch.tensor(1.0)\n", "\n", - "optim = torchopt.MetaAdam(net, lr=1., moment_requires_grad=False)\n", + "optim = torchopt.MetaAdam(net, lr=1.0, moment_requires_grad=False)\n", "\n", "net_state_0 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step0.')\n", "inner_loss = F.mse_loss(net(x), y)\n", @@ -180,7 +198,11 @@ "net_state_1 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step1.')\n", "\n", "outer_loss = F.mse_loss(net(x), y)\n", - "display(torchopt.visual.make_dot(outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]))" + "display(\n", + " torchopt.visual.make_dot(\n", + " outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n", + " )\n", + ")" ] }, { @@ -192,19 +214,19 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\n" + "\n" ] }, { "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140393102737552\n\nouter_loss\n ()\n\n\n\n140393111544400\n\nMseLossBackward0\n\n\n\n140393111544400->140393102737552\n\n\n\n\n\n140393111544304\n\nMulBackward0\n\n\n\n140393111544304->140393111544400\n\n\n\n\n\n140396584753232\n\nAddBackward0\n step1.a\n ()\n\n\n\n140396584753232->140393111544304\n\n\n\n\n\n140393111544016\n\nAccumulateGrad\n\n\n\n140393111544016->140396584753232\n\n\n\n\n\n140393111547280\n\nMulBackward0\n\n\n\n140393111544016->140393111547280\n\n\n\n\n\n140393111570848\n\nstep0.a\n ()\n\n\n\n140393111570848->140393111544016\n\n\n\n\n\n140393111544256\n\nMulBackward0\n\n\n\n140393111544256->140396584753232\n\n\n\n\n\n140393111544160\n\nDivBackward0\n\n\n\n140393111544160->140393111544256\n\n\n\n\n\n140393111546512\n\nDivBackward0\n\n\n\n140393111546512->140393111544160\n\n\n\n\n\n140393111544112\n\nAddBackward0\n\n\n\n140393111544112->140393111546512\n\n\n\n\n\n140393111546368\n\nMulBackward0\n\n\n\n140393111546368->140393111544112\n\n\n\n\n\n140393111547040\n\nAccumulateGrad\n\n\n\n140393111547040->140393111546368\n\n\n\n\n\n140393111569408\n\n ()\n\n\n\n140393111569408->140393111547040\n\n\n\n\n\n140393111546272\n\nPowBackward0\n\n\n\n140393111546272->140393111544112\n\n\n\n\n\n140393111547088\n\nMulBackward0\n\n\n\n140393111547088->140393111546272\n\n\n\n\n\n140393111547328\n\nPowBackward0\n\n\n\n140393111547088->140393111547328\n\n\n\n\n\n140393111547184\n\nMseLossBackwardBackward0\n\n\n\n140393111547184->140393111547088\n\n\n\n\n\n140393111547280->140393111547184\n\n\n\n\n\n140393111546944\n\nPowBackward0\n\n\n\n140393111546944->140393111547088\n\n\n\n\n\n140393111546944->140393111547280\n\n\n\n\n\n140393111546320\n\nAccumulateGrad\n\n\n\n140393111546320->140393111546944\n\n\n\n\n\n140393111544208\n\nPowBackward0\n\n\n\n140393111546320->140393111544208\n\n\n\n\n\n140393111571168\n\nx\n ()\n\n\n\n140393111571168->140393111546320\n\n\n\n\n\n140393111546848\n\nAddBackward0\n\n\n\n140393111546848->140393111544160\n\n\n\n\n\n140393111547136\n\nSqrtBackward0\n\n\n\n140393111547136->140393111546848\n\n\n\n\n\n140393111547232\n\nAddBackward0\n\n\n\n140393111547232->140393111547136\n\n\n\n\n\n140393111545360\n\nDivBackward0\n\n\n\n140393111545360->140393111547232\n\n\n\n\n\n140393111547424\n\nAddBackward0\n\n\n\n140393111547424->140393111545360\n\n\n\n\n\n140393111547520\n\nMulBackward0\n\n\n\n140393111547520->140393111547424\n\n\n\n\n\n140393111547616\n\nAccumulateGrad\n\n\n\n140393111547616->140393111547520\n\n\n\n\n\n140393111570288\n\n ()\n\n\n\n140393111570288->140393111547616\n\n\n\n\n\n140393111547328->140393111547424\n\n\n\n\n\n140393111544208->140393111544304\n\n\n\n\n\n" + "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140447553148704\n\nouter_loss\n()\n\n\n\n140447553041024\n\nMseLossBackward0\n\n\n\n140447553041024->140447553148704\n\n\n\n\n\n140447553043424\n\nMulBackward0\n\n\n\n140447553043424->140447553041024\n\n\n\n\n\n140450536407152\n\nAddBackward0\nstep1.a\n()\n\n\n\n140450536407152->140447553043424\n\n\n\n\n\n140447553041264\n\nAccumulateGrad\n\n\n\n140447553041264->140450536407152\n\n\n\n\n\n140447553019232\n\nMulBackward0\n\n\n\n140447553041264->140447553019232\n\n\n\n\n\n140447553148064\n\nstep0.a\n()\n\n\n\n140447553148064->140447553041264\n\n\n\n\n\n140447553041216\n\nMulBackward0\n\n\n\n140447553041216->140450536407152\n\n\n\n\n\n140447553041312\n\nDivBackward0\n\n\n\n140447553041312->140447553041216\n\n\n\n\n\n140447553041408\n\nDivBackward0\n\n\n\n140447553041408->140447553041312\n\n\n\n\n\n140447553043376\n\nAddBackward0\n\n\n\n140447553043376->140447553041408\n\n\n\n\n\n140447553041168\n\nMulBackward0\n\n\n\n140447553041168->140447553043376\n\n\n\n\n\n140447553042272\n\nAccumulateGrad\n\n\n\n140447553042272->140447553041168\n\n\n\n\n\n140450290826352\n\n()\n\n\n\n140450290826352->140447553042272\n\n\n\n\n\n140447553044432\n\nMulBackward0\n\n\n\n140447553044432->140447553043376\n\n\n\n\n\n140447553018320\n\nAddcmulBackward0\n\n\n\n140447553044432->140447553018320\n\n\n\n\n\n140447553044432->140447553018320\n\n\n\n\n\n140447553042080\n\nMseLossBackwardBackward0\n\n\n\n140447553042080->140447553044432\n\n\n\n\n\n140447553019232->140447553042080\n\n\n\n\n\n140447553019088\n\nPowBackward0\n\n\n\n140447553019088->140447553044432\n\n\n\n\n\n140447553019088->140447553019232\n\n\n\n\n\n140447553018464\n\nAccumulateGrad\n\n\n\n140447553018464->140447553019088\n\n\n\n\n\n140447553043328\n\nPowBackward0\n\n\n\n140447553018464->140447553043328\n\n\n\n\n\n140447553148144\n\nx\n()\n\n\n\n140447553148144->140447553018464\n\n\n\n\n\n140447553041456\n\nAddBackward0\n\n\n\n140447553041456->140447553041312\n\n\n\n\n\n140447553041360\n\nSqrtBackward0\n\n\n\n140447553041360->140447553041456\n\n\n\n\n\n140447553015920\n\nAddBackward0\n\n\n\n140447553015920->140447553041360\n\n\n\n\n\n140447553018560\n\nDivBackward0\n\n\n\n140447553018560->140447553015920\n\n\n\n\n\n140447553018320->140447553018560\n\n\n\n\n\n140447553018272\n\nMulBackward0\n\n\n\n140447553018272->140447553018320\n\n\n\n\n\n140447553018944\n\nAccumulateGrad\n\n\n\n140447553018944->140447553018272\n\n\n\n\n\n140450290824272\n\n()\n\n\n\n140450290824272->140447553018944\n\n\n\n\n\n140447553043328->140447553043424\n\n\n\n\n\n" }, "metadata": {}, "output_type": "display_data" @@ -212,10 +234,10 @@ ], "source": [ "net = Net()\n", - "x = nn.Parameter(torch.tensor(2.), requires_grad=True)\n", - "y = torch.tensor(1.)\n", + "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n", + "y = torch.tensor(1.0)\n", "\n", - "optim = torchopt.MetaAdam(net, lr=1., moment_requires_grad=True)\n", + "optim = torchopt.MetaAdam(net, lr=1.0, moment_requires_grad=True)\n", "\n", "net_state_0 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step0.')\n", "inner_loss = F.mse_loss(net(x), y)\n", @@ -223,14 +245,18 @@ "net_state_1 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step1.')\n", "\n", "outer_loss = F.mse_loss(net(x), y)\n", - "display(torchopt.visual.make_dot(outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]))" + "display(\n", + " torchopt.visual.make_dot(\n", + " outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n", + " )\n", + ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We can see that the additional moment terms are added into the computational graph when we set `moment_requires_grad = True`." + "We can see that the additional moment terms are added into the computational graph when we set `moment_requires_grad=True`." ] }, { @@ -248,36 +274,42 @@ "\n", "We observe that how to reinitialize the inner-loop parameter in a new bi-level process vary in different meta-learning algorithms. For instance, in algorithm like Model-Agnostic Meta-Learning (MAML) ([arXiv:1703.03400](https://arxiv.org/abs/1703.03400)), every time a new task comes, we need to reset the parameters to the initial ones. In other cases such as Meta-Gradient Reinforcement Learning (MGRL) ([arXiv:1805.09801](https://arxiv.org/abs/1805.09801)), the inner-loop network parameter just inherit previous updated parameter to continue the new bi-level process.\n", "\n", - "We provide the `torchopt.extract_state_dict` and `torchopt.recover_state_dict` functions to extract and restore the state of network and optimizer. By default, the extracted state dictionary is a reference (this design is for accumulating gradient of multi-task batch training, MAML for example). You can also set `copy=True` to extract the copy of state dictionary." + "We provide the `torchopt.extract_state_dict` and `torchopt.recover_state_dict` functions to extract and restore the state of network and optimizer. By default, the extracted state dictionary is a reference (this design is for accumulating gradient of multi-task batch training, MAML for example). You can also set `by='copy'` to extract the copy of state dictionary or set `by='deepcopy'` to have a detached copy." ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "a = tensor(-1., grad_fn=)\n", - "a = tensor(-1., grad_fn=)\n" + "a = tensor(-1.0000, grad_fn=)\n", + "a = tensor(-1.0000, grad_fn=)\n" ] } ], "source": [ "net = Net()\n", - "x = nn.Parameter(torch.tensor(2.), requires_grad=True)\n", + "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n", "\n", - "optim = torchopt.MetaAdam(net, lr=1.)\n", + "optim = torchopt.MetaAdam(net, lr=1.0)\n", "\n", "# Get the reference of state dictionary\n", - "init_net_state = torchopt.extract_state_dict(net)\n", - "init_optim_state = torchopt.extract_state_dict(optim)\n", + "init_net_state = torchopt.extract_state_dict(net, by='reference')\n", + "init_optim_state = torchopt.extract_state_dict(optim, by='reference')\n", + "# If set `detach_buffers=True`, the parameters are referenced as references while buffers are detached copies\n", + "init_net_state = torchopt.extract_state_dict(net, by='reference', detach_buffers=True)\n", + "\n", + "# Set `copy` to get the copy of state dictionary\n", + "init_net_state_copy = torchopt.extract_state_dict(net, by='copy')\n", + "init_optim_state_copy = torchopt.extract_state_dict(optim, by='copy')\n", "\n", - "# Set `copy=True` to get the copy of state dictionary\n", - "init_net_state_copy = torchopt.extract_state_dict(net, copy=True)\n", - "init_optim_state_copy = torchopt.extract_state_dict(optim, copy=True)\n", + "# Set `deepcopy` to get the detached copy of state dictionary\n", + "init_net_state_deepcopy = torchopt.extract_state_dict(net, by='deepcopy')\n", + "init_optim_state_deepcopy = torchopt.extract_state_dict(optim, by='deepcopy')\n", "\n", "# Conduct 2 inner-loop optimization\n", "for i in range(2):\n", @@ -303,9 +335,9 @@ "source": [ "### 2.2 Multi-task Example with `extract_state_dict` and `recover_state_dict`\n", "\n", - "Let's move to another more complex setting. Meta Learning algorithms always fix network on several different tasks and accumulate outer loss of each task to the meta gradient.\n", + "Let's move to another more complex setting. Meta-Learning algorithms always fix network on several different tasks and accumulate outer loss of each task to the meta-gradient.\n", "\n", - "Assume $x$ is a meta parameter and $a$ is a normal parameter. We firstly update $a$ use inner loss $\\mathcal{L}_1^{\\textrm{in}} = a_0 \\cdot x^2$ to $a_1$. Then we use $a_1$ to compute the outer loss $\\mathcal{L}_1^{\\textrm{out}} = a_1 \\cdot x^2$ and back-propagate it. Then we use $a_0$ to compute the inner loss $\\mathcal{L}_2^{\\textrm{in}} = a_0 \\cdot x$ and update $a_0$ to $a_2 = a_0 - \\eta \\, \\frac{\\partial \\mathcal{L}_2^{\\textrm{in}}}{\\partial a_0} = a_0 - \\eta \\, x$. Then we compute outer loss $\\mathcal{L}_2^{\\textrm{out}} = a_2 \\cdot x$ and back-propagate it. So the accumulated meta gradient would be:\n", + "Assume $x$ is a meta-parameter and $a$ is a normal parameter. We firstly update $a$ use inner loss $\\mathcal{L}_1^{\\textrm{in}} = a_0 \\cdot x^2$ to $a_1$. Then we use $a_1$ to compute the outer loss $\\mathcal{L}_1^{\\textrm{out}} = a_1 \\cdot x^2$ and backpropagate it. Then we use $a_0$ to compute the inner loss $\\mathcal{L}_2^{\\textrm{in}} = a_0 \\cdot x$ and update $a_0$ to $a_2 = a_0 - \\eta \\, \\frac{\\partial \\mathcal{L}_2^{\\textrm{in}}}{\\partial a_0} = a_0 - \\eta \\, x$. Then we compute outer loss $\\mathcal{L}_2^{\\textrm{out}} = a_2 \\cdot x$ and backpropagate it. So the accumulated meta-gradient would be:\n", "\n", "$$\n", "\\begin{split}\n", @@ -328,26 +360,26 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "class Net2Tasks(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", - " self.a = nn.Parameter(torch.tensor(1.), requires_grad=True)\n", - " \n", + " self.a = nn.Parameter(torch.tensor(1.0), requires_grad=True)\n", + "\n", " def task1(self, x):\n", - " return self.a * x ** 2\n", - " \n", + " return self.a * x**2\n", + "\n", " def task2(self, x):\n", " return self.a * x\n", "\n", "\n", "net = Net2Tasks()\n", - "x = nn.Parameter(torch.tensor(2.), requires_grad=True)\n", + "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n", "\n", - "optim = torchopt.MetaSGD(net, lr=1.)" + "optim = torchopt.MetaSGD(net, lr=1.0)" ] }, { @@ -359,14 +391,14 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "init_optim_state = ((EmptyState(), EmptyState()),)\n", + "init_optim_state = ((EmptyState(),),)\n", "Task 1: x.grad = tensor(-28.)\n", "Accumulated: x.grad = tensor(-31.)\n" ] @@ -374,8 +406,8 @@ ], "source": [ "# Get the reference of state dictionary\n", - "init_net_state = torchopt.extract_state_dict(net)\n", - "init_optim_state = torchopt.extract_state_dict(optim)\n", + "init_net_state = torchopt.extract_state_dict(net, by='reference')\n", + "init_optim_state = torchopt.extract_state_dict(optim, by='reference')\n", "# The `state_dict` is empty for vanilla SGD optimizer\n", "print(f'init_optim_state = {init_optim_state!r}')\n", "\n", @@ -430,7 +462,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -443,9 +475,12 @@ ], "source": [ "net = Net()\n", - "x = nn.Parameter(torch.tensor(2.), requires_grad=True)\n", + "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n", "\n", - "optim_impl = torchopt.combine.chain(torchopt.clip.clip_grad_norm(max_norm=2.), torchopt.sgd(lr=1., moment_requires_grad=True))\n", + "optim_impl = torchopt.combine.chain(\n", + " torchopt.clip.clip_grad_norm(max_norm=2.0),\n", + " torchopt.sgd(lr=1.0, moment_requires_grad=True),\n", + ")\n", "optim = torchopt.MetaOptimizer(net, optim_impl)\n", "\n", "inner_loss = net(x)\n", @@ -465,9 +500,45 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 4. Accelerated Optimizer\n", + "## 4. Learning Rate Scheduler\n", + "\n", + "TorchOpt also provides implementation of learning rate scheduler, which can be used as:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "functional_adam = torchopt.adam(\n", + " lr=torchopt.schedule.linear_schedule(\n", + " init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n", + " )\n", + ")\n", + "\n", + "adam = torchopt.Adam(\n", + " net.parameters(),\n", + " lr=torchopt.schedule.linear_schedule(\n", + " init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n", + " ),\n", + ")\n", + "\n", + "meta_adam = torchopt.MetaAdam(\n", + " net,\n", + " lr=torchopt.schedule.linear_schedule(\n", + " init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Accelerated Optimizer\n", "\n", - "Users can use accelerated optimizer by setting the `use_accelerated_op` as `True`. Currently we only support the Adam optimizer." + "Users can use accelerated optimizer by setting the `use_accelerated_op=True`. Currently we only support the Adam optimizer." ] }, { @@ -479,7 +550,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -496,7 +567,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -513,19 +584,19 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\n" + "\n" ] }, { "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140393102828544\n\nouter_loss\n ()\n\n\n\n140393111546128\n\nMseLossBackward0\n\n\n\n140393111546128->140393102828544\n\n\n\n\n\n140393111546032\n\nMulBackward0\n\n\n\n140393111546032->140393111546128\n\n\n\n\n\n140396237940288\n\nAddBackward0\n step1.a\n ()\n\n\n\n140396237940288->140393111546032\n\n\n\n\n\n140393111546464\n\nAccumulateGrad\n\n\n\n140393111546464->140396237940288\n\n\n\n\n\n140393102725760\n\nMulBackward0\n\n\n\n140393111546464->140393102725760\n\n\n\n\n\n140393102827744\n\nstep0.a\n ()\n\n\n\n140393102827744->140393111546464\n\n\n\n\n\n140393102725232\n\nMulBackward0\n\n\n\n140393102725232->140396237940288\n\n\n\n\n\n140393112318976\n\nUpdatesOpBackward\n\n\n\n140393112318976->140393102725232\n\n\n\n\n\n140396647894368\n\nMuOpBackward\n\n\n\n140396647894368->140393112318976\n\n\n\n\n\n140393102725472\n\nMulBackward0\n\n\n\n140393102725472->140396647894368\n\n\n\n\n\n140393112318736\n\nNuOpBackward\n\n\n\n140393102725472->140393112318736\n\n\n\n\n\n140393102725616\n\nMseLossBackwardBackward0\n\n\n\n140393102725616->140393102725472\n\n\n\n\n\n140393102725760->140393102725616\n\n\n\n\n\n140393102725568\n\nPowBackward0\n\n\n\n140393102725568->140393102725472\n\n\n\n\n\n140393102725568->140393102725760\n\n\n\n\n\n140393102725904\n\nAccumulateGrad\n\n\n\n140393102725904->140393102725568\n\n\n\n\n\n140393111543968\n\nPowBackward0\n\n\n\n140393102725904->140393111543968\n\n\n\n\n\n140393111485872\n\nx\n ()\n\n\n\n140393111485872->140393102725904\n\n\n\n\n\n140393102725328\n\nAccumulateGrad\n\n\n\n140393102725328->140396647894368\n\n\n\n\n\n140393111534224\n\n ()\n\n\n\n140393111534224->140396647894368\n\n\n\n\n\n140393111534224->140393102725328\n\n\n\n\n\n140393111531904\n\n ()\n\n\n\n140393111531904->140396647894368\n\n\n\n\n\n140393111531904->140393112318736\n\n\n\n\n\n140393112318736->140393112318976\n\n\n\n\n\n140393102725712\n\nAccumulateGrad\n\n\n\n140393102725712->140393112318736\n\n\n\n\n\n140393102827824\n\n ()\n\n\n\n140393102827824->140393112318736\n\n\n\n\n\n140393102827824->140393102725712\n\n\n\n\n\n140393102828784\n\n ()\n\n\n\n140393102828784->140393112318976\n\n\n\n\n\n140393102828144\n\n ()\n\n\n\n140393102828144->140393112318976\n\n\n\n\n\n140393102828224\n\n ()\n\n\n\n140393102828224->140393112318976\n\n\n\n\n\n140393111543968->140393111546032\n\n\n\n\n\n" + "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140450290825712\n\nouter_loss\n()\n\n\n\n140450533650240\n\nMseLossBackward0\n\n\n\n140450533650240->140450290825712\n\n\n\n\n\n140450533648560\n\nMulBackward0\n\n\n\n140450533648560->140450533650240\n\n\n\n\n\n140450533647456\n\nAddBackward0\nstep1.a\n()\n\n\n\n140450533647456->140450533648560\n\n\n\n\n\n140447435136640\n\nAccumulateGrad\n\n\n\n140447435136640->140450533647456\n\n\n\n\n\n140450533648416\n\nMulBackward0\n\n\n\n140447435136640->140450533648416\n\n\n\n\n\n140447435236512\n\nstep0.a\n()\n\n\n\n140447435236512->140447435136640\n\n\n\n\n\n140447435136688\n\nMulBackward0\n\n\n\n140447435136688->140450533647456\n\n\n\n\n\n140447554132144\n\nUpdatesOpBackward\n\n\n\n140447554132144->140447435136688\n\n\n\n\n\n140447554131664\n\nMuOpBackward\n\n\n\n140447554131664->140447554132144\n\n\n\n\n\n140447435134816\n\nMulBackward0\n\n\n\n140447435134816->140447554131664\n\n\n\n\n\n140447554131904\n\nNuOpBackward\n\n\n\n140447435134816->140447554131904\n\n\n\n\n\n140450533648992\n\nMseLossBackwardBackward0\n\n\n\n140450533648992->140447435134816\n\n\n\n\n\n140450533648416->140450533648992\n\n\n\n\n\n140450533646448\n\nPowBackward0\n\n\n\n140450533646448->140447435134816\n\n\n\n\n\n140450533646448->140450533648416\n\n\n\n\n\n140447553018176\n\nAccumulateGrad\n\n\n\n140447553018176->140450533646448\n\n\n\n\n\n140447435135536\n\nPowBackward0\n\n\n\n140447553018176->140447435135536\n\n\n\n\n\n140447553045424\n\nx\n()\n\n\n\n140447553045424->140447553018176\n\n\n\n\n\n140447435136592\n\nAccumulateGrad\n\n\n\n140447435136592->140447554131664\n\n\n\n\n\n140447552973856\n\n()\n\n\n\n140447552973856->140447554131664\n\n\n\n\n\n140447552973856->140447435136592\n\n\n\n\n\n140447553044544\n\n()\n\n\n\n140447553044544->140447554131664\n\n\n\n\n\n140447553044544->140447554131904\n\n\n\n\n\n140447554131904->140447554132144\n\n\n\n\n\n140450533648896\n\nAccumulateGrad\n\n\n\n140450533648896->140447554131904\n\n\n\n\n\n140447435236752\n\n()\n\n\n\n140447435236752->140447554131904\n\n\n\n\n\n140447435236752->140450533648896\n\n\n\n\n\n140447553045904\n\n()\n\n\n\n140447553045904->140447554132144\n\n\n\n\n\n140447435237152\n\n()\n\n\n\n140447435237152->140447554132144\n\n\n\n\n\n140447435237232\n\n()\n\n\n\n140447435237232->140447554132144\n\n\n\n\n\n140447435135536->140450533648560\n\n\n\n\n\n" }, "metadata": {}, "output_type": "display_data" @@ -533,24 +604,89 @@ ], "source": [ "net = Net().to(device='cuda')\n", - "x = nn.Parameter(torch.tensor(2., device=torch.device('cuda')), requires_grad=True)\n", - "y = torch.tensor(1., device=torch.device('cuda'))\n", + "x = nn.Parameter(torch.tensor(2.0, device=torch.device('cuda')), requires_grad=True)\n", + "y = torch.tensor(1.0, device=torch.device('cuda'))\n", "\n", - "optim = torchopt.MetaAdam(net, lr=1., moment_requires_grad=True, use_accelerated_op=True)\n", + "optim = torchopt.MetaAdam(net, lr=1.0, moment_requires_grad=True, use_accelerated_op=True)\n", "\n", - "net_state_0 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step0.')\n", + "net_state_0 = torchopt.extract_state_dict(\n", + " net, by='reference', enable_visual=True, visual_prefix='step0.'\n", + ")\n", "inner_loss = F.mse_loss(net(x), y)\n", "optim.step(inner_loss)\n", - "net_state_1 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step1.')\n", + "net_state_1 = torchopt.extract_state_dict(\n", + " net, by='reference', enable_visual=True, visual_prefix='step1.'\n", + ")\n", "\n", "outer_loss = F.mse_loss(net(x), y)\n", - "display(torchopt.visual.make_dot(outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]))" + "display(\n", + " torchopt.visual.make_dot(\n", + " outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Known Issues\n", + "\n", + "Here we record some common issues faced by users when using the meta-optimizer." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**1. Get `NaN` error when using `MetaAdam` or other meta-optimizers.**\n", + "\n", + "The `NaN` error is because of the numerical instability of the `Adam` in meta-learning. There exist an `sqrt` operation in `Adam`'s computation process. Backpropogating through the `Adam` operator introduces the second derivation of the `sqrt` operation, which is not numerical stable, i.e. ${\\left. \\frac{d^2 \\sqrt{x}}{{dx}^2} \\right\\rvert}_{x = 0} = \\texttt{NaN}$. You can also refer to issue [facebookresearch/higher#125](https://github.com/facebookresearch/higher/issues/125).\n", + "\n", + "For this problem, TorchOpt have two recommended solutions.\n", + "\n", + "* Put the `sqrt` operation into the whole equation, and compute the derivation of the output to the input manually. The second derivation of the `sqrt` operation will be eliminated. You can achieve this by setting the flag `use_accelerated_op=True`, you can follow the instructions in notebook [Functional Optimizer](1_Functional_Optimizer.ipynb) and Meta-Optimizer." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "inner_optim = torchopt.MetaAdam(net, lr=1.0, use_accelerated_op=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "* Register hook to the first-order gradients. During the backpropagation, the NaN gradients will be set to 0, which will have a similar effect to the first solution but much slower. " + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "impl = torchopt.chain(torchopt.hook.register_hook(torchopt.hook.zero_nan_hook), torchopt.adam(1e-1))\n", + "inner_optim = torchopt.MetaOptimizer(net, impl)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**2. Get `Trying to backward through the graph a second time` error when conducting multiple meta-optimization.**\n", + "\n", + "Please refer to the tutorial notebook [Stop Gradient](4_Stop_Gradient.ipynb) for more guidances." ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3.8.13 ('torchopt')", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -564,7 +700,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.9.15" }, "vscode": { "interpreter": { diff --git a/tutorials/4_Stop_Gradient.ipynb b/tutorials/4_Stop_Gradient.ipynb index 604196ca..06e6b3c3 100644 --- a/tutorials/4_Stop_Gradient.ipynb +++ b/tutorials/4_Stop_Gradient.ipynb @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "[](https://colab.research.google.com/drive/1jp_oPHIG6aaQMYGNxG72FSuWjABk1DHo?usp=sharing)" + "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/4_Stop_Gradient.ipynb)" ] }, { @@ -40,10 +40,11 @@ " def __init__(self, dim):\n", " super().__init__()\n", " self.fc = nn.Linear(dim, 1, bias=True)\n", - " \n", + "\n", " def forward(self, x):\n", " return self.fc(x)\n", "\n", + "\n", "loss_fn = F.mse_loss" ] }, @@ -81,7 +82,7 @@ "metadata": {}, "outputs": [], "source": [ - "meta_parameter = nn.Parameter(torch.tensor(1.), requires_grad=True)\n", + "meta_parameter = nn.Parameter(torch.tensor(1.0), requires_grad=True)\n", "\n", "optim = torchopt.MetaSGD(net, lr=1e-1)\n", "meta_optim = torch.optim.Adam([meta_parameter], lr=1e-1)" @@ -103,13 +104,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "inner loss: 0.5540\n", - "\n" + "inner loss: 0.3472\n", + "\n" ] }, { "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n139978828415600\n\ninner_loss\n ()\n\n\n\n139978603488640\n\nMseLossBackward0\n\n\n\n139978603488640->139978828415600\n\n\n\n\n\n139978603489744\n\nAddmmBackward0\n\n\n\n139978603489744->139978603488640\n\n\n\n\n\n139978603490800\n\nAccumulateGrad\n\n\n\n139978603490800->139978603489744\n\n\n\n\n\n139975938634512\n\nstep0.fc.bias\n (1)\n\n\n\n139975938634512->139978603490800\n\n\n\n\n\n139978603490224\n\nTBackward0\n\n\n\n139978603490224->139978603489744\n\n\n\n\n\n139978603490368\n\nAccumulateGrad\n\n\n\n139978603490368->139978603490224\n\n\n\n\n\n139975938634432\n\nstep0.fc.weight\n (1, 16)\n\n\n\n139975938634432->139978603490368\n\n\n\n\n\n" + "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140025091550880\n\ninner_loss\n()\n\n\n\n140028156253184\n\nMseLossBackward0\n\n\n\n140028156253184->140025091550880\n\n\n\n\n\n140028156436736\n\nAddmmBackward0\n\n\n\n140028156436736->140028156253184\n\n\n\n\n\n140025091526416\n\nAccumulateGrad\n\n\n\n140025091526416->140028156436736\n\n\n\n\n\n140028155952000\n\nstep0.fc.bias\n(1)\n\n\n\n140028155952000->140025091526416\n\n\n\n\n\n140025091525408\n\nTBackward0\n\n\n\n140025091525408->140028156436736\n\n\n\n\n\n140025091526224\n\nAccumulateGrad\n\n\n\n140025091526224->140025091525408\n\n\n\n\n\n140028155952880\n\nstep0.fc.weight\n(1, 16)\n\n\n\n140028155952880->140025091526224\n\n\n\n\n\n" }, "metadata": {}, "output_type": "display_data" @@ -122,12 +123,7 @@ "inner_loss = loss_fn(net(x), y)\n", "\n", "print(f'inner loss: {inner_loss:.4f}')\n", - "display(\n", - " torchopt.visual.make_dot(\n", - " inner_loss,\n", - " params=(init_net_state, {'inner_loss': inner_loss})\n", - " )\n", - ")" + "display(torchopt.visual.make_dot(inner_loss, params=(init_net_state, {'inner_loss': inner_loss})))" ] }, { @@ -168,13 +164,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "outer loss: 0.2297\n", - "\n" + "outer loss: 0.2039\n", + "\n" ] }, { "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n139975938634752\n\nouter_loss\n ()\n\n\n\n139975938188288\n\nMseLossBackward0\n\n\n\n139975938188288->139975938634752\n\n\n\n\n\n139975938188336\n\nAddmmBackward0\n\n\n\n139975938188336->139975938188288\n\n\n\n\n\n139975938188096\n\nAddBackward0\n step1.fc.bias\n (1)\n\n\n\n139975938188096->139975938188336\n\n\n\n\n\n139978603490800\n\nAccumulateGrad\n\n\n\n139978603490800->139975938188096\n\n\n\n\n\n139978603489744\n\nAddmmBackward0\n\n\n\n139978603490800->139978603489744\n\n\n\n\n\n139975938634512\n\nstep0.fc.bias\n (1)\n\n\n\n139975938634512->139978603490800\n\n\n\n\n\n139975938188480\n\nMulBackward0\n\n\n\n139975938188480->139975938188096\n\n\n\n\n\n139975938188144\n\nViewBackward0\n\n\n\n139975938188144->139975938188480\n\n\n\n\n\n139975938187664\n\nSumBackward1\n\n\n\n139975938187664->139975938188144\n\n\n\n\n\n139975938188720\n\nMseLossBackwardBackward0\n\n\n\n139975938188720->139975938187664\n\n\n\n\n\n139975938189200\n\nTBackward0\n\n\n\n139975938188720->139975938189200\n\n\n\n\n\n139975938188816\n\nMulBackward0\n\n\n\n139975938188816->139975938188720\n\n\n\n\n\n139975938188912\n\nAccumulateGrad\n\n\n\n139975938188912->139975938188816\n\n\n\n\n\n139975938635072\n\nmeta_parameter\n ()\n\n\n\n139975938635072->139975938188912\n\n\n\n\n\n139978603489744->139975938188720\n\n\n\n\n\n139978603490224\n\nTBackward0\n\n\n\n139978603490224->139978603489744\n\n\n\n\n\n139978603490368\n\nAccumulateGrad\n\n\n\n139978603490368->139978603490224\n\n\n\n\n\n139975938187808\n\nAddBackward0\n step1.fc.weight\n (1, 16)\n\n\n\n139978603490368->139975938187808\n\n\n\n\n\n139975938634432\n\nstep0.fc.weight\n (1, 16)\n\n\n\n139975938634432->139978603490368\n\n\n\n\n\n139975938188384\n\nTBackward0\n\n\n\n139975938188384->139975938188336\n\n\n\n\n\n139975938187808->139975938188384\n\n\n\n\n\n139975938188672\n\nMulBackward0\n\n\n\n139975938188672->139975938187808\n\n\n\n\n\n139975938189008\n\nTBackward0\n\n\n\n139975938189008->139975938188672\n\n\n\n\n\n139975938189104\n\nTBackward0\n\n\n\n139975938189104->139975938189008\n\n\n\n\n\n139975938188864\n\nMmBackward0\n\n\n\n139975938188864->139975938189104\n\n\n\n\n\n139975938189200->139975938188864\n\n\n\n\n\n" + "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140027829238416\n\nouter_loss\n()\n\n\n\n140025091525072\n\nMseLossBackward0\n\n\n\n140025091525072->140027829238416\n\n\n\n\n\n140025091525216\n\nAddmmBackward0\n\n\n\n140025091525216->140025091525072\n\n\n\n\n\n140025091526128\n\nAddBackward0\nstep1.fc.bias\n(1)\n\n\n\n140025091526128->140025091525216\n\n\n\n\n\n140025091526416\n\nAccumulateGrad\n\n\n\n140025091526416->140025091526128\n\n\n\n\n\n140028156436736\n\nAddmmBackward0\n\n\n\n140025091526416->140028156436736\n\n\n\n\n\n140028155952000\n\nstep0.fc.bias\n(1)\n\n\n\n140028155952000->140025091526416\n\n\n\n\n\n140025091524976\n\nMulBackward0\n\n\n\n140025091524976->140025091526128\n\n\n\n\n\n140025091526560\n\nViewBackward0\n\n\n\n140025091526560->140025091524976\n\n\n\n\n\n140025091525456\n\nSumBackward1\n\n\n\n140025091525456->140025091526560\n\n\n\n\n\n140025091524112\n\nMseLossBackwardBackward0\n\n\n\n140025091524112->140025091525456\n\n\n\n\n\n140024973742672\n\nTBackward0\n\n\n\n140025091524112->140024973742672\n\n\n\n\n\n140024973742288\n\nMulBackward0\n\n\n\n140024973742288->140025091524112\n\n\n\n\n\n140024973742384\n\nAccumulateGrad\n\n\n\n140024973742384->140024973742288\n\n\n\n\n\n140025091549440\n\nmeta_parameter\n()\n\n\n\n140025091549440->140024973742384\n\n\n\n\n\n140028156436736->140025091524112\n\n\n\n\n\n140025091525408\n\nTBackward0\n\n\n\n140025091525408->140028156436736\n\n\n\n\n\n140025091526224\n\nAccumulateGrad\n\n\n\n140025091526224->140025091525408\n\n\n\n\n\n140025091524928\n\nAddBackward0\nstep1.fc.weight\n(1, 16)\n\n\n\n140025091526224->140025091524928\n\n\n\n\n\n140028155952880\n\nstep0.fc.weight\n(1, 16)\n\n\n\n140028155952880->140025091526224\n\n\n\n\n\n140025091524448\n\nTBackward0\n\n\n\n140025091524448->140025091525216\n\n\n\n\n\n140025091524928->140025091524448\n\n\n\n\n\n140025091525600\n\nMulBackward0\n\n\n\n140025091525600->140025091524928\n\n\n\n\n\n140024973742144\n\nTBackward0\n\n\n\n140024973742144->140025091525600\n\n\n\n\n\n140024973742576\n\nTBackward0\n\n\n\n140024973742576->140024973742144\n\n\n\n\n\n140024973742480\n\nMmBackward0\n\n\n\n140024973742480->140024973742576\n\n\n\n\n\n140024973742672->140024973742480\n\n\n\n\n\n" }, "metadata": {}, "output_type": "display_data" @@ -191,7 +187,11 @@ "display(\n", " torchopt.visual.make_dot(\n", " outer_loss,\n", - " params=(init_net_state, one_step_net_state, {'meta_parameter': meta_parameter, 'outer_loss': outer_loss})\n", + " params=(\n", + " init_net_state,\n", + " one_step_net_state,\n", + " {'meta_parameter': meta_parameter, 'outer_loss': outer_loss},\n", + " ),\n", " )\n", ")" ] @@ -200,7 +200,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Then we backward the loss to conduct outer-loop meta optimization." + "Then we backward the loss to conduct outer-loop meta-optimization." ] }, { @@ -212,8 +212,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "meta_parameter.grad = tensor(-0.2464)\n", - "meta_parameter = Parameter containing: tensor(1.1000, requires_grad=True)\n" + "meta_parameter.grad = tensor(-0.1205)\n", + "meta_parameter = Parameter containing:\n", + "tensor(1.1000, requires_grad=True)\n" ] } ], @@ -236,11 +237,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In general, the back-propagation only frees saved tensors (often used as auxiliary data for computing the gradient) but the computation graph remains. Once the outer iteration is finished, if you want to use any intermediate network parameters produced by the inner loop for the next bi-level iteration, you should detach them from the computation graph.\n", + "In general, the backpropagation only frees saved tensors (often used as auxiliary data for computing the gradient) but the computation graph remains. Once the outer iteration is finished, if you want to use any intermediate network parameters produced by the inner loop for the next bi-level iteration, you should detach them from the computation graph.\n", "\n", "There are two main reasons:\n", "\n", - "- The network parameters are still connected to the previous computation graph (`.grad_fn` is not `None`). If later the gradient back-propagate to these parameters, the PyTorch backward engine will try to back-propagate through the previous computation graph. This will raise a `RuntimeError`: Trying to backward through the graph a second time...\n", + "- The network parameters are still connected to the previous computation graph (`.grad_fn` is not `None`). If later the gradient backpropagate to these parameters, the PyTorch backward engine will try to backpropagate through the previous computation graph. This will raise a `RuntimeError`: Trying to backward through the graph a second time...\n", "- If we do not detach the computation graph, the computation graph connected to these parameters can not be freed by GC (Garbage Collector) until these parameters are collected by GC." ] }, @@ -260,12 +261,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "\n" + "\n" ] }, { "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n139978828415600\n\nouter_loss\n ()\n\n\n\n139975938626944\n\nMseLossBackward0\n\n\n\n139975938626944->139978828415600\n\n\n\n\n\n139975938626656\n\nAddmmBackward0\n\n\n\n139975938626656->139975938626944\n\n\n\n\n\n139975938188624\n\nAddBackward0\n\n\n\n139975938188624->139975938626656\n\n\n\n\n\n139975938188096\n\nAddBackward0\n step1.fc.bias\n (1)\n\n\n\n139975938188096->139975938188624\n\n\n\n\n\n139975938188144\n\nAddmmBackward0\n\n\n\n139975938188096->139975938188144\n\n\n\n\n\n139975938187424\n\nAccumulateGrad\n\n\n\n139975938187424->139975938188096\n\n\n\n\n\n139975938188912\n\nAddmmBackward0\n\n\n\n139975938187424->139975938188912\n\n\n\n\n\n139975938634512\n\nstep0.fc.bias\n (1)\n\n\n\n139975938634512->139975938187424\n\n\n\n\n\n139975938187856\n\nMulBackward0\n\n\n\n139975938187856->139975938188096\n\n\n\n\n\n139975938188768\n\nViewBackward0\n\n\n\n139975938188768->139975938187856\n\n\n\n\n\n139975938189200\n\nSumBackward1\n\n\n\n139975938189200->139975938188768\n\n\n\n\n\n139975938189008\n\nMseLossBackwardBackward0\n\n\n\n139975938189008->139975938189200\n\n\n\n\n\n139975938189728\n\nTBackward0\n\n\n\n139975938189008->139975938189728\n\n\n\n\n\n139975938188864\n\nMulBackward0\n\n\n\n139975938188864->139975938189008\n\n\n\n\n\n139975938187952\n\nAccumulateGrad\n\n\n\n139975938187952->139975938188864\n\n\n\n\n\n139975938187712\n\nMulBackward0\n\n\n\n139975938187952->139975938187712\n\n\n\n\n\n139975938635072\n\nmeta_parameter\n ()\n\n\n\n139975938635072->139975938187952\n\n\n\n\n\n139975938188912->139975938189008\n\n\n\n\n\n139975938188480\n\nTBackward0\n\n\n\n139975938188480->139975938188912\n\n\n\n\n\n139975938188384\n\nAccumulateGrad\n\n\n\n139975938188384->139975938188480\n\n\n\n\n\n139975938187808\n\nAddBackward0\n step1.fc.weight\n (1, 16)\n\n\n\n139975938188384->139975938187808\n\n\n\n\n\n139975938634432\n\nstep0.fc.weight\n (1, 16)\n\n\n\n139975938634432->139975938188384\n\n\n\n\n\n139975938187520\n\nMulBackward0\n\n\n\n139975938187520->139975938188624\n\n\n\n\n\n139975938189296\n\nViewBackward0\n\n\n\n139975938189296->139975938187520\n\n\n\n\n\n139975938188576\n\nSumBackward1\n\n\n\n139975938188576->139975938189296\n\n\n\n\n\n139975938188720\n\nMseLossBackwardBackward0\n\n\n\n139975938188720->139975938188576\n\n\n\n\n\n139975938189824\n\nTBackward0\n\n\n\n139975938188720->139975938189824\n\n\n\n\n\n139975938187712->139975938188720\n\n\n\n\n\n139975938188144->139975938188720\n\n\n\n\n\n139975938188816\n\nTBackward0\n\n\n\n139975938188816->139975938188144\n\n\n\n\n\n139975938187808->139975938188816\n\n\n\n\n\n139975938189104\n\nAddBackward0\n\n\n\n139975938187808->139975938189104\n\n\n\n\n\n139975938189248\n\nMulBackward0\n\n\n\n139975938189248->139975938187808\n\n\n\n\n\n139975938189344\n\nTBackward0\n\n\n\n139975938189344->139975938189248\n\n\n\n\n\n139975938189536\n\nTBackward0\n\n\n\n139975938189536->139975938189344\n\n\n\n\n\n139975938189440\n\nMmBackward0\n\n\n\n139975938189440->139975938189536\n\n\n\n\n\n139975938189728->139975938189440\n\n\n\n\n\n139975938187904\n\nTBackward0\n\n\n\n139975938187904->139975938626656\n\n\n\n\n\n139975938189104->139975938187904\n\n\n\n\n\n139975938188240\n\nMulBackward0\n\n\n\n139975938188240->139975938189104\n\n\n\n\n\n139975938188048\n\nTBackward0\n\n\n\n139975938188048->139975938188240\n\n\n\n\n\n139975938188528\n\nTBackward0\n\n\n\n139975938188528->139975938188048\n\n\n\n\n\n139975938189584\n\nMmBackward0\n\n\n\n139975938189584->139975938188528\n\n\n\n\n\n139975938189824->139975938189584\n\n\n\n\n\n" + "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140024973755152\n\nouter_loss\n()\n\n\n\n140027829363232\n\nMseLossBackward0\n\n\n\n140027829363232->140024973755152\n\n\n\n\n\n140027829363616\n\nAddmmBackward0\n\n\n\n140027829363616->140027829363232\n\n\n\n\n\n140027829366544\n\nAddBackward0\n\n\n\n140027829366544->140027829363616\n\n\n\n\n\n140025091526128\n\nAddBackward0\nstep1.fc.bias\n(1)\n\n\n\n140025091526128->140027829366544\n\n\n\n\n\n140025091725152\n\nAddmmBackward0\n\n\n\n140025091526128->140025091725152\n\n\n\n\n\n140025091526416\n\nAccumulateGrad\n\n\n\n140025091526416->140025091526128\n\n\n\n\n\n140028156436736\n\nAddmmBackward0\n\n\n\n140025091526416->140028156436736\n\n\n\n\n\n140028155952000\n\nstep0.fc.bias\n(1)\n\n\n\n140028155952000->140025091526416\n\n\n\n\n\n140025091524976\n\nMulBackward0\n\n\n\n140025091524976->140025091526128\n\n\n\n\n\n140025091526560\n\nViewBackward0\n\n\n\n140025091526560->140025091524976\n\n\n\n\n\n140025091525456\n\nSumBackward1\n\n\n\n140025091525456->140025091526560\n\n\n\n\n\n140025091524112\n\nMseLossBackwardBackward0\n\n\n\n140025091524112->140025091525456\n\n\n\n\n\n140024973742672\n\nTBackward0\n\n\n\n140025091524112->140024973742672\n\n\n\n\n\n140024973742288\n\nMulBackward0\n\n\n\n140024973742288->140025091524112\n\n\n\n\n\n140024973742384\n\nAccumulateGrad\n\n\n\n140024973742384->140024973742288\n\n\n\n\n\n140025091726064\n\nMulBackward0\n\n\n\n140024973742384->140025091726064\n\n\n\n\n\n140025091549440\n\nmeta_parameter\n()\n\n\n\n140025091549440->140024973742384\n\n\n\n\n\n140028156436736->140025091524112\n\n\n\n\n\n140025091525408\n\nTBackward0\n\n\n\n140025091525408->140028156436736\n\n\n\n\n\n140025091526224\n\nAccumulateGrad\n\n\n\n140025091526224->140025091525408\n\n\n\n\n\n140025091524928\n\nAddBackward0\nstep1.fc.weight\n(1, 16)\n\n\n\n140025091526224->140025091524928\n\n\n\n\n\n140028155952880\n\nstep0.fc.weight\n(1, 16)\n\n\n\n140028155952880->140025091526224\n\n\n\n\n\n140025091726784\n\nMulBackward0\n\n\n\n140025091726784->140027829366544\n\n\n\n\n\n140025091726688\n\nViewBackward0\n\n\n\n140025091726688->140025091726784\n\n\n\n\n\n140025091725680\n\nSumBackward1\n\n\n\n140025091725680->140025091726688\n\n\n\n\n\n140025091726112\n\nMseLossBackwardBackward0\n\n\n\n140025091726112->140025091725680\n\n\n\n\n\n140025091726880\n\nTBackward0\n\n\n\n140025091726112->140025091726880\n\n\n\n\n\n140025091726064->140025091726112\n\n\n\n\n\n140025091725152->140025091726112\n\n\n\n\n\n140025091725824\n\nTBackward0\n\n\n\n140025091725824->140025091725152\n\n\n\n\n\n140025091524928->140025091725824\n\n\n\n\n\n140025091726016\n\nAddBackward0\n\n\n\n140025091524928->140025091726016\n\n\n\n\n\n140025091525600\n\nMulBackward0\n\n\n\n140025091525600->140025091524928\n\n\n\n\n\n140024973742144\n\nTBackward0\n\n\n\n140024973742144->140025091525600\n\n\n\n\n\n140024973742576\n\nTBackward0\n\n\n\n140024973742576->140024973742144\n\n\n\n\n\n140024973742480\n\nMmBackward0\n\n\n\n140024973742480->140024973742576\n\n\n\n\n\n140024973742672->140024973742480\n\n\n\n\n\n140027829365632\n\nTBackward0\n\n\n\n140027829365632->140027829363616\n\n\n\n\n\n140025091726016->140027829365632\n\n\n\n\n\n140025091726544\n\nMulBackward0\n\n\n\n140025091726544->140025091726016\n\n\n\n\n\n140025091726448\n\nTBackward0\n\n\n\n140025091726448->140025091726544\n\n\n\n\n\n140025091725584\n\nTBackward0\n\n\n\n140025091725584->140025091726448\n\n\n\n\n\n140025091727024\n\nMmBackward0\n\n\n\n140025091727024->140025091725584\n\n\n\n\n\n140025091726880->140025091727024\n\n\n\n\n\n" }, "metadata": {}, "output_type": "display_data" @@ -273,67 +274,103 @@ { "data": { "text/html": [ - "
╭──────────────────────────── Traceback (most recent call last) ────────────────────────────╮\n",
-       " <ipython-input-8-5906690e2182>:17 in <cell line: 17>                                      \n",
-       " /home/TorchOpt/Miniconda3/envs/torchopt/lib/python3.8/site-packages/torch/_tensor.py:396  \n",
-       " in backward                                                                               \n",
-       "                                                                                           \n",
-       "    393 │   │   │   │   retain_graph=retain_graph,                                         \n",
-       "    394 │   │   │   │   create_graph=create_graph,                                         \n",
-       "    395 │   │   │   │   inputs=inputs)                                                     \n",
-       "  396 │   │   torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs \n",
-       "    397 │                                                                                  \n",
-       "    398 │   def register_hook(self, hook):                                                 \n",
-       "    399 │   │   r\"\"\"Registers a backward hook.                                             \n",
-       "                                                                                           \n",
-       " /home/TorchOpt/Miniconda3/envs/torchopt/lib/python3.8/site-packages/torch/autograd/__init \n",
-       " __.py:173 in backward                                                                     \n",
-       "                                                                                           \n",
-       "   170 │   # The reason we repeat same the comment below is that                           \n",
-       "   171 │   # some Python versions print out the first line of a multi-line function        \n",
-       "   172 │   # calls in the traceback and some print out the last line                       \n",
-       " 173 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run th \n",
-       "   174 │   │   tensors, grad_tensors_, retain_graph, create_graph, inputs,                 \n",
-       "   175 │   │   allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine  \n",
-       "   176                                                                                     \n",
-       "╰───────────────────────────────────────────────────────────────────────────────────────────╯\n",
-       "RuntimeError: Trying to backward through the graph a second time (or directly access saved \n",
-       "tensors after they have already been freed). Saved intermediate values of the graph are freed\n",
-       "when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to \n",
-       "backward through the graph a second time or if you need to access saved tensors after calling\n",
-       "backward.\n",
+       "
╭─────────────────────────────────────── Traceback (most recent call last) ───────────────────────────────────────╮\n",
+       " /tmp/ipykernel_3962266/4178930003.py:21 in <module>                                                             \n",
+       "                                                                                                                 \n",
+       " [Errno 2] No such file or directory: '/tmp/ipykernel_3962266/4178930003.py'                                     \n",
+       "                                                                                                                 \n",
+       " /home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.9/site-packages/torch/_tensor.py:487 in backward           \n",
+       "                                                                                                                 \n",
+       "    484 │   │   │   │   create_graph=create_graph,                                                               \n",
+       "    485 │   │   │   │   inputs=inputs,                                                                           \n",
+       "    486 │   │   │   )                                                                                            \n",
+       "  487 │   │   torch.autograd.backward(                                                                         \n",
+       "    488 │   │   │   self, gradient, retain_graph, create_graph, inputs=inputs                                    \n",
+       "    489 │   │   )                                                                                                \n",
+       "    490                                                                                                          \n",
+       "                                                                                                                 \n",
+       " ╭───────────────────────── locals ──────────────────────────╮                                                   \n",
+       "  create_graph = False                                                                                         \n",
+       "      gradient = None                                                                                          \n",
+       "        inputs = None                                                                                          \n",
+       "  retain_graph = None                                                                                          \n",
+       "          self = tensor(0.1203, grad_fn=<MseLossBackward0>)                                                    \n",
+       " ╰───────────────────────────────────────────────────────────╯                                                   \n",
+       "                                                                                                                 \n",
+       " /home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.9/site-packages/torch/autograd/__init__.py:197 in backward \n",
+       "                                                                                                                 \n",
+       "   194 │   # The reason we repeat same the comment below is that                                                 \n",
+       "   195 │   # some Python versions print out the first line of a multi-line function                              \n",
+       "   196 │   # calls in the traceback and some print out the last line                                             \n",
+       " 197 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the ba                   \n",
+       "   198 │   │   tensors, grad_tensors_, retain_graph, create_graph, inputs,                                       \n",
+       "   199 │   │   allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to r                   \n",
+       "   200                                                                                                           \n",
+       "                                                                                                                 \n",
+       " ╭──────────────────────────── locals ────────────────────────────╮                                              \n",
+       "    create_graph = False                                                                                       \n",
+       "    grad_tensors = None                                                                                        \n",
+       "   grad_tensors_ = (tensor(1.),)                                                                               \n",
+       "  grad_variables = None                                                                                        \n",
+       "          inputs = ()                                                                                          \n",
+       "    retain_graph = False                                                                                       \n",
+       "         tensors = (tensor(0.1203, grad_fn=<MseLossBackward0>),)                                               \n",
+       " ╰────────────────────────────────────────────────────────────────╯                                              \n",
+       "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
+       "RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have \n",
+       "already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad().\n",
+       "Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved \n",
+       "tensors after calling backward.\n",
        "
\n" ], "text/plain": [ - "\u001b[91m╭─\u001b[0m\u001b[91m─────────────────────────── \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[91m ───────────────────────────\u001b[0m\u001b[91m─╮\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[33m\u001b[0m:\u001b[94m17\u001b[0m in \u001b[92m\u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[2;33m/home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.8/site-packages/torch/\u001b[0m\u001b[1;33m_tensor.py\u001b[0m:\u001b[94m396\u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m in \u001b[92mbackward\u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[2m 393 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mretain_graph=retain_graph, \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[2m 394 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mcreate_graph=create_graph, \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[2m 395 \u001b[0m\u001b[2m│ │ │ │ \u001b[0minputs=inputs) \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[31m❱ \u001b[0m 396 \u001b[2m│ │ \u001b[0mtorch.autograd.backward(\u001b[96mself\u001b[0m, gradient, retain_graph, create_graph, inputs \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[2m 397 \u001b[0m\u001b[2m│ \u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[2m 398 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92mregister_hook\u001b[0m(\u001b[96mself\u001b[0m, hook): \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[2m 399 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[33mr\u001b[0m\u001b[33m\"\"\"Registers a backward hook.\u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[2;33m/home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.8/site-packages/torch/autograd/\u001b[0m\u001b[1;33m__ini\u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[1;33mt__.py\u001b[0m:\u001b[94m173\u001b[0m in \u001b[92mbackward\u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[2m170 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# The reason we repeat same the comment below is that\u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[2m171 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# some Python versions print out the first line of a multi-line function\u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[2m172 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# calls in the traceback and some print out the last line\u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[31m❱ \u001b[0m173 \u001b[2m│ \u001b[0mVariable._execution_engine.run_backward( \u001b[2m# Calls into the C++ engine to run th\u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[2m174 \u001b[0m\u001b[2m│ │ \u001b[0mtensors, grad_tensors_, retain_graph, create_graph, inputs, \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[2m175 \u001b[0m\u001b[2m│ │ \u001b[0mallow_unreachable=\u001b[94mTrue\u001b[0m, accumulate_grad=\u001b[94mTrue\u001b[0m) \u001b[2m# Calls into the C++ engine \u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m│\u001b[0m \u001b[2m176 \u001b[0m \u001b[91m│\u001b[0m\n", - "\u001b[91m╰───────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n", - "\u001b[1;91mRuntimeError: \u001b[0mTrying to backward through the graph a second time \u001b[1m(\u001b[0mor directly access saved \n", - "tensors after they have already been freed\u001b[1m)\u001b[0m. Saved intermediate values of the graph are freed\n", - "when you call \u001b[1;35m.backward\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m or \u001b[1;35mautograd.grad\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m. Specify \u001b[33mretain_graph\u001b[0m=\u001b[3;92mTrue\u001b[0m if you need to \n", - "backward through the graph a second time or if you need to access saved tensors after calling\n", - "backward.\n" + "\u001b[31m╭─\u001b[0m\u001b[31m────────────────────────────────────── \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m ──────────────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/tmp/ipykernel_3962266/\u001b[0m\u001b[1;33m4178930003.py\u001b[0m:\u001b[94m21\u001b[0m in \u001b[92m\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[3;31m[Errno 2] No such file or directory: '/tmp/ipykernel_3962266/4178930003.py'\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.9/site-packages/torch/\u001b[0m\u001b[1;33m_tensor.py\u001b[0m:\u001b[94m487\u001b[0m in \u001b[92mbackward\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 484 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mcreate_graph=create_graph, \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 485 \u001b[0m\u001b[2m│ │ │ │ \u001b[0minputs=inputs, \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 486 \u001b[0m\u001b[2m│ │ │ \u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 487 \u001b[2m│ │ \u001b[0mtorch.autograd.backward( \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 488 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[96mself\u001b[0m, gradient, retain_graph, create_graph, inputs=inputs \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 489 \u001b[0m\u001b[2m│ │ \u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 490 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m╭─\u001b[0m\u001b[33m──────────────────────── locals ─────────────────────────\u001b[0m\u001b[33m─╮\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m create_graph = \u001b[94mFalse\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m gradient = \u001b[94mNone\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m inputs = \u001b[94mNone\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m retain_graph = \u001b[94mNone\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m self = \u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[94m0.1203\u001b[0m, \u001b[33mgrad_fn\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mMseLossBackward0\u001b[0m\u001b[1m>\u001b[0m\u001b[1m)\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m╰───────────────────────────────────────────────────────────╯\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.9/site-packages/torch/autograd/\u001b[0m\u001b[1;33m__init__.py\u001b[0m:\u001b[94m197\u001b[0m in \u001b[92mbackward\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m194 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# The reason we repeat same the comment below is that\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m195 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# some Python versions print out the first line of a multi-line function\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m196 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# calls in the traceback and some print out the last line\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m197 \u001b[2m│ \u001b[0mVariable._execution_engine.run_backward( \u001b[2m# Calls into the C++ engine to run the ba\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m198 \u001b[0m\u001b[2m│ │ \u001b[0mtensors, grad_tensors_, retain_graph, create_graph, inputs, \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m199 \u001b[0m\u001b[2m│ │ \u001b[0mallow_unreachable=\u001b[94mTrue\u001b[0m, accumulate_grad=\u001b[94mTrue\u001b[0m) \u001b[2m# Calls into the C++ engine to r\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m200 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m╭─\u001b[0m\u001b[33m─────────────────────────── locals ───────────────────────────\u001b[0m\u001b[33m─╮\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m create_graph = \u001b[94mFalse\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m grad_tensors = \u001b[94mNone\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m grad_tensors_ = \u001b[1m(\u001b[0m\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[94m1\u001b[0m.\u001b[1m)\u001b[0m,\u001b[1m)\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m grad_variables = \u001b[94mNone\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m inputs = \u001b[1m(\u001b[0m\u001b[1m)\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m retain_graph = \u001b[94mFalse\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m tensors = \u001b[1m(\u001b[0m\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[94m0.1203\u001b[0m, \u001b[33mgrad_fn\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mMseLossBackward0\u001b[0m\u001b[1m>\u001b[0m\u001b[1m)\u001b[0m,\u001b[1m)\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m╰────────────────────────────────────────────────────────────────╯\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n", + "\u001b[1;91mRuntimeError: \u001b[0mTrying to backward through the graph a second time \u001b[1m(\u001b[0mor directly access saved tensors after they have \n", + "already been freed\u001b[1m)\u001b[0m. Saved intermediate values of the graph are freed when you call \u001b[1;35m.backward\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m or \u001b[1;35mautograd.grad\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m.\n", + "Specify \u001b[33mretain_graph\u001b[0m=\u001b[3;92mTrue\u001b[0m if you need to backward through the graph a second time or if you need to access saved \n", + "tensors after calling backward.\n" ] }, "metadata": {}, @@ -351,7 +388,11 @@ "display(\n", " torchopt.visual.make_dot(\n", " outer_loss,\n", - " params=(init_net_state, one_step_net_state, {'meta_parameter': meta_parameter, 'outer_loss': outer_loss})\n", + " params=(\n", + " init_net_state,\n", + " one_step_net_state,\n", + " {'meta_parameter': meta_parameter, 'outer_loss': outer_loss},\n", + " ),\n", " )\n", ")\n", "\n", @@ -397,14 +438,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "meta_parameter.grad = tensor(-0.0914)\n", - "meta_parameter = Parameter containing: tensor(1.1887, requires_grad=True)\n", - "\n" + "meta_parameter.grad = tensor(-0.0635)\n", + "meta_parameter = Parameter containing:\n", + "tensor(1.1940, requires_grad=True)\n", + "\n" ] }, { "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n139975938621248\n\nouter_loss\n ()\n\n\n\n139975251126352\n\nMseLossBackward0\n\n\n\n139975251126352->139975938621248\n\n\n\n\n\n139975251126592\n\nAddmmBackward0\n\n\n\n139975251126592->139975251126352\n\n\n\n\n\n139975251125920\n\nAddBackward0\n\n\n\n139975251125920->139975251126592\n\n\n\n\n\n139975251126400\n\nAccumulateGrad\n\n\n\n139975251126400->139975251125920\n\n\n\n\n\n139975251127120\n\nAddmmBackward0\n\n\n\n139975251126400->139975251127120\n\n\n\n\n\n139975938636032\n\nstep1.detached.fc.bias\n (1)\n\n\n\n139975938636032->139975251126400\n\n\n\n\n\n139975251126304\n\nMulBackward0\n\n\n\n139975251126304->139975251125920\n\n\n\n\n\n139975251127072\n\nViewBackward0\n\n\n\n139975251127072->139975251126304\n\n\n\n\n\n139975251128080\n\nSumBackward1\n\n\n\n139975251128080->139975251127072\n\n\n\n\n\n139975251126448\n\nMseLossBackwardBackward0\n\n\n\n139975251126448->139975251128080\n\n\n\n\n\n139975251127456\n\nTBackward0\n\n\n\n139975251126448->139975251127456\n\n\n\n\n\n139975251127312\n\nMulBackward0\n\n\n\n139975251127312->139975251126448\n\n\n\n\n\n139975251126016\n\nAccumulateGrad\n\n\n\n139975251126016->139975251127312\n\n\n\n\n\n139975938635072\n\nmeta_parameter\n ()\n\n\n\n139975938635072->139975251126016\n\n\n\n\n\n139975251127120->139975251126448\n\n\n\n\n\n139975251126880\n\nTBackward0\n\n\n\n139975251126880->139975251127120\n\n\n\n\n\n139975251126544\n\nAccumulateGrad\n\n\n\n139975251126544->139975251126880\n\n\n\n\n\n139975251128272\n\nAddBackward0\n\n\n\n139975251126544->139975251128272\n\n\n\n\n\n139975938635552\n\nstep1.detached.fc.weight\n (1, 16)\n\n\n\n139975938635552->139975251126544\n\n\n\n\n\n139975251126256\n\nTBackward0\n\n\n\n139975251126256->139975251126592\n\n\n\n\n\n139975251128272->139975251126256\n\n\n\n\n\n139975251127744\n\nMulBackward0\n\n\n\n139975251127744->139975251128272\n\n\n\n\n\n139975251126112\n\nTBackward0\n\n\n\n139975251126112->139975251127744\n\n\n\n\n\n139975251126640\n\nTBackward0\n\n\n\n139975251126640->139975251126112\n\n\n\n\n\n139975251126976\n\nMmBackward0\n\n\n\n139975251126976->139975251126640\n\n\n\n\n\n139975251127456->139975251126976\n\n\n\n\n\n" + "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140024973754912\n\nouter_loss\n()\n\n\n\n140024956770528\n\nMseLossBackward0\n\n\n\n140024956770528->140024973754912\n\n\n\n\n\n140024956772112\n\nAddmmBackward0\n\n\n\n140024956772112->140024956770528\n\n\n\n\n\n140024956770720\n\nAddBackward0\n\n\n\n140024956770720->140024956772112\n\n\n\n\n\n140024962101312\n\nAccumulateGrad\n\n\n\n140024962101312->140024956770720\n\n\n\n\n\n140024973745552\n\nAddmmBackward0\n\n\n\n140024962101312->140024973745552\n\n\n\n\n\n140025091547520\n\nstep1.detached.fc.bias\n(1)\n\n\n\n140025091547520->140024962101312\n\n\n\n\n\n140024971586864\n\nMulBackward0\n\n\n\n140024971586864->140024956770720\n\n\n\n\n\n140024973742528\n\nViewBackward0\n\n\n\n140024973742528->140024971586864\n\n\n\n\n\n140024973743968\n\nSumBackward1\n\n\n\n140024973743968->140024973742528\n\n\n\n\n\n140024973742768\n\nMseLossBackwardBackward0\n\n\n\n140024973742768->140024973743968\n\n\n\n\n\n140024973744400\n\nTBackward0\n\n\n\n140024973742768->140024973744400\n\n\n\n\n\n140024973744688\n\nMulBackward0\n\n\n\n140024973744688->140024973742768\n\n\n\n\n\n140024973745264\n\nAccumulateGrad\n\n\n\n140024973745264->140024973744688\n\n\n\n\n\n140025091549440\n\nmeta_parameter\n()\n\n\n\n140025091549440->140024973745264\n\n\n\n\n\n140024973745552->140024973742768\n\n\n\n\n\n140024973745168\n\nTBackward0\n\n\n\n140024973745168->140024973745552\n\n\n\n\n\n140024973744256\n\nAccumulateGrad\n\n\n\n140024973744256->140024973745168\n\n\n\n\n\n140024973745984\n\nAddBackward0\n\n\n\n140024973744256->140024973745984\n\n\n\n\n\n140027828983424\n\nstep1.detached.fc.weight\n(1, 16)\n\n\n\n140027828983424->140024973744256\n\n\n\n\n\n140024956771632\n\nTBackward0\n\n\n\n140024956771632->140024956772112\n\n\n\n\n\n140024973745984->140024956771632\n\n\n\n\n\n140024973743728\n\nMulBackward0\n\n\n\n140024973743728->140024973745984\n\n\n\n\n\n140024973743344\n\nTBackward0\n\n\n\n140024973743344->140024973743728\n\n\n\n\n\n140024973745312\n\nTBackward0\n\n\n\n140024973745312->140024973743344\n\n\n\n\n\n140024973743200\n\nMmBackward0\n\n\n\n140024973743200->140024973745312\n\n\n\n\n\n140024973744400->140024973743200\n\n\n\n\n\n" }, "metadata": {}, "output_type": "display_data" @@ -414,7 +456,9 @@ "# Stop gradient and make them become the leaf node\n", "torchopt.stop_gradient(net)\n", "torchopt.stop_gradient(optim)\n", - "one_step_net_state_detached = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step1.detached.')\n", + "one_step_net_state_detached = torchopt.extract_state_dict(\n", + " net, enable_visual=True, visual_prefix='step1.detached.'\n", + ")\n", "\n", "# Inner update\n", "inner_loss = loss_fn(net(x), y)\n", @@ -432,7 +476,10 @@ "display(\n", " torchopt.visual.make_dot(\n", " outer_loss,\n", - " params=(one_step_net_state_detached, {'meta_parameter': meta_parameter, 'outer_loss': outer_loss})\n", + " params=(\n", + " one_step_net_state_detached,\n", + " {'meta_parameter': meta_parameter, 'outer_loss': outer_loss},\n", + " ),\n", " )\n", ")" ] @@ -447,7 +494,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.8.13 ('torchopt')", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -461,7 +508,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.9.15" }, "vscode": { "interpreter": { diff --git a/tutorials/5_Implicit_Differentiation.ipynb b/tutorials/5_Implicit_Differentiation.ipynb new file mode 100644 index 00000000..c2913101 --- /dev/null +++ b/tutorials/5_Implicit_Differentiation.ipynb @@ -0,0 +1,578 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8850c832-3b54-4971-8ee0-2cd64b585ea8", + "metadata": {}, + "source": [ + "# TorchOpt for Implicit Differentiation" + ] + }, + { + "cell_type": "markdown", + "id": "2b547376", + "metadata": {}, + "source": [ + "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/5_Implicit_Differentiation.ipynb)" + ] + }, + { + "cell_type": "markdown", + "id": "8d7f9865-dc02-43d4-be90-da1160c4e4dd", + "metadata": {}, + "source": [ + "By treating the solution $\\phi^{\\star}$ as an implicit function of $\\theta$, the idea of implicit differentiation is to directly get analytical best-response derivatives $\\partial \\phi^{\\star}(\\theta)/ \\partial \\theta$ by implicit function theorem. This is suitable for algorithms when the inner-level optimal solution is achieved ${\\left. \\frac{\\partial F (\\phi, \\theta)}{\\partial \\phi} \\right\\rvert}_{\\phi = \\phi^{\\star}} = 0$ or reaches some stationary conditions $F (\\phi^{\\star}, \\theta) = 0$, such as [iMAML](https://arxiv.org/abs/1909.04630) and [DEQ](https://arxiv.org/abs/1909.01377)." + ] + }, + { + "cell_type": "markdown", + "id": "d7e4b9e1-115f-45ad-a9b3-ea338bcfe6dd", + "metadata": {}, + "source": [ + "In this tutorial, we will introduce how TorchOpt can be used to conduct implicit differentiation." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "8f13ae67-e328-409f-84a8-1fc425c03a66", + "metadata": {}, + "outputs": [], + "source": [ + "import functorch\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "import torchopt" + ] + }, + { + "cell_type": "markdown", + "id": "0cdaac49-4b94-4900-9bb5-a39057ac8b21", + "metadata": {}, + "source": [ + "## 1. Functional API\n", + "\n", + "The basic functional API is `torchopt.diff.implicit.custom_root`, which is used as the decorator for the forward process implicit gradient procedures. Users are required to implement the stationary conditions for the inner-loop process, which will be used as the input of custom_root decorator. We show the pseudo code in the following part." + ] + }, + { + "cell_type": "markdown", + "id": "c0b4400b-a491-4f07-926c-c421ac5a2069", + "metadata": {}, + "source": [ + "```python\n", + "# Functional API for implicit gradient\n", + "def stationary(params, meta_params, data):\n", + " # stationary condition construction\n", + " return stationary condition\n", + "\n", + "# Decorator that wraps the function\n", + "# Optionally specify the linear solver (conjugate gradient or Neumann series)\n", + "@torchopt.diff.implicit.custom_root(stationary, solve=linear_solver)\n", + "def solve(params, meta_params, data):\n", + " # Forward optimization process for params\n", + " return optimal_params\n", + "\n", + "# Define params, meta_params and get data\n", + "params, meta_prams, data = ..., ..., ...\n", + "optimal_params = solve(params, meta_params, data)\n", + "loss = outer_loss(optimal_params)\n", + "\n", + "meta_grads = torch.autograd.grad(loss, meta_params)\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "dbef87df-2164-4f1d-8919-37a6fbdc5011", + "metadata": {}, + "source": [ + "Here we use the example of [iMAML](https://arxiv.org/abs/1909.04630) as a real example. For iMAML, the inner-loop objective is described by the following equation.\n", + "\n", + "$$\n", + "{\\mathcal{Alg}}^{\\star} \\left( \\boldsymbol{\\theta}, \\mathcal{D}_{i}^{\\text{tr}} \\right) = \\underset{\\phi'}{\\operatorname{\\arg \\min}} ~ G \\left( \\boldsymbol{\\phi}', \\boldsymbol{\\theta} \\right) \\triangleq \\mathcal{L} \\left( \\boldsymbol{\\phi}', \\mathcal{D}_{i}^{\\text{tr}} \\right) + \\frac{\\lambda}{2} {\\left\\| \\boldsymbol{\\phi}' - \\boldsymbol{\\theta} \\right\\|}^{2}\n", + "$$\n", + "\n", + "According to this function, we can define the forward function `inner_solver`, where we solve this equation based on sufficient gradient descents. For such inner-loop process, the optimality condition is that the gradient w.r.t inner-loop parameter is $0$.\n", + "\n", + "$$\n", + "{\\left. \\nabla_{\\boldsymbol{\\phi}'} G \\left( \\boldsymbol{\\phi}', \\boldsymbol{\\theta} \\right) \\right\\rvert}_{\\boldsymbol{\\phi}' = \\boldsymbol{\\phi}^{\\star}} = 0\n", + "$$\n", + "\n", + "Thus we can define the optimality function by defining `imaml_objective` and make it first-order gradient w.r.t the inner-loop parameter as $0$. We achieve so by calling out `functorch.grad(imaml_objective, argnums=0)`. Finally, the forward function is decorated by the `@torchopt.diff.implicit.custom_root` decorator and the optimality condition we define." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8d623b2f-48ee-4df6-a2ce-cf306b4c9067", + "metadata": {}, + "outputs": [], + "source": [ + "# Inner-loop objective function\n", + "# The optimality function: grad(imaml_objective)\n", + "def imaml_objective(params, meta_params, data):\n", + " x, y, fmodel = data\n", + " y_pred = fmodel(params, x)\n", + " regularization_loss = 0.0\n", + " for p1, p2 in zip(params, meta_params):\n", + " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", + " loss = F.mse_loss(y_pred, y) + regularization_loss\n", + " return loss\n", + "\n", + "\n", + "# Optimality Condition is: the gradient w.r.t inner-loop optimal params is 0 (we achieve so by\n", + "# specifying argnums=0 in functorch.grad) the argnums=1 specify which meta-parameter we want to\n", + "# backpropogate, in this case we want to backpropogate to the initial parameters so we set it as 1.\n", + "# You can also set argnums as (1, 2) if you want to backpropogate through multiple meta-parameters\n", + "\n", + "# Here we pass argnums=1 to the custom_root. That means we want to compute the gradient of\n", + "# optimal_params w.r.t. the 1-indexed argument in inner_solver, i.e., params.\n", + "# torchopt.linear_solve.solve_normal_cg specify that we use the conjugate gradient based linear solver\n", + "@torchopt.diff.implicit.custom_root(\n", + " functorch.grad(imaml_objective, argnums=0), # optimality function\n", + " argnums=1,\n", + " solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),\n", + ")\n", + "def inner_solver(params, meta_params, data):\n", + " # Initial functional optimizer based on TorchOpt\n", + " x, y, fmodel = data\n", + " optimizer = torchopt.sgd(lr=2e-2)\n", + " opt_state = optimizer.init(params)\n", + " with torch.enable_grad():\n", + " # Temporarily enable gradient computation for conducting the optimization\n", + " for i in range(100):\n", + " pred = fmodel(params, x)\n", + " loss = F.mse_loss(pred, y) # compute loss\n", + "\n", + " # Compute regularization loss\n", + " regularization_loss = 0.0\n", + " for p1, p2 in zip(params, meta_params):\n", + " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", + " final_loss = loss + regularization_loss\n", + "\n", + " grads = torch.autograd.grad(final_loss, params) # compute gradients\n", + " updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get updates\n", + " params = torchopt.apply_updates(params, updates, inplace=True)\n", + "\n", + " optimal_params = params\n", + " return optimal_params\n", + "\n", + "\n", + "# torchopt.linear_solve.solve_inv specify that we use the Neumann Series inversion linear solver\n", + "@torchopt.diff.implicit.custom_root(\n", + " functorch.grad(imaml_objective, argnums=0), # optimality function\n", + " argnums=1,\n", + " solve=torchopt.linear_solve.solve_inv(ns=True, maxiter=100, alpha=0.1),\n", + ")\n", + "def inner_solver_inv_ns(params, meta_params, data):\n", + " # Initial functional optimizer based on TorchOpt\n", + " x, y, fmodel = data\n", + " optimizer = torchopt.sgd(lr=2e-2)\n", + " opt_state = optimizer.init(params)\n", + " with torch.enable_grad():\n", + " # Temporarily enable gradient computation for conducting the optimization\n", + " for i in range(100):\n", + " pred = fmodel(params, x)\n", + " loss = F.mse_loss(pred, y) # compute loss\n", + "\n", + " # Compute regularization loss\n", + " regularization_loss = 0.0\n", + " for p1, p2 in zip(params, meta_params):\n", + " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", + " final_loss = loss + regularization_loss\n", + "\n", + " grads = torch.autograd.grad(final_loss, params) # compute gradients\n", + " updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get updates\n", + " params = torchopt.apply_updates(params, updates, inplace=True)\n", + "\n", + " optimal_params = params\n", + " return optimal_params" + ] + }, + { + "cell_type": "markdown", + "id": "32a75c81-d479-4120-a73d-5b2b488358d0", + "metadata": {}, + "source": [ + "In the next step, we consider a specific case for one layer neural network to fit the linear data." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "fb95538b-1fd9-4ec8-9f57-6360bedc05b7", + "metadata": {}, + "outputs": [], + "source": [ + "torch.manual_seed(0)\n", + "x = torch.randn(20, 4)\n", + "w = torch.randn(4, 1)\n", + "b = torch.randn(1)\n", + "y = x @ w + b + 0.5 * torch.randn(20, 1)" + ] + }, + { + "cell_type": "markdown", + "id": "eeb1823a-2231-4471-bb68-cce7724f2578", + "metadata": {}, + "source": [ + "We instantiate an one layer neural network, where the weights and bias are initialized with constant." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d50a7bfe-ac69-4089-8cf8-3cbd69d6d4e7", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "class Net(nn.Module):\n", + " def __init__(self, dim):\n", + " super().__init__()\n", + " self.fc = nn.Linear(dim, 1, bias=True)\n", + " nn.init.ones_(self.fc.weight)\n", + " nn.init.zeros_(self.fc.bias)\n", + "\n", + " def forward(self, x):\n", + " return self.fc(x)\n", + "\n", + "\n", + "model = Net(4)\n", + "fmodel, meta_params = functorch.make_functional(model)\n", + "data = (x, y, fmodel)\n", + "\n", + "# Clone function for parameters\n", + "def clone(params):\n", + " cloned = []\n", + " for item in params:\n", + " if isinstance(item, torch.Tensor):\n", + " cloned.append(item.clone().detach_().requires_grad_(True))\n", + " else:\n", + " cloned.append(item)\n", + " return tuple(cloned)" + ] + }, + { + "cell_type": "markdown", + "id": "065c36c4-89e2-4a63-8213-63db6ee3b08e", + "metadata": {}, + "source": [ + "We take the forward process by calling out the forward function, then we pass the optimal params into the outer-loop loss function." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "115e79c6-911f-4743-a2ed-e50a71c3a813", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "optimal_params = inner_solver(clone(meta_params), meta_params, data)\n", + "\n", + "outer_loss = fmodel(optimal_params, x).mean()" + ] + }, + { + "cell_type": "markdown", + "id": "e2812351-f635-496e-9732-c80831ac04a6", + "metadata": {}, + "source": [ + "Finally, we can get the meta-gradient as shown below." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "6bdcbe8d-2336-4f80-b124-eb43c5a2fc0a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(tensor([[-0.0369, 0.0248, 0.0347, 0.0067]]), tensor([0.3156]))\n" + ] + } + ], + "source": [ + "torch.autograd.grad(outer_loss, meta_params)" + ] + }, + { + "cell_type": "markdown", + "id": "926ae8bb", + "metadata": {}, + "source": [ + "Also we can switch to the Neumann Series inversion linear solver." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "43df0374", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(tensor([[-0.0369, 0.0248, 0.0347, 0.0067]]), tensor([0.3156]))\n" + ] + } + ], + "source": [ + "optimal_params = inner_solver_inv_ns(clone(meta_params), meta_params, data)\n", + "outer_loss = fmodel(optimal_params, x).mean()\n", + "torch.autograd.grad(outer_loss, meta_params)" + ] + }, + { + "cell_type": "markdown", + "id": "c92e67ea-b220-4a14-a1ea-4eb3c5f52b6b", + "metadata": {}, + "source": [ + "## 2. OOP API\n", + "\n", + "The basic OOP class is the class `ImplicitMetaGradientModule`. We make the network as an `nn.Module` following a classical PyTorch style. Users need to define the stationary condition/objective function and the inner-loop solve function to enable implicit gradient computation. We show the pseudo code in the following part.\n", + "\n", + "```python\n", + "from torchopt.nn import ImplicitMetaGradientModule\n", + "\n", + "# Inherited from the class ImplicitMetaGradientModule\n", + "# Optionally specify the linear solver (conjugate gradient or Neumann series)\n", + "class InnerNet(ImplicitMetaGradientModule, linear_solve=linear_solver):\n", + " def __init__(self, meta_module):\n", + " ...\n", + "\n", + " def forward(self, batch):\n", + " # Forward process\n", + " ...\n", + "\n", + " def optimality(self, batch, labels):\n", + " # Stationary condition construction for calculating implicit gradient\n", + " # NOTE: If this method is not implemented, it will be automatically derived from the\n", + " # gradient of the `objective` function.\n", + " ...\n", + "\n", + " def objective(self, batch, labels):\n", + " # Define the inner-loop optimization objective\n", + " # NOTE: This method is optional if method `optimality` is implemented.\n", + " ...\n", + "\n", + " def solve(self, batch, labels):\n", + " # Conduct the inner-loop optimization\n", + " ...\n", + " return self # optimized module\n", + "\n", + "# Get meta_params and data\n", + "meta_params, data = ..., ...\n", + "inner_net = InnerNet()\n", + "\n", + "# Solve for inner-loop process related with the meta-parameters\n", + "optimal_inner_net = inner_net.solve(meta_params, *data)\n", + "\n", + "# Get outer-loss and solve for meta-gradient\n", + "loss = outer_loss(optimal_inner_net)\n", + "meta_grad = torch.autograd.grad(loss, meta_params)\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "62fbe520-11d0-41ff-9b0a-c6508b1d01cf", + "metadata": {}, + "source": [ + "The class `ImplicitMetaGradientModule` is to enable the gradient flow from `self.parameters()` to `self.meta_parameters()`. In `__init__` function, users need to define the inner parameters and meta-parameters. By default, `ImplicitMetaGradientModule` treats all tensors and modules from input as `self.meta_parameters()`, and all tensors and modules defined in the `__init__` are regarded as `self.parameters()`. Users can also register `self.parameters()` and `self.meta_parameters()` by calling `self.register_parameter()` and `self.register_meta_parameter()` respectively." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c3999684-f4d3-4bc0-86ab-a7e803b2fe80", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(tensor([[-0.0369, 0.0248, 0.0347, 0.0067]]), tensor([0.3156]))\n" + ] + } + ], + "source": [ + "class InnerNet(\n", + " torchopt.nn.ImplicitMetaGradientModule,\n", + " linear_solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),\n", + "):\n", + " def __init__(self, meta_net, n_inner_iter, reg_param):\n", + " super().__init__()\n", + " # Declaration of the meta-parameter\n", + " self.meta_net = meta_net\n", + " # Get a deepcopy, register inner-parameter\n", + " self.net = torchopt.module_clone(meta_net, by='deepcopy', detach_buffers=True)\n", + " self.n_inner_iter = n_inner_iter\n", + " self.reg_param = reg_param\n", + "\n", + " def forward(self, x):\n", + " return self.net(x)\n", + "\n", + " def objective(self, x, y):\n", + " # We do not implement the optimality conditions, so it will be automatically derived from\n", + " # the gradient of the `objective` function.\n", + " y_pred = self(x)\n", + " loss = F.mse_loss(y_pred, y)\n", + " regularization_loss = 0\n", + " for p1, p2 in zip(\n", + " self.parameters(), # parameters of `self.net`\n", + " self.meta_parameters(), # parameters of `self.meta_net`\n", + " ):\n", + " regularization_loss += (\n", + " 0.5 * self.reg_param * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", + " )\n", + " return loss + regularization_loss\n", + "\n", + " def solve(self, x, y):\n", + " params = tuple(self.parameters())\n", + " inner_optim = torchopt.SGD(params, lr=2e-2)\n", + " with torch.enable_grad():\n", + " # Temporarily enable gradient computation for conducting the optimization\n", + " for _ in range(self.n_inner_iter):\n", + " loss = self.objective(x, y)\n", + " inner_optim.zero_grad()\n", + " # NOTE: The parameter inputs should be explicitly specified in `backward` function\n", + " # as argument `inputs`. Otherwise, if not provided, the gradient is accumulated into\n", + " # all the leaf Tensors (including the meta-parameters) that were used to compute the\n", + " # objective output. Alternatively, please use `torch.autograd.grad` instead.\n", + " loss.backward(inputs=params) # backward pass in inner-loop\n", + " inner_optim.step() # update inner parameters\n", + " return self\n", + "\n", + "\n", + "# Initialize the meta-network\n", + "meta_net = Net(4)\n", + "inner_net = InnerNet(meta_net, 100, reg_param=1)\n", + "\n", + "# Solve for inner-loop\n", + "optimal_inner_net = inner_net.solve(x, y)\n", + "outer_loss = optimal_inner_net(x).mean()\n", + "\n", + "# Derive the meta-gradient\n", + "torch.autograd.grad(outer_loss, meta_net.parameters())" + ] + }, + { + "cell_type": "markdown", + "id": "2b69a5d6-b5e4-4f08-af0a-40afc2382b45", + "metadata": {}, + "source": [ + "We also show an example on how to implement implicit gradient calculation when the inner-level optimal solution reaches some stationary conditions $F (\\phi^{\\star}, \\theta) = 0$, such as [DEQ](https://arxiv.org/abs/1909.01377), based on the OOP API. " + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "de87c308-d847-4491-9aa1-bc393e6dd1d8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "(\n", + "│ tensor([[ 0.0272, 0.0031, -0.0156, -0.0238],\n", + "│ │ [ 0.1004, 0.0113, -0.0573, -0.0878],\n", + "│ │ [ 0.0666, 0.0075, -0.0380, -0.0583],\n", + "│ │ [ 0.1446, 0.0163, -0.0826, -0.1265]]),\n", + "│ tensor([0.0574, 0.2114, 0.1403, 0.3046])\n", + ")\n" + ] + } + ], + "source": [ + "class Net(nn.Module):\n", + " def __init__(self, dim):\n", + " super().__init__()\n", + " self.fc = nn.Linear(dim, dim)\n", + "\n", + " def forward(self, x):\n", + " return self.fc(x)\n", + "\n", + "\n", + "class InnerNet(\n", + " torchopt.nn.ImplicitMetaGradientModule,\n", + " linear_solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),\n", + "):\n", + " def __init__(self, meta_net, x0):\n", + " super().__init__()\n", + " # Register meta-parameter\n", + " self.meta_net = meta_net\n", + " # Declaration of the inner-parameter, register inner-parameter\n", + " self.x = nn.Parameter(x0.clone().detach_(), requires_grad=True)\n", + "\n", + " def forward(self, x):\n", + " return self.meta_net(x)\n", + "\n", + " def optimality(self):\n", + " # Fixed-point condition\n", + " return (self.x - self(self.x),)\n", + "\n", + " def solve(self):\n", + " # Solving inner-loop fixed-point iteration\n", + " # This is just an illustrating example for solving fixed-point iteration\n", + " # one can use more advanced method to solve fixed-point iteration\n", + " # such as anderson acceleration.\n", + " for _ in range(10):\n", + " self.x.copy_(self(self.x))\n", + " return self\n", + "\n", + "\n", + "# Initialize meta-network\n", + "torch.manual_seed(0)\n", + "meta_net = Net(4)\n", + "x0 = torch.randn(1, 4)\n", + "inner_net = InnerNet(meta_net, x0)\n", + "\n", + "# Solve for inner-loop\n", + "optimal_inner_net = inner_net.solve()\n", + "outer_loss = optimal_inner_net.x.mean()\n", + "\n", + "# Derive the meta-gradient\n", + "torch.autograd.grad(outer_loss, meta_net.parameters())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.15" + }, + "vscode": { + "interpreter": { + "hash": "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/6_Zero_Order_Differentiation.ipynb b/tutorials/6_Zero_Order_Differentiation.ipynb new file mode 100644 index 00000000..c8d1e551 --- /dev/null +++ b/tutorials/6_Zero_Order_Differentiation.ipynb @@ -0,0 +1,212 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8850c832-3b54-4971-8ee0-2cd64b585ea8", + "metadata": {}, + "source": [ + "# TorchOpt for Zero-Order Differentiation" + ] + }, + { + "cell_type": "markdown", + "id": "2b547376", + "metadata": {}, + "source": [ + "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/6_Zero_Order_Differentiation.ipynb)" + ] + }, + { + "cell_type": "markdown", + "id": "8d7f9865-dc02-43d4-be90-da1160c4e4dd", + "metadata": {}, + "source": [ + "When the inner-loop process is non-differentiable or one wants to eliminate the heavy computation burdens in the previous two modes (brought by Hessian), one can choose ZD. ZD typically gets gradients based on zero-order estimation, such as finite-difference, or Evolutionary Strategy.\n", + "\n", + "TorchOpt offers API for ES-based differentiation. Instead of optimizing the objective $F$, ES optimizes a Gaussion smoothing objective defined as $\\tilde{f}_{\\sigma} (\\theta) = \\mathbb{E}_{{z} \\sim \\mathcal{N}( {0}, {I}_d )} [ f ({\\theta} + \\sigma \\, z) ]$, where $\\sigma$ denotes precision. The gradient of such objective is $\\nabla_\\theta \\tilde{f}_{\\sigma} (\\theta) = \\frac{1}{\\sigma} \\mathbb{E}_{{z} \\sim \\mathcal{N}( {0}, {I}_d )} [ f({\\theta} + \\sigma \\, z) \\cdot z ]$. Refer to [ES-MAML](https://arxiv.org/pdf/1910.01215.pdf) for more details." + ] + }, + { + "cell_type": "markdown", + "id": "d7e4b9e1-115f-45ad-a9b3-ea338bcfe6dd", + "metadata": {}, + "source": [ + "In this tutorial, we will introduce how TorchOpt can be used to ES-based differentiation." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "8f13ae67-e328-409f-84a8-1fc425c03a66", + "metadata": {}, + "outputs": [], + "source": [ + "import functorch\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "import torchopt" + ] + }, + { + "cell_type": "markdown", + "id": "0cdaac49-4b94-4900-9bb5-a39057ac8b21", + "metadata": {}, + "source": [ + "## 1. Functional API\n", + "\n", + "The basic functional API is `torchopt.diff.zero_order.zero_order`, which is used as the decorator for the forward process zero-order gradient procedures. Users are required to implement the noise sampling function, which will be used as the input of zero_order decorator. Here we show the specific meaning for each parameter used in the decorator.\n", + "\n", + "- `distribution` for noise sampling distribution\n", + "- `method` for different kind of algorithms, we support `'naive'` ([ES-RL](https://arxiv.org/abs/1703.03864)), `'forward'` ([Forward-FD](http://proceedings.mlr.press/v80/choromanski18a/choromanski18a.pdf)), and `'antithetic'` ([antithetic](https://d1wqtxts1xzle7.cloudfront.net/75609515/coredp2011_1web-with-cover-page-v2.pdf?Expires=1670215467&Signature=RfP~mQhhhI7aGknwXbRBgSggFrKuNTPYdyUSdMmfTxOa62QoOJAm-Xhr3F1PLyjUQc2JVxmKIKGGuyYvyfCTpB31dfmMtuVQxZMWVF-SfErTN05SliC93yjA1x1g2kjhn8bkBFdQqGl~1RQSKnhj88BakgSeDNzyCxwbD5VgR89BXRs4YIK5RBIKYtgLhoyz5jar7wHS3TJhRzs3WNeTIAjAmLqJ068oGFZ0Jr7maGquTe3w~8LEEIprJ6cyCMc6b1UUJkmwjNq0RLTVbxgFjfi4Z9kyxyJB9IOS1J25OOON4jfwh5JlXS7MVskuONUyHJim1TQ8OwCraKlBsQLPQw__&Key-Pair-Id=APKAJLOHF5GGSLRBV4ZA)).\n", + "- `argnums` specifies which parameter we want to trace the meta-gradient.\n", + "- `sigma` is for precision.\n", + "- `num_samples` specifies how many times we want to conduct the sampling.\n", + "\n", + "We show the pseudo code in the following part." + ] + }, + { + "cell_type": "markdown", + "id": "c0b4400b-a491-4f07-926c-c421ac5a2069", + "metadata": {}, + "source": [ + "```python\n", + "# Functional API for zero-order differentiation\n", + "# 1. Customize the noise distribution via a distribution class\n", + "class Distribution:\n", + " def sample(self, sample_shape = torch.Size()):\n", + " # sampling function for noise\n", + " return noise_batch\n", + "\n", + "distribution = Distribution()\n", + "\n", + "# 2. Customize the noise distribution via a sampling function\n", + "def distribution(sample_shape = torch.Size()):\n", + " # sampling function for noise\n", + " return noise_batch\n", + "\n", + "# 3. Distribution can also be an instance of `torch.distributions.Distribution`, e.g., `torch.distributions.Normal(...)`\n", + "distribution = torch.distributions.Normal(loc=0, scale=1)\n", + "\n", + "# Decorator that wraps the function\n", + "@torchopt.diff.zero_order(distribution=distribution, method='naive', argnums=0, sigma=0.01, num_samples=100)\n", + "def forward(params, data):\n", + " # Forward optimization process for params\n", + " return output\n", + "\n", + "# Define params and get data\n", + "params, data = ..., ...\n", + "loss = forward(params, data)\n", + "\n", + "meta_grads = torch.autograd.grad(loss, params)\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "dbef87df-2164-4f1d-8919-37a6fbdc5011", + "metadata": {}, + "source": [ + "Here we use the example of a linear layer as an example, note that this is just an example to show linear layer can work with ES." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8d623b2f-48ee-4df6-a2ce-cf306b4c9067", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "001: tensor(0.0269, grad_fn=)\n", + "002: tensor(0.0246, grad_fn=)\n", + "003: tensor(0.0225, grad_fn=)\n", + "004: tensor(0.0205, grad_fn=)\n", + "005: tensor(0.0187, grad_fn=)\n", + "006: tensor(0.0171, grad_fn=)\n", + "007: tensor(0.0156, grad_fn=)\n", + "008: tensor(0.0144, grad_fn=)\n", + "009: tensor(0.0134, grad_fn=)\n", + "010: tensor(0.0128, grad_fn=)\n", + "011: tensor(0.0122, grad_fn=)\n", + "012: tensor(0.0118, grad_fn=)\n", + "013: tensor(0.0120, grad_fn=)\n", + "014: tensor(0.0117, grad_fn=)\n", + "015: tensor(0.0117, grad_fn=)\n", + "016: tensor(0.0118, grad_fn=)\n", + "017: tensor(0.0121, grad_fn=)\n", + "018: tensor(0.0117, grad_fn=)\n", + "019: tensor(0.0118, grad_fn=)\n", + "020: tensor(0.0118, grad_fn=)\n", + "021: tensor(0.0115, grad_fn=)\n", + "022: tensor(0.0117, grad_fn=)\n", + "023: tensor(0.0117, grad_fn=)\n", + "024: tensor(0.0116, grad_fn=)\n", + "025: tensor(0.0113, grad_fn=)\n" + ] + } + ], + "source": [ + "torch.random.manual_seed(0)\n", + "\n", + "fmodel, params = functorch.make_functional(torch.nn.Linear(32, 1))\n", + "x = torch.randn(64, 32) * 0.1\n", + "y = torch.randn(64) * 0.1\n", + "distribution = torch.distributions.Normal(loc=0, scale=1)\n", + "\n", + "\n", + "@torchopt.diff.zero_order.zero_order(\n", + " distribution=distribution, method='forward', argnums=0, sigma=0.01, num_samples=1000\n", + ")\n", + "def forward_process(params, fn, x, y):\n", + " y_pred = fn(params, x)\n", + " loss = torch.mean((y - y_pred) ** 2)\n", + " return loss\n", + "\n", + "\n", + "optimizer = torchopt.adam(lr=0.01)\n", + "opt_state = optimizer.init(params)\n", + "\n", + "for i in range(25):\n", + " opt_state = optimizer.init(params) # init optimizer\n", + " loss = forward_process(params, fmodel, x, y) # compute loss\n", + "\n", + " grads = torch.autograd.grad(loss, params) # compute gradients\n", + " updates, opt_state = optimizer.update(grads, opt_state) # get updates\n", + " params = torchopt.apply_updates(params, updates) # update network parameters\n", + "\n", + " print(f'{i + 1:03d}: {loss!r}')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.15 ('torchopt')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.15" + }, + "vscode": { + "interpreter": { + "hash": "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/requirements.txt b/tutorials/requirements.txt index 5fe3b1ad..ff5a5c42 100644 --- a/tutorials/requirements.txt +++ b/tutorials/requirements.txt @@ -1,8 +1,11 @@ ---extra-index-url https://download.pytorch.org/whl/cu116 -torch >= 1.12 +--extra-index-url https://download.pytorch.org/whl/cu117 +# Sync with project.dependencies +torch >= 1.13 torchvision -functorch >= 0.2 --requirement ../requirements.txt ipykernel +jax[cpu] >= 0.3 +jaxopt +optax pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy