diff --git a/.github/ISSUE_TEMPLATE/questions_help_support.md b/.github/ISSUE_TEMPLATE/questions_help_support.md new file mode 100644 index 00000000..072d2e52 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/questions_help_support.md @@ -0,0 +1,17 @@ +--- +name: Questions / Help / Support +about: Do you need support? +title: "[Question]" +labels: "question" +assignees: Benjamin-eecs + +--- + +## Questions + + + +## Checklist + +- [ ] I have checked that there is no similar issue in the repo (**required**) +- [ ] I have read the [documentation](https://torchopt.readthedocs.io/) (**required**) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 8b26e861..bebf3cf1 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -20,6 +20,14 @@ on: - published # Allow to trigger the workflow manually workflow_dispatch: + inputs: + task: + description: "Task type" + type: choice + options: + - build-only + - build-and-publish + required: true permissions: contents: read @@ -28,11 +36,15 @@ concurrency: group: "${{ github.workflow }}-${{ github.ref }}" cancel-in-progress: ${{ github.event_name == 'pull_request' }} +env: + CUDA_VERSION: "11.6" + TEST_TORCH_SPECS: "cpu cu113 cu116" + jobs: - build: - runs-on: ubuntu-18.04 + build-sdist: + runs-on: ubuntu-latest if: github.repository == 'metaopt/TorchOpt' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/')) - timeout-minutes: 45 + timeout-minutes: 10 steps: - name: Checkout uses: actions/checkout@v3 @@ -40,146 +52,93 @@ jobs: submodules: "recursive" fetch-depth: 1 - - name: Set up Python 3.7 - id: py37 + - name: Set up Python uses: actions/setup-python@v4 with: - python-version: "3.7" - update-environment: false + python-version: "3.7 - 3.10" + update-environment: true - - name: Set up Python 3.8 - id: py38 - uses: actions/setup-python@v4 - with: - python-version: "3.8" - update-environment: false + - name: Install dependencies + run: python -m pip install --upgrade pip setuptools wheel build - - name: Set up Python 3.9 - id: py39 - uses: actions/setup-python@v4 - with: - python-version: "3.9" - update-environment: false + - name: Build sdist + run: python -m build --sdist - - name: Set up Python 3.10 - id: py310 - uses: actions/setup-python@v4 + - name: Upload artifact + uses: actions/upload-artifact@v3 with: - python-version: "3.10" - update-environment: false + name: sdist + path: dist/*.tar.gz + if-no-files-found: error - - name: Set up Python executable paths - run: | - echo "${{ steps.py37.outputs.python-path }}" > .python-paths - echo "${{ steps.py38.outputs.python-path }}" >> .python-paths - echo "${{ steps.py39.outputs.python-path }}" >> .python-paths - echo "${{ steps.py310.outputs.python-path }}" >> .python-paths - - - name: Setup CUDA Toolkit - uses: Jimver/cuda-toolkit@v0.2.7 - id: cuda-toolkit + build-wheels: + runs-on: ubuntu-latest + needs: [build-sdist] + if: github.repository == 'metaopt/TorchOpt' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/')) + timeout-minutes: 60 + steps: + - name: Checkout + uses: actions/checkout@v3 with: - cuda: "11.6.2" - method: network - sub-packages: '["nvcc"]' - - run: | - CUDA_VERSION="${{steps.cuda-toolkit.outputs.cuda}}" - echo "CUDA_VERSION=${CUDA_VERSION}" >> "${GITHUB_ENV}" - TORCH_INDEX_URL="https://download.pytorch.org/whl/cu$(echo "${CUDA_VERSION}" | cut -d'.' -f-2 | tr -d '.')" - echo "TORCH_INDEX_URL=${TORCH_INDEX_URL}" >> "${GITHUB_ENV}" - - echo "Installed CUDA version is: ${CUDA_VERSION}" - echo "CUDA install location: ${{steps.cuda-toolkit.outputs.CUDA_PATH}}" - nvcc -V - echo "Torch index URL: ${TORCH_INDEX_URL}" - - - name: Build sdist and wheels - run: | - DEFAULT_PYTHON="$(head -n 1 .python-paths)" - - while read -r PYTHON; do - echo "Building wheel with Python: ${PYTHON} ($("${PYTHON}" --version))" - "${PYTHON}" -m pip install --upgrade pip setuptools wheel build - "${PYTHON}" -m pip install --extra-index-url "${TORCH_INDEX_URL}" \ - -r requirements.txt - if [[ "${PYTHON}" == "${DEFAULT_PYTHON}" ]]; then - "${PYTHON}" -m build - else - "${PYTHON}" -m build --wheel - fi - done < .python-paths - - - name: List built sdist and wheels - run: | - if [[ -n "$(find dist -maxdepth 0 -not -empty -print 2>/dev/null)" ]]; then - echo "Built sdist and wheels:" - ls -lh dist/ - else - echo "No sdist and wheels are built." - exit 1 - fi + submodules: "recursive" + fetch-depth: 1 - - name: Audit and repair wheels - run: | - while read -r PYTHON; do - PYVER="cp$("${PYTHON}" --version | cut -d ' ' -f2 | cut -d '.' -f-2 | tr -d '.')" - echo "Audit and repair wheel for Python: ${PYTHON} (${PYVER})" - LIBTORCH_PATH="$("${PYTHON}" -c 'import os, site; print(os.path.join(site.getsitepackages()[0], "torch", "lib"))')" - "${PYTHON}" -m pip install --upgrade git+https://github.com/XuehaiPan/auditwheel.git@torchopt - ( - export LD_LIBRARY_PATH="${LIBTORCH_PATH}${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" - echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" - "${PYTHON}" -m auditwheel show dist/torchopt-*-${PYVER}-*.whl && - "${PYTHON}" -m auditwheel repair --plat manylinux2014_x86_64 --wheel-dir wheelhouse dist/torchopt-*-${PYVER}-*.whl - ) - done < .python-paths - - rm dist/torchopt-*.whl - mv wheelhouse/torchopt-*manylinux*.whl dist/ - - - name: List built sdist and wheels - run: | - if [[ -n "$(find dist -maxdepth 0 -not -empty -print 2>/dev/null)" ]]; then - echo "Built sdist and wheels:" - ls -lh dist/ - else - echo "No sdist and wheels are built." - exit 1 - fi + - name: Build wheels + uses: pypa/cibuildwheel@v2.8.1 + with: + package-dir: . + output-dir: wheelhouse + config-file: "{package}/pyproject.toml" - - name: Test sdist and wheels - run: | - DEFAULT_PYTHON="$(head -n 1 .python-paths)" - while read -r PYTHON; do - PYVER="cp$("${PYTHON}" --version | cut -d ' ' -f2 | cut -d '.' -f-2 | tr -d '.')" - mkdir -p "temp-${PYVER}" - pushd "temp-${PYVER}" - if [[ "${PYTHON}" == "${DEFAULT_PYTHON}" ]]; then - echo "Testing sdist with Python: ${PYTHON} (${PYVER})" - "${PYTHON}" -m pip uninstall torch torchopt -y - "${PYTHON}" -m pip install --extra-index-url https://download.pytorch.org/whl/cpu \ - ../dist/torchopt-*.tar.gz - "${PYTHON}" -c 'import torchopt' - fi - echo "Testing wheel with Python: ${PYTHON} (${PYVER})" - "${PYTHON}" -m pip uninstall torch torchopt -y - "${PYTHON}" -m pip install --extra-index-url https://download.pytorch.org/whl/cpu \ - ../dist/torchopt-*-${PYVER}-*.whl - "${PYTHON}" -c 'import torchopt' - "${PYTHON}" -m pip uninstall torch torchopt -y - popd - done < .python-paths + - uses: actions/upload-artifact@v3 + with: + name: wheels + path: wheelhouse/*.whl + if-no-files-found: error + + publish: + runs-on: ubuntu-latest + needs: [build-sdist, build-wheels] + if: | + github.repository == 'metaopt/TorchOpt' && github.event_name != 'pull_request' && + (github.event_name != 'workflow_dispatch' || github.event.inputs.task == 'build-and-publish') && + (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/')) + timeout-minutes: 15 + steps: + - name: Set up Python + uses: actions/setup-python@v4 + if: startsWith(github.ref, 'refs/tags/') + with: + python-version: "3.7 - 3.10" + update-environment: true - name: Check consistency between the package version and release tag if: startsWith(github.ref, 'refs/tags/') run: | + PYTHON="$(head -n 1 .python-paths)" + PACKAGE_VER="v$("${PYTHON}" setup.py --version)" RELEASE_TAG="${GITHUB_REF#refs/*/}" - PACKAGE_VER="v$(python setup.py --version)" if [[ "${PACKAGE_VER}" != "${RELEASE_TAG}" ]]; then echo "package ver. (${PACKAGE_VER}) != release tag. (${RELEASE_TAG})" exit 1 fi + - name: Download built sdist + uses: actions/download-artifact@v3 + with: + # unpacks default artifact into dist/ + # if `name: artifact` is omitted, the action will create extra parent dir + name: sdist + path: dist + + - name: Download built wheels + uses: actions/download-artifact@v3 + with: + # unpacks default artifact into dist/ + # if `name: artifact` is omitted, the action will create extra parent dir + name: wheels + path: dist + - name: Publish to TestPyPI if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' uses: pypa/gh-action-pypi-publish@v1.5.0 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index f2393c77..274133de 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -5,6 +5,8 @@ on: branches: - main pull_request: + # Allow to trigger the workflow manually + workflow_dispatch: permissions: contents: read @@ -52,14 +54,9 @@ jobs: run: | python -m pip install --upgrade pip setuptools - - name: Install dependencies - run: | - python -m pip install --extra-index-url "${TORCH_INDEX_URL}" \ - -r tests/requirements.txt -r docs/requirements.txt - - name: Install TorchOpt run: | - python -m pip install -e . + python -m pip install -vvv -e '.[lint]' - name: pre-commit run: | @@ -93,6 +90,11 @@ jobs: run: | make mypy + - name: Install dependencies + run: | + python -m pip install --extra-index-url "${TORCH_INDEX_URL}" \ + -r docs/requirements.txt + - name: docstyle run: | make docstyle diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5c62ff1b..5692839a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,6 +16,8 @@ on: - tests/** - torchopt/** - .github/workflows/tests.yml + # Allow to trigger the workflow manually + workflow_dispatch: permissions: contents: read @@ -70,7 +72,7 @@ jobs: - name: Install TorchOpt run: | - python -m pip install -e . + python -m pip install -vvv -e . - name: Test with pytest run: | diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 88b7a202..73e1e60f 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -10,6 +10,10 @@ build: os: ubuntu-20.04 tools: python: mambaforge-4.10 + jobs: + post_install: + - python -m pip install --upgrade pip setuptools + - python -m pip install --no-build-isolation --editable . # Optionally declare the Python requirements required to build your docs conda: @@ -24,9 +28,3 @@ sphinx: builder: html configuration: docs/source/conf.py fail_on_warning: true - -# Optionally declare the Python requirements required to build your docs -python: - install: - - method: pip - path: . diff --git a/CHANGELOG.md b/CHANGELOG.md index 70cbe2e8..62d6ab83 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,29 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +### Changed + +### Fixed + +### Removed + +------ + +## [0.4.3] - 2022-08-08 + +### Added + +- Bump PyTorch version to 1.12.1 by [@XuehaiPan](https://github.com/XuehaiPan) in [#49](https://github.com/metaopt/TorchOpt/pull/49). +- CPU-only build without `nvcc` requirement by [@XuehaiPan](https://github.com/XuehaiPan) in [#51](https://github.com/metaopt/TorchOpt/pull/51). +- Use [`cibuildwheel`](https://github.com/pypa/cibuildwheel) to build wheels by [@XuehaiPan](https://github.com/XuehaiPan) in [#45](https://github.com/metaopt/TorchOpt/pull/45). +- Use dynamic process number in CPU kernels by [@JieRen98](https://github.com/JieRen98) in [#42](https://github.com/metaopt/TorchOpt/pull/42). + +### Changed + +- Use correct Python Ctype for pybind11 function prototype [@XuehaiPan](https://github.com/XuehaiPan) in [#52](https://github.com/metaopt/TorchOpt/pull/52). + ------ ## [0.4.2] - 2022-07-26 @@ -51,7 +74,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ------ -[Unreleased]: https://github.com/olivierlacan/keep-a-changelog/compare/v0.4.2...HEAD +[Unreleased]: https://github.com/olivierlacan/keep-a-changelog/compare/v0.4.3...HEAD +[0.4.3]: https://github.com/olivierlacan/keep-a-changelog/compare/v0.4.2...v0.4.3 [0.4.2]: https://github.com/olivierlacan/keep-a-changelog/compare/v0.4.1...v0.4.2 [0.4.1]: https://github.com/olivierlacan/keep-a-changelog/compare/v0.4.0...v0.4.1 [0.4.0]: https://github.com/olivierlacan/keep-a-changelog/releases/tag/v0.4.0 diff --git a/CITATION.cff b/CITATION.cff index 60c65cb3..bb49226a 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -28,7 +28,7 @@ authors: family-names: Yang affiliation: Peking University email: yaodong.yang@pku.edu.cn -version: 0.4.2 -date-released: "2022-07-26" +version: 0.4.3 +date-released: "2022-08-08" license: Apache-2.0 repository-code: "https://github.com/metaopt/TorchOpt" diff --git a/CMakeLists.txt b/CMakeLists.txt index 523dc849..b4b5400c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,33 +13,43 @@ # limitations under the License. # ============================================================================== -cmake_minimum_required(VERSION 3.4) -project(torchopt LANGUAGES CXX CUDA) +cmake_minimum_required(VERSION 3.8) +project(torchopt LANGUAGES CXX) if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) endif() -find_package(CUDA REQUIRED) -cuda_select_nvcc_arch_flags(CUDA_ARCH_FLAGS All) -list(APPEND CUDA_NVCC_FLAGS ${CUDA_ARCH_FLAGS}) - set(CMAKE_CXX_STANDARD 14) -set(CMAKE_CUDA_STANDARD 14) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -pthread -fPIC -fopenmp") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3") -set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -O3") + +find_package(CUDA) + +if(CUDA_FOUND) + message(STATUS "Found CUDA, enabling CUDA support.") + enable_language(CUDA) + + cuda_select_nvcc_arch_flags(CUDA_ARCH_FLAGS All) + list(APPEND CUDA_NVCC_FLAGS ${CUDA_ARCH_FLAGS}) + set(CMAKE_CUDA_STANDARD 14) + set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -O3") +else() + message(STATUS "CUDA not found, build for CPU-only.") +endif() function(system) set(options STRIP) set(oneValueArgs OUTPUT_VARIABLE ERROR_VARIABLE WORKING_DIRECTORY) set(multiValueArgs COMMAND) - cmake_parse_arguments(SYSTEM - "${options}" - "${oneValueArgs}" - "${multiValueArgs}" - "${ARGN}") + cmake_parse_arguments( + SYSTEM + "${options}" + "${oneValueArgs}" + "${multiValueArgs}" + "${ARGN}" + ) if(NOT DEFINED SYSTEM_WORKING_DIRECTORY) set(SYSTEM_WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}") @@ -51,6 +61,7 @@ function(system) ERROR_VARIABLE STDERR WORKING_DIRECTORY "${SYSTEM_WORKING_DIRECTORY}" ) + if("${SYSTEM_STRIP}") string(STRIP "${STDOUT}" STDOUT) string(STRIP "${STDERR}" STDERR) @@ -144,8 +155,7 @@ endif() unset(TORCH_LIBRARIES) foreach(VAR_PATH ${TORCH_LIBRARY_PATH}) - file(GLOB TORCH_LIBRARY "${VAR_PATH}/*.so") - list(APPEND TORCH_LIBRARIES "${TORCH_LIBRARY}") + list(APPEND TORCH_LIBRARIES "${VAR_PATH}/libtorch_python.so") endforeach() message(STATUS "Detected Torch libraries: \"${TORCH_LIBRARIES}\"") diff --git a/README.md b/README.md index c73ae163..cf48f0d0 100644 --- a/README.md +++ b/README.md @@ -219,12 +219,20 @@ Requirements - (Optional) For visualizing computation graphs - [Graphviz](https://graphviz.org/download/) (for Linux users use `apt/yum install graphviz` or `conda install -c anaconda python-graphviz`) -Please follow the instructions at to install PyTorch in your Python environment first. Then run the following command to install TorchOpt from PyPI ([![PyPI](https://img.shields.io/pypi/v/torchopt?label=PyPI)](https://pypi.org/project/torchopt) / ![Status](https://img.shields.io/pypi/status/torchopt?label=Status)): +**Please follow the instructions at to install PyTorch in your Python environment first.** Then run the following command to install TorchOpt from PyPI ([![PyPI](https://img.shields.io/pypi/v/torchopt?label=PyPI)](https://pypi.org/project/torchopt) / ![Status](https://img.shields.io/pypi/status/torchopt?label=Status)): ```bash pip3 install torchopt ``` +If the minimum version of PyTorch is not satisfied, `pip` will install/upgrade it for you. Please be careful about the `torch` build for CPU / CUDA support (e.g. `cpu`, `cu102`, `cu113`). You may need to specify the extra index URL for the `torch` package: + +```bash +pip3 install torchopt --extra-index-url https://download.pytorch.org/whl/cu116 +``` + +See for more information about installing PyTorch. + You can also build shared libraries from source, use: ```bash @@ -243,7 +251,7 @@ cd TorchOpt CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe.yaml conda activate torchopt -pip3 install -e . +pip3 install --no-build-isolation --editable . ``` -------------------------------------------------------------------------------- @@ -252,7 +260,6 @@ pip3 install -e . - [ ] Support general implicit differentiation with functional programing. - [ ] Support more optimizers such as AdamW, RMSProp -- [ ] CPU-accelerated optimizer ## Changelog diff --git a/conda-recipe.yaml b/conda-recipe.yaml index 3c10a3ed..ed082846 100644 --- a/conda-recipe.yaml +++ b/conda-recipe.yaml @@ -17,11 +17,11 @@ dependencies: - pip # Learning - - pytorch::pytorch = 1.12 + - pytorch::pytorch >= 1.12 - pytorch::torchvision - pytorch::pytorch-mutex = *=*cuda* - pip: - - functorch + - functorch >= 0.2 - torchviz - sphinxcontrib-katex # for documentation - jax @@ -76,7 +76,7 @@ dependencies: - mypy - flake8 - flake8-bugbear - - doc8 + - doc8 < 1.0.0a0 - pydocstyle - clang-format - clang-tools # clang-tidy diff --git a/docs/conda-recipe.yaml b/docs/conda-recipe.yaml index d55c0f19..3e9f51cb 100644 --- a/docs/conda-recipe.yaml +++ b/docs/conda-recipe.yaml @@ -30,12 +30,12 @@ dependencies: - pip # Learning - - pytorch::pytorch = 1.12 + - pytorch::pytorch >= 1.12 - pytorch::torchvision - pytorch::pytorch-mutex = *=*cpu* - pip: - jax[cpu] >= 0.3 - - functorch + - functorch >= 0.2 - torchviz - sphinxcontrib-katex # for documentation - tensorboard diff --git a/docs/requirements.txt b/docs/requirements.txt index 61b877af..8837b2a9 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,7 +1,7 @@ ---extra-index-url https://download.pytorch.org/whl/cu116 -torch == 1.12 +--extra-index-url https://download.pytorch.org/whl/cpu +torch >= 1.12 torchvision -functorch +functorch >= 0.2 --requirement ../requirements.txt diff --git a/docs/source/developer/contributing.rst b/docs/source/developer/contributing.rst index 278e2900..302b9fb3 100644 --- a/docs/source/developer/contributing.rst +++ b/docs/source/developer/contributing.rst @@ -33,17 +33,17 @@ Then you are ready to rock. Thanks for contributing to TorchOpt! Install Develop Version ----------------------- -To install TorchOpt in an "editable" mode, run +To install TorchOpt in an "editable" mode, run: .. code-block:: bash - pip install -e . + pip3 install --no-build-isolation --editable . -in the main directory. This installation is removable by +in the main directory. This installation is removable by: .. code-block:: bash - python setup.py develop --uninstall + pip3 uninstall torchopt Lint Check @@ -75,13 +75,46 @@ To check if everything conforms to the specification, run: Test Locally ------------ -This command will run automatic tests in the main directory +This command will run automatic tests in the main directory: .. code-block:: bash $ make test +Build Wheels +------------ + +To build compatible **manylinux2014** (:pep:`599`) wheels for distribution, you can use |cibuildwheel|_. You will need to install |docker|_ first. Then run the following command: + +.. code-block:: bash + + pip3 install --upgrade cibuildwheel + + export TEST_TORCH_SPECS="cpu cu113 cu116" # `torch` builds for testing + export CUDA_VERSION="11.6" # version of `nvcc` for compilation + python3 -m cibuildwheel --platform=linux --output-dir=wheelhouse --config-file=pyproject.toml + +It will install the CUDA compiler with ``CUDA_VERSION`` in the build container. Then build wheel binaries for all supported CPython versions. The outputs will be placed in the ``wheelhouse`` directory. + +To build a wheel for a specific CPython version, you can use the |CIBW_BUILD|_ environment variable. +For example, the following command will build a wheel for Python 3.7: + +.. code-block:: bash + + CIBW_BUILD="cp37*manylinux*" python3 -m cibuildwheel --platform=linux --output-dir=wheelhouse --config-file=pyproject.toml + +You can change ``cp37*`` to ``cp310*`` to build for Python 3.10. See https://cibuildwheel.readthedocs.io/en/stable/options for more options. + +.. |cibuildwheel| replace:: ``cibuildwheel`` +.. _cibuildwheel: https://github.com/pypa/cibuildwheel + +.. |CIBW_BUILD| replace:: ``CIBW_BUILD`` +.. _CIBW_BUILD: https://cibuildwheel.readthedocs.io/en/stable/options/#build-skip + +.. |docker| replace:: ``docker`` +.. _docker: https://www.docker.com + Documentation ------------- diff --git a/docs/source/index.rst b/docs/source/index.rst index 892a1090..157eb5ad 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -11,12 +11,12 @@ TorchOpt Installation ------------ -Requirements +Requirements: -- PyTorch -- JAX -- (Optional) For visualizing computation graphs - - `Graphviz `_ (for Linux users use ``apt/yum install graphviz`` or ``conda install -c anaconda python-graphviz``) +* PyTorch +* JAX +* (Optional) For visualizing computation graphs + * `Graphviz `_ (for Linux users use ``apt/yum install graphviz`` or ``conda install -c anaconda python-graphviz``) Please follow the instructions at https://pytorch.org to install PyTorch in your Python environment first. Then run the following command to install TorchOpt from PyPI: diff --git a/docs/source/spelling_wordlist.txt b/docs/source/spelling_wordlist.txt index db1e67a1..4fc50e3c 100644 --- a/docs/source/spelling_wordlist.txt +++ b/docs/source/spelling_wordlist.txt @@ -69,3 +69,4 @@ sgd SGD CHANGELOG Changelog +CPython diff --git a/examples/MAML-RL/maml.png b/examples/MAML-RL/maml.png index 8aaad571..221462c8 100644 Binary files a/examples/MAML-RL/maml.png and b/examples/MAML-RL/maml.png differ diff --git a/examples/MAML-RL/maml.py b/examples/MAML-RL/maml.py index 8734e000..f2bb38e9 100644 --- a/examples/MAML-RL/maml.py +++ b/examples/MAML-RL/maml.py @@ -70,7 +70,7 @@ def sample_traj(env, task, policy): next_obs_buf[step][batch] = next_ob acs_buf[step][batch] = ac rews_buf[step][batch] = rew - gammas_buf[step][batch] = done * GAMMA + gammas_buf[step][batch] = (1 - done) * GAMMA ob = next_ob return Traj( obs=obs_buf, @@ -99,7 +99,6 @@ def a2c_loss(traj, policy, value_coef): advs = lambda_returns - torch.squeeze(values, -1) action_loss = -(advs.detach() * log_probs).mean() value_loss = advs.pow(2).mean() - a2c_loss = action_loss + value_coef * value_loss return a2c_loss @@ -107,7 +106,7 @@ def a2c_loss(traj, policy, value_coef): def evaluate(env, seed, task_num, policy): pre_reward_ls = [] post_reward_ls = [] - inner_opt = torchopt.MetaSGD(policy, lr=0.5) + inner_opt = torchopt.MetaSGD(policy, lr=0.1) env = gym.make( 'TabularMDP-v0', **dict( @@ -123,7 +122,6 @@ def evaluate(env, seed, task_num, policy): for idx in range(task_num): for _ in range(inner_iters): pre_trajs = sample_traj(env, tasks[idx], policy) - inner_loss = a2c_loss(pre_trajs, policy, value_coef=0.5) inner_opt.step(inner_loss) post_trajs = sample_traj(env, tasks[idx], policy) @@ -153,7 +151,7 @@ def main(args): ) # Policy policy = CategoricalMLPPolicy(input_size=STATE_DIM, output_size=ACTION_DIM) - inner_opt = torchopt.MetaSGD(policy, lr=0.5) + inner_opt = torchopt.MetaSGD(policy, lr=0.1) outer_opt = optim.Adam(policy.parameters(), lr=1e-3) train_pre_reward = [] train_post_reward = [] @@ -170,7 +168,6 @@ def main(args): policy_state_dict = torchopt.extract_state_dict(policy) optim_state_dict = torchopt.extract_state_dict(inner_opt) for idx in range(TASK_NUM): - for _ in range(inner_iters): pre_trajs = sample_traj(env, tasks[idx], policy) inner_loss = a2c_loss(pre_trajs, policy, value_coef=0.5) diff --git a/examples/requirements.txt b/examples/requirements.txt index 9e2e108e..9055e15e 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -1,7 +1,7 @@ --extra-index-url https://download.pytorch.org/whl/cu116 -torch == 1.12 +torch >= 1.12 torchvision -functorch +functorch >= 0.2 --requirement ../requirements.txt diff --git a/include/adam_op/adam_op.h b/include/adam_op/adam_op.h index 38ebd0cc..77ae6ca9 100644 --- a/include/adam_op/adam_op.h +++ b/include/adam_op/adam_op.h @@ -23,32 +23,35 @@ namespace torchopt { TensorArray<3> adamForwardInplace(const torch::Tensor& updates, const torch::Tensor& mu, - const torch::Tensor& nu, const float b1, - const float b2, const float eps, - const float eps_root, const int count); + const torch::Tensor& nu, const pyfloat_t b1, + const pyfloat_t b2, const pyfloat_t eps, + const pyfloat_t eps_root, + const pyuint_t count); torch::Tensor adamForwardMu(const torch::Tensor& updates, - const torch::Tensor& mu, const float b1); + const torch::Tensor& mu, const pyfloat_t b1); torch::Tensor adamForwardNu(const torch::Tensor& updates, - const torch::Tensor& nu, const float b2); + const torch::Tensor& nu, const pyfloat_t b2); torch::Tensor adamForwardUpdates(const torch::Tensor& new_mu, - const torch::Tensor& new_nu, const float b1, - const float b2, const float eps, - const float eps_root, const int count); + const torch::Tensor& new_nu, + const pyfloat_t b1, const pyfloat_t b2, + const pyfloat_t eps, const pyfloat_t eps_root, + const pyuint_t count); TensorArray<2> adamBackwardMu(const torch::Tensor& dmu, const torch::Tensor& updates, - const torch::Tensor& mu, const float b1); + const torch::Tensor& mu, const pyfloat_t b1); TensorArray<2> adamBackwardNu(const torch::Tensor& dnu, const torch::Tensor& updates, - const torch::Tensor& nu, const float b2); + const torch::Tensor& nu, const pyfloat_t b2); TensorArray<2> adamBackwardUpdates(const torch::Tensor& dupdates, const torch::Tensor& updates, const torch::Tensor& new_mu, - const torch::Tensor& new_nu, const float b1, - const float b2, const int count); + const torch::Tensor& new_nu, + const pyfloat_t b1, const pyfloat_t b2, + const pyuint_t count); } // namespace torchopt diff --git a/include/adam_op/adam_op_impl.h b/include/adam_op/adam_op_impl_cpu.h similarity index 67% rename from include/adam_op/adam_op_impl.h rename to include/adam_op/adam_op_impl_cpu.h index 87562fb1..c65408db 100644 --- a/include/adam_op/adam_op_impl.h +++ b/include/adam_op/adam_op_impl_cpu.h @@ -21,35 +21,36 @@ #include "include/common.h" namespace torchopt { -TensorArray<3> adamForwardInplaceCPU(const torch::Tensor& updates, - const torch::Tensor& mu, - const torch::Tensor& nu, const float b1, - const float b2, const float eps, - const float eps_root, const int count); +TensorArray<3> adamForwardInplaceCPU( + const torch::Tensor& updates, const torch::Tensor& mu, + const torch::Tensor& nu, const pyfloat_t b1, const pyfloat_t b2, + const pyfloat_t eps, const pyfloat_t eps_root, const pyuint_t count); torch::Tensor adamForwardMuCPU(const torch::Tensor& updates, - const torch::Tensor& mu, const float b1); + const torch::Tensor& mu, const pyfloat_t b1); torch::Tensor adamForwardNuCPU(const torch::Tensor& updates, - const torch::Tensor& nu, const float b2); + const torch::Tensor& nu, const pyfloat_t b2); torch::Tensor adamForwardUpdatesCPU(const torch::Tensor& new_mu, - const torch::Tensor& new_nu, const float b1, - const float b2, const float eps, - const float eps_root, const int count); + const torch::Tensor& new_nu, + const pyfloat_t b1, const pyfloat_t b2, + const pyfloat_t eps, + const pyfloat_t eps_root, + const pyuint_t count); TensorArray<2> adamBackwardMuCPU(const torch::Tensor& dmu, const torch::Tensor& updates, - const torch::Tensor& mu, const float b1); + const torch::Tensor& mu, const pyfloat_t b1); TensorArray<2> adamBackwardNuCPU(const torch::Tensor& dnu, const torch::Tensor& updates, - const torch::Tensor& nu, const float b2); + const torch::Tensor& nu, const pyfloat_t b2); TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor& dupdates, const torch::Tensor& updates, const torch::Tensor& new_mu, const torch::Tensor& new_nu, - const float b1, const float b2, - const int count); + const pyfloat_t b1, const pyfloat_t b2, + const pyuint_t count); } // namespace torchopt diff --git a/include/adam_op/adam_op_impl.cuh b/include/adam_op/adam_op_impl_cuda.cuh similarity index 67% rename from include/adam_op/adam_op_impl.cuh rename to include/adam_op/adam_op_impl_cuda.cuh index c9dcba85..406374aa 100644 --- a/include/adam_op/adam_op_impl.cuh +++ b/include/adam_op/adam_op_impl_cuda.cuh @@ -21,36 +21,36 @@ #include "include/common.h" namespace torchopt { -TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor &updates, - const torch::Tensor &mu, - const torch::Tensor &nu, const float b1, - const float b2, const float eps, - const float eps_root, const int count); +TensorArray<3> adamForwardInplaceCUDA( + const torch::Tensor &updates, const torch::Tensor &mu, + const torch::Tensor &nu, const pyfloat_t b1, const pyfloat_t b2, + const pyfloat_t eps, const pyfloat_t eps_root, const pyuint_t count); torch::Tensor adamForwardMuCUDA(const torch::Tensor &updates, - const torch::Tensor &mu, const float b1); + const torch::Tensor &mu, const pyfloat_t b1); torch::Tensor adamForwardNuCUDA(const torch::Tensor &updates, - const torch::Tensor &nu, const float b2); + const torch::Tensor &nu, const pyfloat_t b2); torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor &new_mu, const torch::Tensor &new_nu, - const float b1, const float b2, - const float eps, const float eps_root, - const int count); + const pyfloat_t b1, const pyfloat_t b2, + const pyfloat_t eps, + const pyfloat_t eps_root, + const pyuint_t count); TensorArray<2> adamBackwardMuCUDA(const torch::Tensor &dmu, const torch::Tensor &updates, - const torch::Tensor &mu, const float b1); + const torch::Tensor &mu, const pyfloat_t b1); TensorArray<2> adamBackwardNuCUDA(const torch::Tensor &dnu, const torch::Tensor &updates, - const torch::Tensor &nu, const float b2); + const torch::Tensor &nu, const pyfloat_t b2); TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor &dupdates, const torch::Tensor &updates, const torch::Tensor &new_mu, const torch::Tensor &new_nu, - const float b1, const float b2, - const int count); + const pyfloat_t b1, const pyfloat_t b2, + const pyuint_t count); } // namespace torchopt diff --git a/include/common.h b/include/common.h index e4362013..801a85a4 100644 --- a/include/common.h +++ b/include/common.h @@ -17,6 +17,10 @@ #include #include +#include + +using pyfloat_t = double; +using pyuint_t = std::size_t; namespace torchopt { template diff --git a/pyproject.toml b/pyproject.toml index d76dd3dc..d60dcc99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,7 @@ +# Package ###################################################################### + [build-system] -requires = ["setuptools", "torch == 1.12", "numpy", "pybind11"] +requires = ["setuptools", "torch >= 1.12", "numpy", "pybind11"] build-backend = "setuptools.build_meta" [project] @@ -8,13 +10,13 @@ description = "A Jax-style optimizer for PyTorch." readme = "README.md" requires-python = ">= 3.7" authors = [ - {name = "TorchOpt Contributors"}, - {name = "Xuehai Pan", email = "XuehaiPan@pku.edu.cn"}, - {name = "Jie Ren", email = "jieren9806@gmail.com"}, - {name = "Xidong Feng", email = "xidong.feng.20@ucl.ac.uk"}, - {name = "Bo Liu", email = "benjaminliu.eecs@gmail.com"}, + { name = "TorchOpt Contributors" }, + { name = "Xuehai Pan", email = "XuehaiPan@pku.edu.cn" }, + { name = "Jie Ren", email = "jieren9806@gmail.com" }, + { name = "Xidong Feng", email = "xidong.feng.20@ucl.ac.uk" }, + { name = "Bo Liu", email = "benjaminliu.eecs@gmail.com" }, ] -license = {file = "LICENSE"} +license = { text = "Apache License, Version 2.0" } keywords = [ "PyTorch", "functorch", @@ -26,7 +28,7 @@ keywords = [ ] classifiers = [ "Development Status :: 4 - Beta", - "License :: OSI Approved :: Apache Software License 2.0 (Apache-2.0)", + "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", @@ -42,15 +44,13 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", ] dependencies = [ - "torch == 1.12", + "torch >= 1.12", "jax[cpu] >= 0.3", "numpy", "graphviz", "typing-extensions", ] -dynamic = [ - "version", -] +dynamic = ["version"] [project.urls] Homepage = "https://github.com/metaopt/TorchOpt" @@ -58,9 +58,99 @@ Repository = "https://github.com/metaopt/TorchOpt" Documentation = "https://torchopt.readthedocs.io" "Bug Report" = "https://github.com/metaopt/TorchOpt/issues" +[project.optional-dependencies] +lint = [ + "isort", + "black >= 22.6.0", + "pylint", + "mypy", + "flake8", + "flake8-bugbear", + "doc8 < 1.0.0a0", + "pydocstyle", + "pyenchant", + "cpplint", + "pre-commit", +] +test = [ + 'torchvision', + 'functorch >= 0.2', + 'pytest', + 'pytest-cov', + 'pytest-xdist', +] + [tool.setuptools.packages.find] include = ["torchopt", "torchopt.*"] +# Wheel builder ################################################################ +# Reference: https://cibuildwheel.readthedocs.io +[tool.cibuildwheel] +archs = ["x86_64"] +build = "*manylinux*" +skip = "pp*" +build-frontend = "pip" +build-verbosity = 3 +environment.CUDACXX = "/usr/local/cuda/bin/nvcc" +environment.DEFAULT_CUDA_VERSION = "11.6" +environment.DEFAULT_TEST_TORCH_SPECS = "cpu cu113 cu116" +environment-pass = ["CUDA_VERSION", "TEST_TORCH_SPECS"] +container-engine = "docker" + +before-all = """ + CUDA_VERSION="${CUDA_VERSION:-"${DEFAULT_CUDA_VERSION}"}" + if [[ "${CUDA_VERSION}" == "None" || "${CUDA_VERSION}" == "none" ]]; then + sed -i -E "s/__version__\\s*=\\s*.*$/\\0 + '+cpu'/" torchopt/version.py + else + CUDA_VERSION="$(echo "${CUDA_VERSION}" | cut -d"." -f-2)" + CUDA_PKG_SUFFIX="$(echo "${CUDA_VERSION}" | tr "." "-")" + echo "CUDA_VERSION=${CUDA_VERSION}" + yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo + yum clean all + yum install -y nvidia-driver-latest-libs "cuda-minimal-build-${CUDA_PKG_SUFFIX}" + fi + echo "cat torchopt/version.py"; cat torchopt/version.py + """ +test-extras = ["test"] +test-command = """ + SITE_PACKAGES="$(python -c 'print(__import__("sysconfig").get_path("purelib"))')" + TORCH_LIB_PATH="${SITE_PACKAGES}/torch/lib" + echo "LD_LIBRARY_PATH='${LD_LIBRARY_PATH}'" + echo "ls ${TORCH_LIB_PATH}"; ls -lh "${TORCH_LIB_PATH}" + find "${SITE_PACKAGES}/torchopt" -name "*.so" -print0 | + xargs -0 -I '{}' bash -c "echo 'ldd {}'; ldd '{}'; echo 'patchelf --print-rpath {}'; patchelf --print-rpath '{}'" + make -C "{project}" test || exit 1 + TORCH_VERSION="$(python -c 'print(__import__("torch").__version__.partition("+")[0])')" + TEST_TORCH_SPECS="${TEST_TORCH_SPECS:-"${DEFAULT_TEST_TORCH_SPECS}"}" + for spec in ${TEST_TORCH_SPECS}; do + python -m pip uninstall -y torch torchvision + export PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/${spec}" + echo "PIP_EXTRA_INDEX_URL='${PIP_EXTRA_INDEX_URL}'" + python -m pip install "torch==${TORCH_VERSION}" torchvision + echo "ls ${TORCH_LIB_PATH}"; ls -lh "${TORCH_LIB_PATH}" + find "${SITE_PACKAGES}/torchopt" -name "*.so" -print0 | + xargs -0 -I '{}' bash -c "echo 'ldd {}'; ldd '{}'; echo 'patchelf --print-rpath {}'; patchelf --print-rpath '{}'" + make -C "{project}" test || exit 1 + done + rm -rf ~/.pip/cache ~/.cache/pip + """ + +[tool.cibuildwheel.linux] +repair-wheel-command = """ + python -m pip install -r requirements.txt + SITE_PACKAGES="$(python -c 'print(__import__("sysconfig").get_path("purelib"))')" + TORCH_LIB_PATH="${SITE_PACKAGES}/torch/lib" + ( + export LD_LIBRARY_PATH="${TORCH_LIB_PATH}:/usr/local/cuda/lib64:/usr/local/cuda/targets/x86_64-linux/lib/stubs${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" + echo "ls ${TORCH_LIB_PATH}"; ls -lh "${TORCH_LIB_PATH}" + python -m pip install --force-reinstall git+https://github.com/XuehaiPan/auditwheel.git@torchopt + python -m auditwheel lddtree "{wheel}" + python -m auditwheel repair --no-copy-site-libs --wheel-dir="{dest_dir}" "{wheel}" + ) + """ + +# Linter tools ################################################################# + [tool.black] safe = true line-length = 100 diff --git a/requirements.txt b/requirements.txt index 21fb120c..539a8145 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -torch == 1.12 +torch >= 1.12 jax[cpu] >= 0.3 numpy graphviz diff --git a/setup.py b/setup.py index 169a767c..a3dfe441 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ import shutil import sys -from setuptools import find_packages, setup +from setuptools import setup try: diff --git a/src/adam_op/CMakeLists.txt b/src/adam_op/CMakeLists.txt index cea0371b..3c7226e3 100644 --- a/src/adam_op/CMakeLists.txt +++ b/src/adam_op/CMakeLists.txt @@ -13,36 +13,13 @@ # limitations under the License. # ============================================================================== -# add_library( -# adam_op_CUDA SHARED -# adam_op_impl.cu -# ) +set(adam_op_src adam_op.cpp adam_op_impl_cpu.cpp) -# target_link_libraries( -# adam_op_CUDA -# ${TORCH_LIBRARIES} -# ) +if(CUDA_FOUND) + list(APPEND adam_op_src adam_op_impl_cuda.cu) +endif() -# add_library( -# adam_op_CPU SHARED -# adam_op_impl.cpp -# ) - -# target_link_libraries( -# adam_op_CPU -# ${TORCH_LIBRARIES} -# ) - -# pybind11_add_module(adam_op adam_op.cpp) - -# target_link_libraries( -# adam_op PRIVATE -# adam_op_CPU -# adam_op_CUDA -# ${TORCH_LIBRARIES} -# ) - -pybind11_add_module(adam_op adam_op.cpp adam_op_impl.cpp adam_op_impl.cu) +pybind11_add_module(adam_op "${adam_op_src}") target_link_libraries( adam_op PRIVATE diff --git a/src/adam_op/adam_op.cpp b/src/adam_op/adam_op.cpp index a11c0116..bb4531ac 100644 --- a/src/adam_op/adam_op.cpp +++ b/src/adam_op/adam_op.cpp @@ -18,29 +18,38 @@ #include #include -#include "include/adam_op/adam_op_impl.cuh" -#include "include/adam_op/adam_op_impl.h" +#include "include/adam_op/adam_op_impl_cpu.h" +#if defined(__CUDACC__) +#include "include/adam_op/adam_op_impl_cuda.cuh" +#endif namespace torchopt { TensorArray<3> adamForwardInplace(const torch::Tensor& updates, const torch::Tensor& mu, - const torch::Tensor& nu, const float b1, - const float b2, const float eps, - const float eps_root, const int count) { + const torch::Tensor& nu, const pyfloat_t b1, + const pyfloat_t b2, const pyfloat_t eps, + const pyfloat_t eps_root, + const pyuint_t count) { +#if defined(__CUDACC__) if (updates.device().is_cuda()) { return adamForwardInplaceCUDA(updates, mu, nu, b1, b2, eps, eps_root, count); - } else if (updates.device().is_cpu()) { + } +#endif + if (updates.device().is_cpu()) { return adamForwardInplaceCPU(updates, mu, nu, b1, b2, eps, eps_root, count); } else { throw std::runtime_error("Not implemented"); } } torch::Tensor adamForwardMu(const torch::Tensor& updates, - const torch::Tensor& mu, const float b1) { + const torch::Tensor& mu, const pyfloat_t b1) { +#if defined(__CUDACC__) if (updates.device().is_cuda()) { return adamForwardMuCUDA(updates, mu, b1); - } else if (updates.device().is_cpu()) { + } +#endif + if (updates.device().is_cpu()) { return adamForwardMuCPU(updates, mu, b1); } else { throw std::runtime_error("Not implemented"); @@ -48,10 +57,13 @@ torch::Tensor adamForwardMu(const torch::Tensor& updates, } torch::Tensor adamForwardNu(const torch::Tensor& updates, - const torch::Tensor& nu, const float b2) { + const torch::Tensor& nu, const pyfloat_t b2) { +#if defined(__CUDACC__) if (updates.device().is_cuda()) { return adamForwardNuCUDA(updates, nu, b2); - } else if (updates.device().is_cpu()) { + } +#endif + if (updates.device().is_cpu()) { return adamForwardNuCPU(updates, nu, b2); } else { throw std::runtime_error("Not implemented"); @@ -59,12 +71,16 @@ torch::Tensor adamForwardNu(const torch::Tensor& updates, } torch::Tensor adamForwardUpdates(const torch::Tensor& new_mu, - const torch::Tensor& new_nu, const float b1, - const float b2, const float eps, - const float eps_root, const int count) { + const torch::Tensor& new_nu, + const pyfloat_t b1, const pyfloat_t b2, + const pyfloat_t eps, const pyfloat_t eps_root, + const pyuint_t count) { +#if defined(__CUDACC__) if (new_mu.device().is_cuda()) { return adamForwardUpdatesCUDA(new_mu, new_nu, b1, b2, eps, eps_root, count); - } else if (new_mu.device().is_cpu()) { + } +#endif + if (new_mu.device().is_cpu()) { return adamForwardUpdatesCPU(new_mu, new_nu, b1, b2, eps, eps_root, count); } else { throw std::runtime_error("Not implemented"); @@ -73,10 +89,13 @@ torch::Tensor adamForwardUpdates(const torch::Tensor& new_mu, TensorArray<2> adamBackwardMu(const torch::Tensor& dmu, const torch::Tensor& updates, - const torch::Tensor& mu, const float b1) { + const torch::Tensor& mu, const pyfloat_t b1) { +#if defined(__CUDACC__) if (dmu.device().is_cuda()) { return adamBackwardMuCUDA(dmu, updates, mu, b1); - } else if (dmu.device().is_cpu()) { + } +#endif + if (dmu.device().is_cpu()) { return adamBackwardMuCPU(dmu, updates, mu, b1); } else { throw std::runtime_error("Not implemented"); @@ -85,10 +104,13 @@ TensorArray<2> adamBackwardMu(const torch::Tensor& dmu, TensorArray<2> adamBackwardNu(const torch::Tensor& dnu, const torch::Tensor& updates, - const torch::Tensor& nu, const float b2) { + const torch::Tensor& nu, const pyfloat_t b2) { +#if defined(__CUDACC__) if (dnu.device().is_cuda()) { return adamBackwardNuCUDA(dnu, updates, nu, b2); - } else if (dnu.device().is_cpu()) { + } +#endif + if (dnu.device().is_cpu()) { return adamBackwardNuCPU(dnu, updates, nu, b2); } else { throw std::runtime_error("Not implemented"); @@ -98,12 +120,16 @@ TensorArray<2> adamBackwardNu(const torch::Tensor& dnu, TensorArray<2> adamBackwardUpdates(const torch::Tensor& dupdates, const torch::Tensor& updates, const torch::Tensor& new_mu, - const torch::Tensor& new_nu, const float b1, - const float b2, const int count) { + const torch::Tensor& new_nu, + const pyfloat_t b1, const pyfloat_t b2, + const pyuint_t count) { +#if defined(__CUDACC__) if (dupdates.device().is_cuda()) { return adamBackwardUpdatesCUDA(dupdates, updates, new_mu, new_nu, b1, b2, count); - } else if (dupdates.device().is_cpu()) { + } +#endif + if (dupdates.device().is_cpu()) { return adamBackwardUpdatesCPU(dupdates, updates, new_mu, new_nu, b1, b2, count); } else { diff --git a/src/adam_op/adam_op_impl.cpp b/src/adam_op/adam_op_impl_cpu.cpp similarity index 77% rename from src/adam_op/adam_op_impl.cpp rename to src/adam_op/adam_op_impl_cpu.cpp index 16be5251..c50a4cd4 100644 --- a/src/adam_op/adam_op_impl.cpp +++ b/src/adam_op/adam_op_impl_cpu.cpp @@ -13,7 +13,7 @@ // limitations under the License. // ============================================================================== -#include "include/adam_op/adam_op_impl.h" +#include "include/adam_op/adam_op_impl_cpu.h" #include #include @@ -31,7 +31,7 @@ void adamForwardInplaceCPUKernel( const other_t inv_one_minus_pow_b2, const other_t eps, const other_t eps_root, const size_t n, scalar_t* __restrict__ updates_ptr, scalar_t* __restrict__ mu_ptr, scalar_t* __restrict__ nu_ptr) { -#pragma omp parallel for num_threads(32) +#pragma omp parallel for num_threads(omp_get_num_procs()) for (size_t tid = 0; tid < n; ++tid) { const scalar_t updates = updates_ptr[tid]; const scalar_t mu = mu_ptr[tid]; @@ -50,20 +50,20 @@ void adamForwardInplaceCPUKernel( } } // namespace -TensorArray<3> adamForwardInplaceCPU(const torch::Tensor& updates, - const torch::Tensor& mu, - const torch::Tensor& nu, const float b1, - const float b2, const float eps, - const float eps_root, const int count) { - using other_t = float; - const float inv_one_minus_pow_b1 = 1 / (1 - std::pow(b1, count)); - const float inv_one_minus_pow_b2 = 1 / (1 - std::pow(b2, count)); +TensorArray<3> adamForwardInplaceCPU( + const torch::Tensor& updates, const torch::Tensor& mu, + const torch::Tensor& nu, const pyfloat_t b1, const pyfloat_t b2, + const pyfloat_t eps, const pyfloat_t eps_root, const pyuint_t count) { + using other_t = pyfloat_t; + const other_t inv_one_minus_pow_b1 = 1 / (1 - std::pow(b1, count)); + const other_t inv_one_minus_pow_b2 = 1 / (1 - std::pow(b2, count)); const size_t n = getTensorPlainSize(updates); AT_DISPATCH_FLOATING_TYPES( updates.scalar_type(), "adamForwardInplaceCPU", ([&] { - adamForwardInplaceCPUKernel( - b1, inv_one_minus_pow_b1, b2, inv_one_minus_pow_b2, eps, eps_root, + adamForwardInplaceCPUKernel( + scalar_t(b1), scalar_t(inv_one_minus_pow_b1), scalar_t(b2), + scalar_t(inv_one_minus_pow_b2), scalar_t(eps), scalar_t(eps_root), n, updates.data_ptr(), mu.data_ptr(), nu.data_ptr()); })); @@ -76,7 +76,7 @@ void adamForwardMuCPUKernel(const scalar_t* __restrict__ updates_ptr, const scalar_t* __restrict__ mu_ptr, const other_t b1, const size_t n, scalar_t* __restrict__ mu_out_ptr) { -#pragma omp parallel for num_threads(32) +#pragma omp parallel for num_threads(omp_get_num_procs()) for (size_t tid = 0; tid < n; ++tid) { const scalar_t updates = updates_ptr[tid]; const scalar_t mu = mu_ptr[tid]; @@ -87,16 +87,14 @@ void adamForwardMuCPUKernel(const scalar_t* __restrict__ updates_ptr, } // namespace torch::Tensor adamForwardMuCPU(const torch::Tensor& updates, - const torch::Tensor& mu, const float b1) { - using other_t = float; - + const torch::Tensor& mu, const pyfloat_t b1) { auto mu_out = torch::empty_like(mu); const size_t n = getTensorPlainSize(updates); AT_DISPATCH_FLOATING_TYPES(updates.scalar_type(), "adamForwardMuCPU", ([&] { - adamForwardMuCPUKernel( + adamForwardMuCPUKernel( updates.data_ptr(), - mu.data_ptr(), b1, n, + mu.data_ptr(), scalar_t(b1), n, mu_out.data_ptr()); })); return mu_out; @@ -108,7 +106,7 @@ void adamForwardNuCPUKernel(const scalar_t* __restrict__ updates_ptr, const scalar_t* __restrict__ nu_ptr, const other_t b2, const size_t n, scalar_t* __restrict__ nu_out_ptr) { -#pragma omp parallel for num_threads(32) +#pragma omp parallel for num_threads(omp_get_num_procs()) for (size_t tid = 0; tid < n; ++tid) { const scalar_t updates = updates_ptr[tid]; const scalar_t nu = nu_ptr[tid]; @@ -120,16 +118,14 @@ void adamForwardNuCPUKernel(const scalar_t* __restrict__ updates_ptr, } // namespace torch::Tensor adamForwardNuCPU(const torch::Tensor& updates, - const torch::Tensor& nu, const float b2) { - using other_t = float; - + const torch::Tensor& nu, const pyfloat_t b2) { auto nu_out = torch::empty_like(nu); const size_t n = getTensorPlainSize(updates); AT_DISPATCH_FLOATING_TYPES(updates.scalar_type(), "adamForwardNuCPU", ([&] { - adamForwardNuCPUKernel( + adamForwardNuCPUKernel( updates.data_ptr(), - nu.data_ptr(), b2, n, + nu.data_ptr(), scalar_t(b2), n, nu_out.data_ptr()); })); return nu_out; @@ -144,7 +140,7 @@ void adamForwardUpdatesCPUKernel(const scalar_t* __restrict__ new_mu_ptr, const other_t eps, const other_t eps_root, const size_t n, scalar_t* __restrict__ updates_out_ptr) { -#pragma omp parallel for num_threads(32) +#pragma omp parallel for num_threads(omp_get_num_procs()) for (size_t tid = 0; tid < n; ++tid) { const scalar_t new_mu = new_mu_ptr[tid]; const scalar_t new_nu = new_nu_ptr[tid]; @@ -156,10 +152,12 @@ void adamForwardUpdatesCPUKernel(const scalar_t* __restrict__ new_mu_ptr, } // namespace torch::Tensor adamForwardUpdatesCPU(const torch::Tensor& new_mu, - const torch::Tensor& new_nu, const float b1, - const float b2, const float eps, - const float eps_root, const int count) { - using other_t = float; + const torch::Tensor& new_nu, + const pyfloat_t b1, const pyfloat_t b2, + const pyfloat_t eps, + const pyfloat_t eps_root, + const pyuint_t count) { + using other_t = pyfloat_t; auto updates_out = torch::empty_like(new_mu); @@ -171,9 +169,10 @@ torch::Tensor adamForwardUpdatesCPU(const torch::Tensor& new_mu, const size_t n = getTensorPlainSize(new_mu); AT_DISPATCH_FLOATING_TYPES( new_mu.scalar_type(), "adamForwardUpdatesCPU", ([&] { - adamForwardUpdatesCPUKernel( + adamForwardUpdatesCPUKernel( new_mu.data_ptr(), new_nu.data_ptr(), - inv_one_minus_pow_b1, inv_one_minus_pow_b2, eps, eps_root, n, + scalar_t(inv_one_minus_pow_b1), scalar_t(inv_one_minus_pow_b2), + scalar_t(eps), scalar_t(eps_root), n, updates_out.data_ptr()); })); return updates_out; @@ -185,7 +184,7 @@ void adamBackwardMuCPUKernel(const scalar_t* __restrict__ dmu_ptr, const other_t b1, const size_t n, scalar_t* __restrict__ dupdates_out_ptr, scalar_t* __restrict__ dmu_out_ptr) { -#pragma omp parallel for num_threads(32) +#pragma omp parallel for num_threads(omp_get_num_procs()) for (size_t tid = 0; tid < n; ++tid) { const scalar_t dmu = dmu_ptr[tid]; @@ -197,16 +196,14 @@ void adamBackwardMuCPUKernel(const scalar_t* __restrict__ dmu_ptr, TensorArray<2> adamBackwardMuCPU(const torch::Tensor& dmu, const torch::Tensor& updates, - const torch::Tensor& mu, const float b1) { - using other_t = float; - + const torch::Tensor& mu, const pyfloat_t b1) { auto dupdates_out = torch::empty_like(updates); auto dmu_out = torch::empty_like(mu); const size_t n = getTensorPlainSize(dmu); AT_DISPATCH_FLOATING_TYPES(dmu.scalar_type(), "adamBackwardMuCPU", ([&] { - adamBackwardMuCPUKernel( - dmu.data_ptr(), b1, n, + adamBackwardMuCPUKernel( + dmu.data_ptr(), scalar_t(b1), n, dupdates_out.data_ptr(), dmu_out.data_ptr()); })); @@ -220,7 +217,7 @@ void adamBackwardNuCPUKernel(const scalar_t* __restrict__ dnu_ptr, const other_t b2, const size_t n, scalar_t* __restrict__ dupdates_out_ptr, scalar_t* __restrict__ dnu_out_ptr) { -#pragma omp parallel for num_threads(32) +#pragma omp parallel for num_threads(omp_get_num_procs()) for (size_t tid = 0; tid < n; ++tid) { const scalar_t dnu = dnu_ptr[tid]; const scalar_t updates = updates_ptr[tid]; @@ -233,19 +230,18 @@ void adamBackwardNuCPUKernel(const scalar_t* __restrict__ dnu_ptr, TensorArray<2> adamBackwardNuCPU(const torch::Tensor& dnu, const torch::Tensor& updates, - const torch::Tensor& nu, const float b2) { - using other_t = float; - + const torch::Tensor& nu, const pyfloat_t b2) { auto dupdates_out = torch::empty_like(updates); auto dnu_out = torch::empty_like(nu); const size_t n = getTensorPlainSize(dnu); - AT_DISPATCH_FLOATING_TYPES( - dnu.scalar_type(), "adamForwardNuCPU", ([&] { - adamBackwardNuCPUKernel( - dnu.data_ptr(), updates.data_ptr(), b2, n, - dupdates_out.data_ptr(), dnu_out.data_ptr()); - })); + AT_DISPATCH_FLOATING_TYPES(dnu.scalar_type(), "adamForwardNuCPU", ([&] { + adamBackwardNuCPUKernel( + dnu.data_ptr(), + updates.data_ptr(), scalar_t(b2), + n, dupdates_out.data_ptr(), + dnu_out.data_ptr()); + })); return TensorArray<2>{std::move(dupdates_out), std::move(dnu_out)}; } @@ -259,7 +255,7 @@ void adamBackwardUpdatesCPUKernel(const scalar_t* __restrict__ dupdates_ptr, const size_t n, scalar_t* __restrict__ dnew_mu_out_ptr, scalar_t* __restrict__ dnew_nu_out_ptr) { -#pragma omp parallel for num_threads(32) +#pragma omp parallel for num_threads(omp_get_num_procs()) for (size_t tid = 0; tid < n; ++tid) { const scalar_t dupdates = dupdates_ptr[tid]; const scalar_t updates = updates_ptr[tid]; @@ -286,9 +282,9 @@ TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor& dupdates, const torch::Tensor& updates, const torch::Tensor& new_mu, const torch::Tensor& new_nu, - const float b1, const float b2, - const int count) { - using other_t = float; + const pyfloat_t b1, const pyfloat_t b2, + const pyuint_t count) { + using other_t = pyfloat_t; auto dmu_out = torch::empty_like(new_mu); auto dnu_out = torch::empty_like(new_nu); @@ -300,10 +296,11 @@ TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor& dupdates, const size_t n = getTensorPlainSize(dupdates); AT_DISPATCH_FLOATING_TYPES( dupdates.scalar_type(), "adamBackwardUpdatesCPU", ([&] { - adamBackwardUpdatesCPUKernel( + adamBackwardUpdatesCPUKernel( dupdates.data_ptr(), updates.data_ptr(), - new_mu.data_ptr(), one_minus_pow_b1, inv_one_minus_pow_b2, - n, dmu_out.data_ptr(), dnu_out.data_ptr()); + new_mu.data_ptr(), scalar_t(one_minus_pow_b1), + scalar_t(inv_one_minus_pow_b2), n, dmu_out.data_ptr(), + dnu_out.data_ptr()); })); return TensorArray<2>{std::move(dmu_out), std::move(dnu_out)}; } diff --git a/src/adam_op/adam_op_impl.cu b/src/adam_op/adam_op_impl_cuda.cu similarity index 80% rename from src/adam_op/adam_op_impl.cu rename to src/adam_op/adam_op_impl_cuda.cu index b10942eb..16441157 100644 --- a/src/adam_op/adam_op_impl.cu +++ b/src/adam_op/adam_op_impl_cuda.cu @@ -17,7 +17,7 @@ #include -#include "include/adam_op/adam_op_impl.cuh" +#include "include/adam_op/adam_op_impl_cuda.cuh" #include "include/utils.h" namespace torchopt { @@ -49,22 +49,22 @@ __global__ void adamForwardInplaceCUDAKernel( } } // namespace -TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor &updates, - const torch::Tensor &mu, - const torch::Tensor &nu, const float b1, - const float b2, const float eps, - const float eps_root, const int count) { - using other_t = float; - const float inv_one_minus_pow_b1 = 1 / (1 - std::pow(b1, count)); - const float inv_one_minus_pow_b2 = 1 / (1 - std::pow(b2, count)); +TensorArray<3> adamForwardInplaceCUDA( + const torch::Tensor &updates, const torch::Tensor &mu, + const torch::Tensor &nu, const pyfloat_t b1, const pyfloat_t b2, + const pyfloat_t eps, const pyfloat_t eps_root, const pyuint_t count) { + using other_t = pyfloat_t; + const other_t inv_one_minus_pow_b1 = 1 / (1 - std::pow(b1, count)); + const other_t inv_one_minus_pow_b2 = 1 / (1 - std::pow(b2, count)); const size_t n = getTensorPlainSize(updates); const dim3 block(std::min(n, size_t(256))); const dim3 grid((n - 1) / block.x + 1); AT_DISPATCH_FLOATING_TYPES( updates.scalar_type(), "adamForwardInplaceCUDA", ([&] { - adamForwardInplaceCUDAKernel<<>>( - b1, inv_one_minus_pow_b1, b2, inv_one_minus_pow_b2, eps, eps_root, + adamForwardInplaceCUDAKernel<<>>( + scalar_t(b1), scalar_t(inv_one_minus_pow_b1), scalar_t(b2), + scalar_t(inv_one_minus_pow_b2), scalar_t(eps), scalar_t(eps_root), n, updates.data_ptr(), mu.data_ptr(), nu.data_ptr()); })); @@ -89,9 +89,7 @@ __global__ void adamForwardMuCUDAKernel( } // namespace torch::Tensor adamForwardMuCUDA(const torch::Tensor &updates, - const torch::Tensor &mu, const float b1) { - using other_t = float; - + const torch::Tensor &mu, const pyfloat_t b1) { auto mu_out = torch::empty_like(mu); const size_t n = getTensorPlainSize(updates); @@ -99,9 +97,9 @@ torch::Tensor adamForwardMuCUDA(const torch::Tensor &updates, const dim3 grid((n - 1) / block.x + 1); AT_DISPATCH_FLOATING_TYPES( updates.scalar_type(), "adamForwardMuCUDA", ([&] { - adamForwardMuCUDAKernel<<>>( - updates.data_ptr(), mu.data_ptr(), b1, n, - mu_out.data_ptr()); + adamForwardMuCUDAKernel<<>>( + updates.data_ptr(), mu.data_ptr(), scalar_t(b1), + n, mu_out.data_ptr()); })); return mu_out; } @@ -126,9 +124,7 @@ __global__ void adamForwardNuCUDAKernel( } // namespace torch::Tensor adamForwardNuCUDA(const torch::Tensor &updates, - const torch::Tensor &nu, const float b2) { - using other_t = float; - + const torch::Tensor &nu, const pyfloat_t b2) { auto nu_out = torch::empty_like(nu); const size_t n = getTensorPlainSize(updates); @@ -136,9 +132,9 @@ torch::Tensor adamForwardNuCUDA(const torch::Tensor &updates, const dim3 grid((n - 1) / block.x + 1); AT_DISPATCH_FLOATING_TYPES( updates.scalar_type(), "adamForwardNuCUDA", ([&] { - adamForwardNuCUDAKernel<<>>( - updates.data_ptr(), nu.data_ptr(), b2, n, - nu_out.data_ptr()); + adamForwardNuCUDAKernel<<>>( + updates.data_ptr(), nu.data_ptr(), scalar_t(b2), + n, nu_out.data_ptr()); })); return nu_out; } @@ -166,10 +162,11 @@ __global__ void adamForwardUpdatesCUDAKernel( torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor &new_mu, const torch::Tensor &new_nu, - const float b1, const float b2, - const float eps, const float eps_root, - const int count) { - using other_t = float; + const pyfloat_t b1, const pyfloat_t b2, + const pyfloat_t eps, + const pyfloat_t eps_root, + const pyuint_t count) { + using other_t = pyfloat_t; auto updates_out = torch::empty_like(new_mu); @@ -183,9 +180,10 @@ torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor &new_mu, const dim3 grid((n - 1) / block.x + 1); AT_DISPATCH_FLOATING_TYPES( new_mu.scalar_type(), "adamForwardUpdatesCUDA", ([&] { - adamForwardUpdatesCUDAKernel<<>>( + adamForwardUpdatesCUDAKernel<<>>( new_mu.data_ptr(), new_nu.data_ptr(), - inv_one_minus_pow_b1, inv_one_minus_pow_b2, eps, eps_root, n, + scalar_t(inv_one_minus_pow_b1), scalar_t(inv_one_minus_pow_b2), + scalar_t(eps), scalar_t(eps_root), n, updates_out.data_ptr()); })); return updates_out; @@ -211,9 +209,7 @@ __global__ void adamBackwardMuCUDAKernel( TensorArray<2> adamBackwardMuCUDA(const torch::Tensor &dmu, const torch::Tensor &updates, - const torch::Tensor &mu, const float b1) { - using other_t = float; - + const torch::Tensor &mu, const pyfloat_t b1) { auto dupdates_out = torch::empty_like(updates); auto dmu_out = torch::empty_like(mu); @@ -222,9 +218,9 @@ TensorArray<2> adamBackwardMuCUDA(const torch::Tensor &dmu, const dim3 grid((n - 1) / block.x + 1); AT_DISPATCH_FLOATING_TYPES( dmu.scalar_type(), "adamBackwardMuCUDA", ([&] { - adamBackwardMuCUDAKernel<<>>( - dmu.data_ptr(), b1, n, dupdates_out.data_ptr(), - dmu_out.data_ptr()); + adamBackwardMuCUDAKernel<<>>( + dmu.data_ptr(), scalar_t(b1), n, + dupdates_out.data_ptr(), dmu_out.data_ptr()); })); return TensorArray<2>{std::move(dupdates_out), std::move(dmu_out)}; } @@ -251,9 +247,7 @@ __global__ void adamBackwardNuCUDAKernel( TensorArray<2> adamBackwardNuCUDA(const torch::Tensor &dnu, const torch::Tensor &updates, - const torch::Tensor &nu, const float b2) { - using other_t = float; - + const torch::Tensor &nu, const pyfloat_t b2) { auto dupdates_out = torch::empty_like(updates); auto dnu_out = torch::empty_like(nu); @@ -262,9 +256,10 @@ TensorArray<2> adamBackwardNuCUDA(const torch::Tensor &dnu, const dim3 grid((n - 1) / block.x + 1); AT_DISPATCH_FLOATING_TYPES( dnu.scalar_type(), "adamForwardNuCUDA", ([&] { - adamBackwardNuCUDAKernel<<>>( - dnu.data_ptr(), updates.data_ptr(), b2, n, - dupdates_out.data_ptr(), dnu_out.data_ptr()); + adamBackwardNuCUDAKernel<<>>( + dnu.data_ptr(), updates.data_ptr(), + scalar_t(b2), n, dupdates_out.data_ptr(), + dnu_out.data_ptr()); })); return TensorArray<2>{std::move(dupdates_out), std::move(dnu_out)}; } @@ -307,9 +302,9 @@ TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor &dupdates, const torch::Tensor &updates, const torch::Tensor &new_mu, const torch::Tensor &new_nu, - const float b1, const float b2, - const int count) { - using other_t = float; + const pyfloat_t b1, const pyfloat_t b2, + const pyuint_t count) { + using other_t = pyfloat_t; auto dmu_out = torch::empty_like(new_mu); auto dnu_out = torch::empty_like(new_nu); @@ -323,10 +318,11 @@ TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor &dupdates, const dim3 grid((n - 1) / block.x + 1); AT_DISPATCH_FLOATING_TYPES( dupdates.scalar_type(), "adamBackwardUpdatesCUDA", ([&] { - adamBackwardUpdatesCUDAKernel<<>>( + adamBackwardUpdatesCUDAKernel<<>>( dupdates.data_ptr(), updates.data_ptr(), - new_mu.data_ptr(), one_minus_pow_b1, inv_one_minus_pow_b2, - n, dmu_out.data_ptr(), dnu_out.data_ptr()); + new_mu.data_ptr(), scalar_t(one_minus_pow_b1), + scalar_t(inv_one_minus_pow_b2), n, dmu_out.data_ptr(), + dnu_out.data_ptr()); })); return TensorArray<2>{std::move(dmu_out), std::move(dnu_out)}; } diff --git a/tests/requirements.txt b/tests/requirements.txt index 6cf7a2a1..c6990c45 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,7 +1,7 @@ --extra-index-url https://download.pytorch.org/whl/cu116 -torch == 1.12 +torch >= 1.12 torchvision -functorch +functorch >= 0.2 --requirement ../requirements.txt @@ -14,7 +14,7 @@ pylint mypy flake8 flake8-bugbear -doc8 +doc8 < 1.0.0a0 pydocstyle pyenchant cpplint diff --git a/torchopt/version.py b/torchopt/version.py index 784a9a63..78aa7e73 100644 --- a/torchopt/version.py +++ b/torchopt/version.py @@ -14,4 +14,4 @@ # ============================================================================== """TorchOpt: a high-performance optimizer library built upon PyTorch.""" -__version__ = '0.4.2' +__version__ = '0.4.3' diff --git a/tutorials/requirements.txt b/tutorials/requirements.txt index 00cb5228..5fe3b1ad 100644 --- a/tutorials/requirements.txt +++ b/tutorials/requirements.txt @@ -1,7 +1,7 @@ --extra-index-url https://download.pytorch.org/whl/cu116 -torch == 1.12 +torch >= 1.12 torchvision -functorch +functorch >= 0.2 --requirement ../requirements.txt pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy