Skip to content

feat: add AdamW optimizer #44

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Sep 5, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
70b8a47
feat(torchopt): init adamw optimizer
Benjamin-eecs Jul 27, 2022
3300142
Merge remote-tracking branch 'upstream/main' into feature/adamw
Benjamin-eecs Aug 4, 2022
253cc2a
fix(torchopt): pass adamw tests
Benjamin-eecs Aug 4, 2022
17d5784
fix: force add adamw.py
Benjamin-eecs Aug 4, 2022
cdc3836
feat: add MetaAdamW test and pass lint
Benjamin-eecs Aug 5, 2022
cc3a3c7
feat: add MetaAdamW test and pass lint
Benjamin-eecs Aug 5, 2022
a071550
fix: pass lint and pass MetaAdamW tests
Benjamin-eecs Aug 5, 2022
89fac53
fix: rewrite MetaOptimizer test, pass MetaAdamW tests with error tol
Benjamin-eecs Aug 5, 2022
b50abe0
merge: resolve conflicts
Benjamin-eecs Aug 24, 2022
47ff9f3
merge: resolve conflicts
Benjamin-eecs Aug 24, 2022
476332e
fix: update adamw low level test
Benjamin-eecs Aug 26, 2022
8175181
merge: resolve conflicts
Benjamin-eecs Sep 1, 2022
bb82209
fix(tests): use new test
Benjamin-eecs Sep 4, 2022
4b01c7e
Merge remote-tracking branch 'upstream/main' into feature/adamw
Benjamin-eecs Sep 4, 2022
d935014
fix: pass lint
Benjamin-eecs Sep 4, 2022
47cfa45
fix: pass test
Benjamin-eecs Sep 4, 2022
9b32e7b
Merge remote-tracking branch 'upstream/main' into feature/adamw
Benjamin-eecs Sep 4, 2022
42ed8a5
fix: pass test
Benjamin-eecs Sep 4, 2022
1e64877
fix: pass test
Benjamin-eecs Sep 4, 2022
872b8d4
fix: update docstring
Benjamin-eecs Sep 4, 2022
824d1c5
fix: update docstring
Benjamin-eecs Sep 4, 2022
e920c74
fix: update docstring
Benjamin-eecs Sep 4, 2022
8ee3c41
fix: correct already_flattened
Benjamin-eecs Sep 4, 2022
0f129c0
fix: correct weight_decay range check
Benjamin-eecs Sep 4, 2022
e75671e
fix: already_flattened of mask
Benjamin-eecs Sep 4, 2022
c791bba
style: format code
XuehaiPan Sep 5, 2022
24690a0
feat: add shortcut
XuehaiPan Sep 5, 2022
fec6f99
chore: reorganize code structure
XuehaiPan Sep 5, 2022
d3ad838
feat: inplace support for AdamW
XuehaiPan Sep 5, 2022
c685954
docs: update docstrings
XuehaiPan Sep 5, 2022
8114286
docs(CHANGELOG): update CHANGELOG.md
XuehaiPan Sep 5, 2022
c075533
docs: update docstrings
XuehaiPan Sep 5, 2022
0f5c90a
docs: update docstrings
XuehaiPan Sep 5, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
merge: resolve conflicts
  • Loading branch information
Benjamin-eecs committed Sep 1, 2022
commit 81751816b3a0581d3a2aeb9c5d344eb1b945983c
2 changes: 1 addition & 1 deletion .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax `close #15213` if this solves the issue #15213

- [ ] I have raised an issue to propose this change ([required](https://github.com/metaopt/TorchOpt/issues) for new features and bug fixes)
- [ ] I have raised an issue to propose this change ([required](https://github.com/metaopt/torchopt/issues) for new features and bug fixes)

## Types of changes

Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ env:
jobs:
build-sdist:
runs-on: ubuntu-latest
if: github.repository == 'metaopt/TorchOpt' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/'))
if: github.repository == 'metaopt/torchopt' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/'))
timeout-minutes: 10
steps:
- name: Checkout
Expand Down Expand Up @@ -74,7 +74,7 @@ jobs:
build-wheels:
runs-on: ubuntu-latest
needs: [build-sdist]
if: github.repository == 'metaopt/TorchOpt' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/'))
if: github.repository == 'metaopt/torchopt' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/'))
timeout-minutes: 60
steps:
- name: Checkout
Expand All @@ -100,7 +100,7 @@ jobs:
runs-on: ubuntu-latest
needs: [build-sdist, build-wheels]
if: |
github.repository == 'metaopt/TorchOpt' && github.event_name != 'pull_request' &&
github.repository == 'metaopt/torchopt' && github.event_name != 'pull_request' &&
(github.event_name != 'workflow_dispatch' || github.event.inputs.task == 'build-and-publish') &&
(github.event_name != 'push' || startsWith(github.ref, 'refs/tags/'))
timeout-minutes: 15
Expand Down
9 changes: 4 additions & 5 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ jobs:
- run: |
CUDA_VERSION="${{steps.cuda-toolkit.outputs.cuda}}"
echo "CUDA_VERSION=${CUDA_VERSION}" >> "${GITHUB_ENV}"
TORCH_INDEX_URL="https://download.pytorch.org/whl/cu$(echo "${CUDA_VERSION}" | cut -d'.' -f-2 | tr -d '.')"
echo "TORCH_INDEX_URL=${TORCH_INDEX_URL}" >> "${GITHUB_ENV}"
PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cu$(echo "${CUDA_VERSION}" | cut -d'.' -f-2 | tr -d '.')"
echo "PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}" >> "${GITHUB_ENV}"

echo "Installed CUDA version is: ${CUDA_VERSION}"
echo "CUDA install location: ${{steps.cuda-toolkit.outputs.CUDA_PATH}}"
nvcc -V
echo "Torch index URL: ${TORCH_INDEX_URL}"
echo "Torch index URL: ${PIP_EXTRA_INDEX_URL}"

- name: Upgrade pip
run: |
Expand Down Expand Up @@ -92,8 +92,7 @@ jobs:

- name: Install dependencies
run: |
python -m pip install --extra-index-url "${TORCH_INDEX_URL}" \
-r docs/requirements.txt
python -m pip install -r docs/requirements.txt

- name: docstyle
run: |
Expand Down
16 changes: 11 additions & 5 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,25 +53,31 @@ jobs:
- run: |
CUDA_VERSION="${{steps.cuda-toolkit.outputs.cuda}}"
echo "CUDA_VERSION=${CUDA_VERSION}" >> "${GITHUB_ENV}"
TORCH_INDEX_URL="https://download.pytorch.org/whl/cu$(echo "${CUDA_VERSION}" | cut -d'.' -f-2 | tr -d '.')"
echo "TORCH_INDEX_URL=${TORCH_INDEX_URL}" >> "${GITHUB_ENV}"
PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cu$(echo "${CUDA_VERSION}" | cut -d'.' -f-2 | tr -d '.')"
echo "PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}" >> "${GITHUB_ENV}"

echo "Installed CUDA version is: ${CUDA_VERSION}"
echo "CUDA install location: ${{steps.cuda-toolkit.outputs.CUDA_PATH}}"
nvcc -V
echo "Torch index URL: ${TORCH_INDEX_URL}"
echo "Torch index URL: ${PIP_EXTRA_INDEX_URL}"

- name: Upgrade pip
run: |
python -m pip install --upgrade pip setuptools

- name: Install PyTorch and FuncTorch nightly
run: |
export PIP_EXTRA_INDEX_URL="${PIP_EXTRA_INDEX_URL//whl/whl\/nightly}"
python -m pip install 'torch >= 1.13.0dev' ninja
python -m pip install git+https://github.com/pytorch/functorch.git

- name: Install dependencies
run: |
python -m pip install --extra-index-url "${TORCH_INDEX_URL}" \
-r tests/requirements.txt
python -m pip install -r tests/requirements.txt

- name: Install TorchOpt
run: |
export PIP_EXTRA_INDEX_URL="${PIP_EXTRA_INDEX_URL//whl/whl\/nightly}"
python -m pip install -vvv -e .

- name: Test with pytest
Expand Down
10 changes: 5 additions & 5 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -394,12 +394,12 @@ fabric.properties

##### Vim.gitignore #####
# Swap
[._]*.s[a-v][a-z]
.*.s[a-v][a-z]
!*.svg # comment out if you don't need vector files
[._]*.sw[a-p]
[._]s[a-rt-v][a-z]
[._]ss[a-gi-z]
[._]sw[a-p]
.*.sw[a-p]
.s[a-rt-v][a-z]
.ss[a-gi-z]
.sw[a-p]

# Session
Session.vim
Expand Down
55 changes: 29 additions & 26 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,26 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add option `maximize` option to optimizers by [@XuehaiPan](https://github.com/XuehaiPan) in [#64](https://github.com/metaopt/TorchOpt/pull/64).
- Refactor tests using `pytest.mark.parametrize` and enabling parallel testing by [@XuehaiPan](https://github.com/XuehaiPan) and [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#55](https://github.com/metaopt/TorchOpt/pull/55).
- Add maml-omniglot few-shot classification example using functorch.vmap by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#39](https://github.com/metaopt/TorchOpt/pull/39).
- Add parallel training on one GPU using functorch.vmap example by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#32](https://github.com/metaopt/TorchOpt/pull/32).
- Add question/help/support issue template by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#43](https://github.com/metaopt/TorchOpt/pull/43).
- Add optional argument `params` to update function in gradient transformations by [@XuehaiPan](https://github.com/XuehaiPan) in [#65](https://github.com/metaopt/torchopt/pull/65).
- Add option `weight_decay` option to optimizers by [@XuehaiPan](https://github.com/XuehaiPan) in [#65](https://github.com/metaopt/torchopt/pull/65).
- Add option `maximize` option to optimizers by [@XuehaiPan](https://github.com/XuehaiPan) in [#64](https://github.com/metaopt/torchopt/pull/64).
- Refactor tests using `pytest.mark.parametrize` and enabling parallel testing by [@XuehaiPan](https://github.com/XuehaiPan) and [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#55](https://github.com/metaopt/torchopt/pull/55).
- Add maml-omniglot few-shot classification example using functorch.vmap by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#39](https://github.com/metaopt/torchopt/pull/39).
- Add parallel training on one GPU using functorch.vmap example by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#32](https://github.com/metaopt/torchopt/pull/32).
- Add question/help/support issue template by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#43](https://github.com/metaopt/torchopt/pull/43).

### Changed

- Replace JAX PyTrees with OpTree by [@XuehaiPan](https://github.com/XuehaiPan) in [#62](https://github.com/metaopt/TorchOpt/pull/62).
- Update image link in README to support PyPI rendering by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#56](https://github.com/metaopt/TorchOpt/pull/56).
- Align argument names with PyTorch by [@XuehaiPan](https://github.com/XuehaiPan) in [#65](https://github.com/metaopt/torchopt/pull/65).
- Replace JAX PyTrees with OpTree by [@XuehaiPan](https://github.com/XuehaiPan) in [#62](https://github.com/metaopt/torchopt/pull/62).
- Update image link in README to support PyPI rendering by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#56](https://github.com/metaopt/torchopt/pull/56).

### Fixed

- Fix RMSProp optimizer by [@XuehaiPan](https://github.com/XuehaiPan) in [#55](https://github.com/metaopt/TorchOpt/pull/55).
- Fix momentum tracing by [@XuehaiPan](https://github.com/XuehaiPan) in [#58](https://github.com/metaopt/TorchOpt/pull/58).
- Fix CUDA build for accelerated OP by [@XuehaiPan](https://github.com/XuehaiPan) in [#53](https://github.com/metaopt/TorchOpt/pull/53).
- Fix gamma error in MAML-RL implementation by [@Benjamin-eecs](https://github.com/Benjamin-eecs) [#47](https://github.com/metaopt/TorchOpt/pull/47).
- Fix RMSProp optimizer by [@XuehaiPan](https://github.com/XuehaiPan) in [#55](https://github.com/metaopt/torchopt/pull/55).
- Fix momentum tracing by [@XuehaiPan](https://github.com/XuehaiPan) in [#58](https://github.com/metaopt/torchopt/pull/58).
- Fix CUDA build for accelerated OP by [@XuehaiPan](https://github.com/XuehaiPan) in [#53](https://github.com/metaopt/torchopt/pull/53).
- Fix gamma error in MAML-RL implementation by [@Benjamin-eecs](https://github.com/Benjamin-eecs) [#47](https://github.com/metaopt/torchopt/pull/47).

### Removed

Expand All @@ -39,37 +42,37 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Bump PyTorch version to 1.12.1 by [@XuehaiPan](https://github.com/XuehaiPan) in [#49](https://github.com/metaopt/TorchOpt/pull/49).
- CPU-only build without `nvcc` requirement by [@XuehaiPan](https://github.com/XuehaiPan) in [#51](https://github.com/metaopt/TorchOpt/pull/51).
- Use [`cibuildwheel`](https://github.com/pypa/cibuildwheel) to build wheels by [@XuehaiPan](https://github.com/XuehaiPan) in [#45](https://github.com/metaopt/TorchOpt/pull/45).
- Use dynamic process number in CPU kernels by [@JieRen98](https://github.com/JieRen98) in [#42](https://github.com/metaopt/TorchOpt/pull/42).
- Bump PyTorch version to 1.12.1 by [@XuehaiPan](https://github.com/XuehaiPan) in [#49](https://github.com/metaopt/torchopt/pull/49).
- CPU-only build without `nvcc` requirement by [@XuehaiPan](https://github.com/XuehaiPan) in [#51](https://github.com/metaopt/torchopt/pull/51).
- Use [`cibuildwheel`](https://github.com/pypa/cibuildwheel) to build wheels by [@XuehaiPan](https://github.com/XuehaiPan) in [#45](https://github.com/metaopt/torchopt/pull/45).
- Use dynamic process number in CPU kernels by [@JieRen98](https://github.com/JieRen98) in [#42](https://github.com/metaopt/torchopt/pull/42).

### Changed

- Use correct Python Ctype for pybind11 function prototype [@XuehaiPan](https://github.com/XuehaiPan) in [#52](https://github.com/metaopt/TorchOpt/pull/52).
- Use correct Python Ctype for pybind11 function prototype [@XuehaiPan](https://github.com/XuehaiPan) in [#52](https://github.com/metaopt/torchopt/pull/52).

------

## [0.4.2] - 2022-07-26

### Added

- Read the Docs integration by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#34](https://github.com/metaopt/TorchOpt/pull/34).
- Update documentation and code styles by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#22](https://github.com/metaopt/TorchOpt/pull/22).
- Update tutorial notebooks by [@XuehaiPan](https://github.com/XuehaiPan) in [#27](https://github.com/metaopt/TorchOpt/pull/27).
- Bump PyTorch version to 1.12 by [@XuehaiPan](https://github.com/XuehaiPan) in [#25](https://github.com/metaopt/TorchOpt/pull/25).
- Support custom Python executable path in `CMakeLists.txt` by [@XuehaiPan](https://github.com/XuehaiPan) in [#18](https://github.com/metaopt/TorchOpt/pull/18).
- Add citation information by [@waterhorse1](https://github.com/waterhorse1) in [#14](https://github.com/metaopt/TorchOpt/pull/14) and [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#15](https://github.com/metaopt/TorchOpt/pull/15).
- Implement RMSProp optimizer by [@future-xy](https://github.com/future-xy) in [#8](https://github.com/metaopt/TorchOpt/pull/8).
- Read the Docs integration by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#34](https://github.com/metaopt/torchopt/pull/34).
- Update documentation and code styles by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#22](https://github.com/metaopt/torchopt/pull/22).
- Update tutorial notebooks by [@XuehaiPan](https://github.com/XuehaiPan) in [#27](https://github.com/metaopt/torchopt/pull/27).
- Bump PyTorch version to 1.12 by [@XuehaiPan](https://github.com/XuehaiPan) in [#25](https://github.com/metaopt/torchopt/pull/25).
- Support custom Python executable path in `CMakeLists.txt` by [@XuehaiPan](https://github.com/XuehaiPan) in [#18](https://github.com/metaopt/torchopt/pull/18).
- Add citation information by [@waterhorse1](https://github.com/waterhorse1) in [#14](https://github.com/metaopt/torchopt/pull/14) and [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#15](https://github.com/metaopt/torchopt/pull/15).
- Implement RMSProp optimizer by [@future-xy](https://github.com/future-xy) in [#8](https://github.com/metaopt/torchopt/pull/8).

### Changed

- Use `pyproject.toml` for packaging and update GitHub Action workflows by [@XuehaiPan](https://github.com/XuehaiPan) in [#31](https://github.com/metaopt/TorchOpt/pull/31).
- Rename the package from `TorchOpt` to `torchopt` by [@XuehaiPan](https://github.com/XuehaiPan) in [#20](https://github.com/metaopt/TorchOpt/pull/20).
- Use `pyproject.toml` for packaging and update GitHub Action workflows by [@XuehaiPan](https://github.com/XuehaiPan) in [#31](https://github.com/metaopt/torchopt/pull/31).
- Rename the package from `TorchOpt` to `torchopt` by [@XuehaiPan](https://github.com/XuehaiPan) in [#20](https://github.com/metaopt/torchopt/pull/20).

### Fixed

- Fixed errors while building from the source and add `conda` environment recipe by [@XuehaiPan](https://github.com/XuehaiPan) in [#24](https://github.com/metaopt/TorchOpt/pull/24).
- Fixed errors while building from the source and add `conda` environment recipe by [@XuehaiPan](https://github.com/XuehaiPan) in [#24](https://github.com/metaopt/torchopt/pull/24).

------

Expand Down
2 changes: 1 addition & 1 deletion CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ authors:
version: 0.4.3
date-released: "2022-08-08"
license: Apache-2.0
repository-code: "https://github.com/metaopt/TorchOpt"
repository-code: "https://github.com/metaopt/torchopt"
13 changes: 6 additions & 7 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ RUN echo "export PS1='[\[\e[1;33m\]\u\[\e[0m\]:\[\e[1;35m\]\w\[\e[0m\]]\$ '" >>

# Setup virtual environment
RUN /usr/bin/python3.9 -m venv --upgrade-deps ~/venv && rm -rf ~/.pip/cache
RUN TORCH_INDEX_URL="https://download.pytorch.org/whl/cu$(echo "${CUDA_VERSION}" | cut -d'.' -f-2 | tr -d '.')" && \
echo "export TORCH_INDEX_URL='${TORCH_INDEX_URL}'" >> ~/venv/bin/activate && \
RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cu$(echo "${CUDA_VERSION}" | cut -d'.' -f-2 | tr -d '.')" && \
echo "export PIP_EXTRA_INDEX_URL='${PIP_EXTRA_INDEX_URL}'" >> ~/venv/bin/activate && \
echo "source /home/torchopt/venv/bin/activate" >> ~/.bashrc

# Install dependencies
WORKDIR /home/torchopt/TorchOpt
WORKDIR /home/torchopt/torchopt
COPY --chown=torchopt requirements.txt requirements.txt
RUN source ~/venv/bin/activate && \
python -m pip install --extra-index-url "${TORCH_INDEX_URL}" -r requirements.txt && \
python -m pip install -r requirements.txt && \
rm -rf ~/.pip/cache ~/.cache/pip

####################################################################################################
Expand All @@ -63,8 +63,7 @@ RUN go install github.com/google/addlicense@latest
COPY --chown=torchopt tests/requirements.txt tests/requirements.txt
COPY --chown=torchopt tutorials/requirements.txt tutorials/requirements.txt
RUN source ~/venv/bin/activate && \
python -m pip install --extra-index-url "${TORCH_INDEX_URL}" \
-r tests/requirements.txt -r tutorials/requirements.txt && \
python -m pip install -r tests/requirements.txt -r tutorials/requirements.txt && \
rm -rf ~/.pip/cache ~/.cache/pip

####################################################################################################
Expand All @@ -84,4 +83,4 @@ ENTRYPOINT [ "/bin/bash", "--login" ]

FROM devel-builder AS devel

COPY --from=base /home/torchopt/TorchOpt .
COPY --from=base /home/torchopt/torchopt .
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy