diff --git a/.editorconfig b/.editorconfig index 96ef7342..3ae9f69a 100644 --- a/.editorconfig +++ b/.editorconfig @@ -14,7 +14,7 @@ insert_final_newline = true indent_size = 4 src_paths=torchopt,tests,examples -[*.{yaml,yml}] +[*.{yaml,yml,json}] indent_size = 2 [*.md] @@ -25,8 +25,18 @@ x-soft-wrap-text = true indent_size = 4 x-soft-wrap-text = true +[*.{bib,tex}] +indent_size = 2 + [Makefile] indent_style = tab +[*.sh] +indent_style = tab + +[*.bat] +end_of_line = crlf +indent_style = tab + [*.{cpp,h,cu,cuh}] indent_size = 2 diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 4b90bb84..6d381b28 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -25,10 +25,9 @@ body: - type: input id: version attributes: - label: | - What version of TorchOpt are you using? - value: | - python3 -m pip show torchopt + label: What version of TorchOpt are you using? + description: Run command `python3 -c 'print(__import__("torchopt").__version__)'` in your shell and paste the output here. + placeholder: E.g., 0.6.0 validations: required: true @@ -36,7 +35,7 @@ body: id: system-info attributes: label: System information - value: | + description: | Describe the characteristic of your environment: - Describe how the library was installed (pip, conda, source, ...) @@ -55,7 +54,7 @@ body: id: description attributes: label: Problem description - placeholder: | + description: >- Provide a short description, state the expected behavior and what actually happens. Include relevant information like what version of TorchOpt you are using, what system you are on, and any useful commands / output. @@ -66,18 +65,18 @@ body: id: code attributes: label: Reproducible example code + description: >- + The code should be minimal, have minimal external dependencies, and isolate the functions + that cause breakage. Submit matched and complete snippets that can be easily run to diagnose + the issue. value: | - - The Python snippets: ```python ``` - Run the snippets with the following commands: + Command lines: ```bash @@ -88,6 +87,12 @@ body: ```text ``` + + Steps to reproduce: + + 1. + 2. + 3. validations: required: true @@ -95,9 +100,8 @@ body: id: traceback attributes: label: Traceback + description: Put the Python traceback information here. placeholder: | - Put the Python traceback information here. - Traceback (most recent call last): File ... render: pytb @@ -106,14 +110,13 @@ body: id: expected attributes: label: Expected behavior - placeholder: | - Provide a clear and concise description of what you expected to happen. + description: Provide a clear and concise description of what you expected to happen. - type: textarea id: additional-context attributes: label: Additional context - placeholder: | + description: >- Add any other context about the problem here. Screenshots may also be helpful. If you know or suspect the reason for this bug, paste the code lines and suggest modifications. diff --git a/.github/ISSUE_TEMPLATE/feature-request.yml b/.github/ISSUE_TEMPLATE/feature-request.yml index 959ec909..ee76e770 100644 --- a/.github/ISSUE_TEMPLATE/feature-request.yml +++ b/.github/ISSUE_TEMPLATE/feature-request.yml @@ -19,6 +19,7 @@ body: id: motivation attributes: label: Motivation + description: Outline the motivation for the proposal. value: | @@ -247,7 +246,7 @@ Users need to define the stationary condition/objective function and the inner-l ```python # Inherited from the class ImplicitMetaGradientModule # Optionally specify the linear solver (conjugate gradient or Neumann series) -class InnerNet(ImplicitMetaGradientModule, linear_solver): +class InnerNet(ImplicitMetaGradientModule, linear_solve=linear_solver): def __init__(self, meta_param): super().__init__() self.meta_param = meta_param @@ -275,7 +274,7 @@ class InnerNet(ImplicitMetaGradientModule, linear_solver): meta_params, data = ..., ... inner_net = InnerNet(meta_params) -# Solve for inner-loop process related with the meta-parameters +# Solve for inner-loop process related to the meta-parameters optimal_inner_net = inner_net.solve(data) # Get outer loss and solve for meta-gradient @@ -293,17 +292,58 @@ Refer to the tutorial notebook [Zero-order Differentiation](tutorials/6_Zero_Ord #### Functional API +For zero-order differentiation, users need to define the forward pass calculation and the noise sampling procedure. TorchOpt provides the decorator to wrap the forward function for enabling zero-order differentiation. + ```python # Customize the noise sampling function in ES -def sample(sample_shape): +def distribution(sample_shape): + # Generate a batch of noise samples + # NOTE: The distribution should be spherical symmetric and with a constant variance of 1. ... - return sample_noise + return noise_batch + +# Distribution can also be an instance of `torch.distributions.Distribution`, e.g., `torch.distributions.Normal(...)` +distribution = torch.distributions.Normal(loc=0, scale=1) # Specify method and hyper-parameter of ES -@torchopt.diff.zero_order(sample, method) +@torchopt.diff.zero_order(distribution, method) def forward(params, batch, labels): - # forward process - return output + # Forward process + ... + return objective # the returned tensor should be a scalar tensor +``` + +#### OOP API + +TorchOpt also offer an OOP API, users need to inherit from the class `torchopt.nn.ZeroOrderGradientModule` to construct the network as an `nn.Module` following a classical PyTorch style. +Users need to define the forward process zero-order gradient procedures `forward()` and a noise sampling function `sample()`. + +```python +# Inherited from the class ZeroOrderGradientModule +# Optionally specify the `method` and/or `num_samples` and/or `sigma` used for sampling +class Net(ZeroOrderGradientModule, method=method, num_samples=num_samples, sigma=sigma): + def __init__(self, ...): + ... + + def forward(self, batch): + # Forward process + ... + return objective # the returned tensor should be a scalar tensor + + def sample(self, sample_shape=torch.Size()): + # Generate a batch of noise samples + # NOTE: The distribution should be spherical symmetric and with a constant variance of 1. + ... + return noise_batch + +# Get model and data +net = Net(...) +data = ... + +# Forward pass +loss = Net(data) +# Backward pass using zero-order differentiation +grads = torch.autograd.grad(loss, net.parameters()) ``` -------------------------------------------------------------------------------- @@ -315,7 +355,7 @@ def forward(params, batch, labels): We take the optimizer as a whole instead of separating it into several basic operators (e.g., `sqrt` and `div`). Therefore, by manually writing the forward and backward functions, we can perform the symbolic reduction. In addition, we can store some intermediate data that can be reused during the backpropagation. -We write the accelerated functions in C++ OpenMP and CUDA, bind them by [`pybind11`](https://github.com/pybind/pybind11) to allow they can be called by Python, and then we define the forward and backward behavior using `torch.autograd.Function`. +We write the accelerated functions in C++ OpenMP and CUDA, bind them by [`pybind11`](https://github.com/pybind/pybind11) to allow they can be called by Python, and then define the forward and backward behavior using `torch.autograd.Function`. Users can use by simply setting the `use_accelerated_op` flag as `True`. Refer to the corresponding sections in tutorials [Functional Optimizer](tutorials/1_Functional_Optimizer.ipynb) and [Meta-Optimizer](tutorials/3_Meta_Optimizer.ipynb) @@ -327,22 +367,22 @@ optimizer = torchopt.MetaAdam(model, lr, use_accelerated_op=True) `TorchOpt` provides distributed training features based on the PyTorch RPC module for better training speed and multi-node multi-GPU support. Different from the MPI-like parallelization paradigm, which uses multiple homogenous workers and requires carefully designed communication hooks, the RPC APIs allow users to build their optimization pipeline more flexibly. -Experimental results show that we achieve approximately linear relationship between the speed-up ratio and the number of workers. -Check out the [distributed MAML example](https://github.com/metaopt/torchopt/tree/main/examples/distributed/few-shot) for more specific guidance. +Experimental results show that we achieve an approximately linear relationship between the speed-up ratio and the number of workers. +Check out the [Distributed Training Documentation](https://torchopt.readthedocs.io/en/latest/distributed/distributed.html) and [distributed MAML example](https://github.com/metaopt/torchopt/tree/main/examples/distributed/few-shot) for more specific guidance. ### OpTree -We implement the *PyTree* to enable fast nested structure flatten using C++. +We implement the *PyTree* to enable fast nested structure flattening using C++. The tree operations (e.g., flatten and unflatten) are very important in enabling functional and Just-In-Time (JIT) features of deep learning frameworks. -By implementing it in C++, we can use some cache/memory friendly structures (e.g., `absl::InlinedVector`) to improve the performance. -For more guidance and comparison results, please refer to our open source project [`OpTree`](https://github.com/metaopt/optree). +By implementing it in C++, we can use some cache/memory-friendly structures (e.g., `absl::InlinedVector`) to improve the performance. +For more guidance and comparison results, please refer to our open-source project [`OpTree`](https://github.com/metaopt/optree). -------------------------------------------------------------------------------- ## Visualization Complex gradient flow in meta-learning brings in a great challenge for managing the gradient flow and verifying the correctness of it. -TorchOpt provides a visualization tool that draw variable (e.g., network parameters or meta-parameters) names on the gradient graph for better analyzing. +TorchOpt provides a visualization tool that draws variable (e.g., network parameters or meta-parameters) names on the gradient graph for better analysis. The visualization tool is modified from [`torchviz`](https://github.com/szagoruyko/pytorchviz). Refer to the example [visualization code](examples/visualize.py) and the tutorial notebook [Visualization](tutorials/2_Visualization.ipynb) for more details. @@ -357,7 +397,7 @@ Compared with [`torchviz`](https://github.com/szagoruyko/pytorchviz), TorchOpt f ## Examples -In the [`examples`](examples) directory, we offer several examples of functional optimizer and light-weight meta-learning examples with TorchOpt. +In the [`examples`](examples) directory, we offer several examples of functional optimizers and light-weight meta-learning examples with TorchOpt. - [Model-Agnostic Meta-Learning (MAML) - Supervised Learning](https://arxiv.org/abs/1703.03400) (ICML 2017) - [Learning to Reweight Examples for Robust Deep Learning](https://arxiv.org/abs/1803.09050) (ICML 2018) @@ -378,13 +418,15 @@ Requirements - (Optional) For visualizing computation graphs - [Graphviz](https://graphviz.org/download) (for Linux users use `apt/yum install graphviz` or `conda install -c anaconda python-graphviz`) -**Please follow the instructions at to install PyTorch in your Python environment first.** Then run the following command to install TorchOpt from PyPI ([![PyPI](https://img.shields.io/pypi/v/torchopt?label=pypi&logo=pypi)](https://pypi.org/project/torchopt) / ![Status](https://img.shields.io/pypi/status/torchopt?label=status)): +**Please follow the instructions at to install PyTorch in your Python environment first.** +Then run the following command to install TorchOpt from PyPI ([![PyPI](https://img.shields.io/pypi/v/torchopt?label=pypi&logo=pypi)](https://pypi.org/project/torchopt) / ![Status](https://img.shields.io/pypi/status/torchopt?label=status)): ```bash pip3 install torchopt ``` -If the minimum version of PyTorch is not satisfied, `pip` will install/upgrade it for you. Please be careful about the `torch` build for CPU / CUDA support (e.g. `cpu`, `cu116`, `cu117`). You may need to specify the extra index URL for the `torch` package: +If the minimum version of PyTorch is not satisfied, `pip` will install/upgrade it for you. Please be careful about the `torch` build for CPU / CUDA support (e.g. `cpu`, `cu116`, `cu117`). +You may need to specify the extra index URL for the `torch` package: ```bash pip3 install torchopt --extra-index-url https://download.pytorch.org/whl/cu117 @@ -400,7 +442,8 @@ cd torchopt pip3 install . ``` -We provide a [conda](https://github.com/conda/conda) environment recipe to install the build toolchain such as `cmake`, `g++`, and `nvcc`: +We provide a [conda](https://github.com/conda/conda) environment recipe to install the build toolchain such as `cmake`, `g++`, and `nvcc`. +You can use the following commands with [`conda`](https://github.com/conda/conda) / [`mamba`](https://github.com/mamba-org/mamba) to create a new isolated environment. ```bash git clone https://github.com/metaopt/torchopt.git @@ -436,7 +479,7 @@ If you find TorchOpt useful, please cite it in your publications. ## The Team -TorchOpt is a work by [Jie Ren](https://github.com/JieRen98), [Xidong Feng](https://github.com/waterhorse1), [Bo Liu](https://github.com/Benjamin-eecs), [Xuehai Pan](https://github.com/XuehaiPan), [Luo Mai](https://luomai.github.io), and [Yaodong Yang](https://www.yangyaodong.com). +TorchOpt is a work by [Jie Ren](https://github.com/JieRen98), [Xidong Feng](https://github.com/waterhorse1), [Bo Liu](https://benjamin-eecs.github.io/), [Xuehai Pan](https://github.com/XuehaiPan), [Luo Mai](https://luomai.github.io), and [Yaodong Yang](https://www.yangyaodong.com). ## License diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 00000000..65b70e6e --- /dev/null +++ b/codecov.yml @@ -0,0 +1,9 @@ +coverage: + round: nearest + status: + project: + default: + threshold: 0.05% + patch: + default: + informational: true diff --git a/conda-recipe-minimal.yaml b/conda-recipe-minimal.yaml index 4ae91303..c3d155b8 100644 --- a/conda-recipe-minimal.yaml +++ b/conda-recipe-minimal.yaml @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,7 +27,7 @@ channels: - conda-forge dependencies: - - python = 3.9 + - python = 3.10 - pip # Learning @@ -44,7 +44,6 @@ dependencies: - cmake >= 3.11 - make - cxx-compiler - - gxx = 10 - nvidia/label/cuda-11.7.1::cuda-nvcc - nvidia/label/cuda-11.7.1::cuda-cudart-dev - pybind11 >= 2.10.1 diff --git a/conda-recipe.yaml b/conda-recipe.yaml index 9eacbfaa..faee0a7c 100644 --- a/conda-recipe.yaml +++ b/conda-recipe.yaml @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,7 +27,7 @@ channels: - conda-forge dependencies: - - python = 3.9 + - python = 3.10 - pip # Learning @@ -50,7 +50,6 @@ dependencies: - cmake >= 3.11 - make - cxx-compiler - - gxx = 10 - nvidia/label/cuda-11.7.1::cuda-nvcc - nvidia/label/cuda-11.7.1::cuda-cudart-dev - patchelf >= 0.14 @@ -68,7 +67,7 @@ dependencies: # Documentation - sphinx >= 5.2.1 - - sphinx_rtd_theme + - sphinx-rtd-theme - sphinx-autobuild - sphinx-copybutton - sphinxcontrib-spelling @@ -85,16 +84,15 @@ dependencies: - pytest - pytest-cov - pytest-xdist - - isort + - isort >= 5.11.0 - conda-forge::black-jupyter >= 22.6.0 - pylint >= 2.15.0 - mypy >= 0.990 - - types-setuptools - flake8 - flake8-bugbear - doc8 < 1.0.0a0 - pydocstyle - clang-format >= 14 - - clang-tools # clang-tidy + - clang-tools >= 14 # clang-tidy - cpplint - pre-commit diff --git a/docs/conda-recipe.yaml b/docs/conda-recipe.yaml index a26b613b..9a14af3f 100644 --- a/docs/conda-recipe.yaml +++ b/docs/conda-recipe.yaml @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,7 +27,7 @@ channels: - conda-forge dependencies: - - python = 3.9 + - python = 3.10 - pip # Learning @@ -42,7 +42,6 @@ dependencies: - cmake >= 3.11 - make - cxx-compiler - - gxx = 10 - nvidia/label/cuda-11.7.1::cuda-nvcc - nvidia/label/cuda-11.7.1::cuda-cudart-dev - pybind11 >= 2.10.1 @@ -58,7 +57,7 @@ dependencies: # Documentation - sphinx >= 5.2.1 - - sphinx_rtd_theme + - sphinx-rtd-theme - sphinx-autobuild - sphinx-copybutton - sphinxcontrib-spelling diff --git a/docs/requirements.txt b/docs/requirements.txt index 9ac98898..655c64ff 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -15,6 +15,6 @@ sphinx-autodoc-typehints >= 1.19.2 IPython ipykernel pandoc -myst_nb +myst-nb docutils matplotlib diff --git a/docs/source/_static/images/explicit-gradient.png b/docs/source/_static/images/explicit-gradient.png new file mode 100644 index 00000000..90cf4d4d Binary files /dev/null and b/docs/source/_static/images/explicit-gradient.png differ diff --git a/docs/source/_static/images/implicit-gradient.png b/docs/source/_static/images/implicit-gradient.png new file mode 100644 index 00000000..faf26486 Binary files /dev/null and b/docs/source/_static/images/implicit-gradient.png differ diff --git a/docs/source/_static/images/visualization-fig1.svg b/docs/source/_static/images/visualization-fig1.svg new file mode 100644 index 00000000..281e456b --- /dev/null +++ b/docs/source/_static/images/visualization-fig1.svg @@ -0,0 +1,57 @@ + + + + + + +%3 + + + +140534064715952 + +y +() + + + +140534064838304 + +MulBackward0 + + + +140534064838304->140534064715952 + + + + + +140534064837776 + +AccumulateGrad + + + +140534064837776->140534064838304 + + + + + +140534064714832 + +x +() + + + +140534064714832->140534064837776 + + + + + diff --git a/docs/source/_static/images/visualization-fig2.svg b/docs/source/_static/images/visualization-fig2.svg new file mode 100644 index 00000000..25db4e5a --- /dev/null +++ b/docs/source/_static/images/visualization-fig2.svg @@ -0,0 +1,106 @@ + + + + + + +%3 + + + +140534659780336 + +loss +() + + + +140531595570768 + +MseLossBackward0 + + + +140531595570768->140534659780336 + + + + + +140531595570576 + +AddmmBackward0 + + + +140531595570576->140531595570768 + + + + + +140531595570528 + +AccumulateGrad + + + +140531595570528->140531595570576 + + + + + +140531595583632 + +fc.bias +(1) + + + +140531595583632->140531595570528 + + + + + +140531595571104 + +TBackward0 + + + +140531595571104->140531595570576 + + + + + +140531595570432 + +AccumulateGrad + + + +140531595570432->140531595571104 + + + + + +140531595582816 + +fc.weight +(1, 5) + + + +140531595582816->140531595570432 + + + + + diff --git a/docs/source/_static/images/visualization-fig3.svg b/docs/source/_static/images/visualization-fig3.svg new file mode 100644 index 00000000..c041e0f6 --- /dev/null +++ b/docs/source/_static/images/visualization-fig3.svg @@ -0,0 +1,339 @@ + + + + + + +%3 + + + +140531595614064 + +loss +() + + + +140531595567168 + +MseLossBackward0 + + + +140531595567168->140531595614064 + + + + + +140531595569232 + +AddBackward0 + + + +140531595569232->140531595567168 + + + + + +140531595568800 + +AddmmBackward0 + + + +140531595568800->140531595569232 + + + + + +140534660247264 + +AddBackward0 +step1.fc.bias +(1) + + + +140534660247264->140531595568800 + + + + + +140534553595376 + +AccumulateGrad + + + +140534553595376->140534660247264 + + + + + +140534553592832 + +AddmmBackward0 + + + +140534553595376->140534553592832 + + + + + +140534064448352 + +step0.fc.bias +(1) + + + +140534064448352->140534553595376 + + + + + +140534553595616 + +MulBackward0 + + + +140534553595616->140534660247264 + + + + + +140534553594848 + +ViewBackward0 + + + +140534553594848->140534553595616 + + + + + +140534553594992 + +SumBackward1 + + + +140534553594992->140534553594848 + + + + + +140534553594800 + +MseLossBackwardBackward0 + + + +140534553594800->140534553594992 + + + + + +140531595617904 + +TBackward0 + + + +140534553594800->140531595617904 + + + + + +140534553593072 + +AddBackward0 + + + +140534553593072->140534553594800 + + + + + +140534553592832->140534553593072 + + + + + +140534553593456 + +TBackward0 + + + +140534553593456->140534553592832 + + + + + +140534553593888 + +AccumulateGrad + + + +140534553593888->140534553593456 + + + + + +140531595572368 + +AddBackward0 +step1.fc.weight +(1, 5) + + + +140534553593888->140531595572368 + + + + + +140531595612944 + +step0.fc.weight +(1, 5) + + + +140531595612944->140534553593888 + + + + + +140531595567888 + +AccumulateGrad + + + +140531595567888->140531595569232 + + + + + +140531595567888->140534553593072 + + + + + +140531595613184 + +meta_param +() + + + +140531595613184->140531595567888 + + + + + +140534553594272 + +TBackward0 + + + +140534553594272->140531595568800 + + + + + +140531595572368->140534553594272 + + + + + +140534553593504 + +MulBackward0 + + + +140534553593504->140531595572368 + + + + + +140534553592976 + +TBackward0 + + + +140534553592976->140534553593504 + + + + + +140534553593216 + +TBackward0 + + + +140534553593216->140534553592976 + + + + + +140534553593552 + +MmBackward0 + + + +140534553593552->140534553593216 + + + + + +140531595617904->140534553593552 + + + + + diff --git a/docs/source/_static/images/zero-order.png b/docs/source/_static/images/zero-order.png new file mode 100644 index 00000000..2c94d667 Binary files /dev/null and b/docs/source/_static/images/zero-order.png differ diff --git a/docs/source/api/api.rst b/docs/source/api/api.rst index 27d16a64..b2866407 100644 --- a/docs/source/api/api.rst +++ b/docs/source/api/api.rst @@ -131,7 +131,7 @@ Differentiable Meta-RMSProp Optimizer ------ -Implicit differentiation +Implicit Differentiation ======================== .. currentmodule:: torchopt.diff.implicit @@ -141,7 +141,7 @@ Implicit differentiation custom_root nn.ImplicitMetaGradientModule -Custom solvers +Custom Solvers ~~~~~~~~~~~~~~ .. autofunction:: custom_root @@ -157,7 +157,7 @@ Implicit Meta-Gradient Module ------ -Linear system solvers +Linear System Solvers ===================== .. currentmodule:: torchopt.linear_solve @@ -168,7 +168,7 @@ Linear system solvers solve_normal_cg solve_inv -Indirect solvers +Indirect Solvers ~~~~~~~~~~~~~~~~ .. autofunction:: solve_cg @@ -177,6 +177,32 @@ Indirect solvers ------ +Zero-Order Differentiation +========================== + +.. currentmodule:: torchopt.diff.zero_order + +.. autosummary:: + + zero_order + nn.ZeroOrderGradientModule + +Decorators +~~~~~~~~~~ + +.. autofunction:: zero_order + + +Zero-order Gradient Module +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. currentmodule:: torchopt.diff.zero_order.nn + +.. autoclass:: ZeroOrderGradientModule + :members: + +------ + Optimizer Hooks =============== @@ -259,6 +285,115 @@ Chain .. autofunction:: chain +Distributed Utilities +===================== + +.. currentmodule:: torchopt.distributed + +Initialization and Synchronization +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + auto_init_rpc + barrier + +.. autofunction:: auto_init_rpc +.. autofunction:: barrier + +Process group information +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + get_world_info + get_world_rank + get_rank + get_world_size + get_local_rank + get_local_world_size + get_worker_id + +.. autofunction:: get_world_info +.. autofunction:: get_world_rank +.. autofunction:: get_rank +.. autofunction:: get_world_size +.. autofunction:: get_local_rank +.. autofunction:: get_local_world_size +.. autofunction:: get_worker_id + +Worker selection +~~~~~~~~~~~~~~~~ + +.. autosummary:: + + on_rank + not_on_rank + rank_zero_only + rank_non_zero_only + +.. autofunction:: on_rank +.. autofunction:: not_on_rank +.. autofunction:: rank_zero_only +.. autofunction:: rank_non_zero_only + +Remote Procedure Call (RPC) +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + remote_async_call + remote_sync_call + +.. autofunction:: remote_async_call +.. autofunction:: remote_sync_call + +Predefined partitioners and reducers +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + dim_partitioner + batch_partitioner + mean_reducer + sum_reducer + +.. autofunction:: dim_partitioner +.. autofunction:: batch_partitioner +.. autofunction:: mean_reducer +.. autofunction:: sum_reducer + +Function parallelization wrappers +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + parallelize + parallelize_async + parallelize_sync + +.. autofunction:: parallelize +.. autofunction:: parallelize_async +.. autofunction:: parallelize_sync + +Distributed Autograd +~~~~~~~~~~~~~~~~~~~~ + +.. currentmodule:: torchopt.distributed.autograd + +.. autosummary:: + + context + get_gradients + backward + grad + +.. autofunction:: context +.. autofunction:: get_gradients +.. autofunction:: backward +.. autofunction:: grad + + General Utilities ================= diff --git a/docs/source/basics/basics.rst b/docs/source/basics/basics.rst new file mode 100644 index 00000000..8b5d5acd --- /dev/null +++ b/docs/source/basics/basics.rst @@ -0,0 +1,34 @@ +Basics +====== + +This section describes useful concepts across TorchOpt. + +TorchOpt Types +-------------- + +.. autosummary:: + + torchopt.base.GradientTransformation + torchopt.base.TransformInitFn + torchopt.base.TransformUpdateFn + +PyTrees +------- + +`PyTrees `_ is an essential concept in TorchOpt. +They can be thought as a generalization of vectors. +They are a way to structure parameters or weights using tuples and dictionaries. +Many solvers in TorchOpt have native support for pytrees. + +Floating-Point Precision +------------------------ + +TorchOpt uses single (32-bit) floating precision (``torch.float32``) by default. +However, for some algorithms, this may not be enough. +Double (64-bit) floating precision (``torch.float64``) can be enabled by adding the following lines at the beginning of the file: + +.. code-block:: python + + import torch + + torch.set_default_dtype(torch.float64) diff --git a/docs/source/bibtex.json b/docs/source/bibtex.json index c2aa9165..7abea503 100644 --- a/docs/source/bibtex.json +++ b/docs/source/bibtex.json @@ -1,7 +1,7 @@ { - "cited": { - "examples/MAML": [ - "MAML", - ] - } + "cited": { + "examples/MAML": [ + "MAML", + ] + } } diff --git a/docs/source/conf.py b/docs/source/conf.py index 96736ebb..d8233da7 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ # pylint: disable=all -# -- Path setup -------------------------------------------------------------- +# -- Path setup ---------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the @@ -62,16 +62,16 @@ def filter(self, record): sphinx_autodoc_typehints._LOGGER.logger.addFilter(RecursiveForwardRefFilter()) -# -- Project information ----------------------------------------------------- +# -- Project information ------------------------------------------------------- project = 'TorchOpt' -copyright = '2022 MetaOPT Team' +copyright = '2022-2023 MetaOPT Team' author = 'TorchOpt Contributors' # The full version, including alpha/beta/rc tags release = get_version() -# -- General configuration --------------------------------------------------- +# -- General configuration ----------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom @@ -129,8 +129,9 @@ def filter(self, record): # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'default' -# -- Options for autodoc ----------------------------------------------------- +# -- Options for autodoc ------------------------------------------------------- +autosummary_generate = False autodoc_default_options = { 'member-order': 'bysource', 'undoc-members': True, @@ -141,16 +142,21 @@ def filter(self, record): autoclass_content = 'both' simplify_optional_unions = False -# -- Options for bibtex ----------------------------------------------------- +# -- Options for autosummary --------------------------------------------------- + +autosummary_generate = False +# numpydoc_class_members_toctree = False + +# -- Options for bibtex -------------------------------------------------------- bibtex_bibfiles = ['references.bib'] -# -- Options for myst ------------------------------------------------------- +# -- Options for myst ---------------------------------------------------------- nb_execution_mode = 'force' nb_execution_allow_errors = False -# -- Options for katex ------------------------------------------------------ +# -- Options for katex --------------------------------------------------------- # See: https://sphinxcontrib-katex.readthedocs.io/en/0.4.1/macros.html latex_macros = r""" @@ -164,7 +170,7 @@ def filter(self, record): # Add LaTeX macros for LATEX builder latex_elements = {'preamble': latex_macros} -# -- Options for HTML output ------------------------------------------------- +# -- Options for HTML output --------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. @@ -203,27 +209,27 @@ def setup(app): # # html_sidebars = {} -# -- Source code links ------------------------------------------------------- +# -- Source code links --------------------------------------------------------- extlinks = { 'gitcode': ('https://github.com/metaopt/torchopt/blob/HEAD/%s', '%s'), 'issue': ('https://github.com/metaopt/torchopt/issues/%s', 'issue %s'), } -# -- Extension configuration ------------------------------------------------- +# -- Extension configuration --------------------------------------------------- -# -- Options for napoleon extension ------------------------------------------ +# -- Options for napoleon extension -------------------------------------------- napoleon_include_init_with_doc = True napoleon_include_private_with_doc = False napoleon_include_special_with_doc = True -# -- Options for intersphinx extension --------------------------------------- +# -- Options for intersphinx extension ----------------------------------------- # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = {'python': ('https://docs.python.org/3', None)} -# -- Options for todo extension ---------------------------------------------- +# -- Options for todo extension ------------------------------------------------ # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = True diff --git a/docs/source/developer/contributing.rst b/docs/source/developer/contributing.rst index b4c4c825..ee66f560 100644 --- a/docs/source/developer/contributing.rst +++ b/docs/source/developer/contributing.rst @@ -12,7 +12,7 @@ Before contributing to TorchOpt, please follow the instructions below to setup. git remote add upstream git@github.com:metaopt/torchopt.git -2. Setup a development environment via `conda `_: +2. Setup a development environment via `conda `_ / `mamba `_: .. code-block:: bash @@ -53,7 +53,7 @@ We use several tools to secure code quality, including: * PEP8 code style: ``black``, ``isort``, ``pylint``, ``flake8`` * Type hint check: ``mypy`` - * C++ Google-style: ``cpplint``, ``clang-format`` + * C++ Google-style: ``cpplint``, ``clang-format``, ``clang-tidy`` * License: ``addlicense`` * Documentation: ``pydocstyle``, ``doc8`` diff --git a/docs/source/developer/contributor.rst b/docs/source/developer/contributor.rst index 407b53b0..2358f963 100644 --- a/docs/source/developer/contributor.rst +++ b/docs/source/developer/contributor.rst @@ -3,5 +3,5 @@ Contributor We always welcome contributions to help make TorchOpt better. Below is an incomplete list of our contributors (find more on `this page `_). -* Yao Fu (`future-xy `_) -* Vincent Moens (`vmoens `_) +- Yao Fu (`future-xy `_) +- Vincent Moens (`vmoens `_) diff --git a/docs/source/distributed/distributed.rst b/docs/source/distributed/distributed.rst new file mode 100644 index 00000000..b6f00951 --- /dev/null +++ b/docs/source/distributed/distributed.rst @@ -0,0 +1,733 @@ +Distributed Training +==================== + +.. currentmodule:: torchopt.distributed + +Distributed training is a technique that allows you to train your pipeline on multiple workers/machines. +This is useful when you have a large model or computation graph that doesn't fit on a single GPU/machine, or when you want to train a model faster by using more resources. + +TorchOpt offers a simple API to train your model on multiple GPUs/machines based on the PyTorch |Distributed RPC|_. +Here are some key concepts that TorchOpt's distributed mechanism relies on: + +- **Remote Procedure Call (RPC)** supports running a function on the specified destination worker with the given arguments and getting the return value back or creating a reference to the return value. + + That is, you can treat the remote worker as an accelerator. You can call a function on a remote worker and get the result back to the local worker. + +- **Distributed Autograd** stitches together local autograd engines on all the workers involved in the forward pass, and automatically reach out to them during the backward pass to compute gradients. + + This is much more flexible to fit the meta-learning use case to have a complex task dependency tree. + +.. |Distributed RPC| replace:: Distributed RPC Framework (``torch.distributed.rpc``) +.. _Distributed RPC: https://pytorch.org/docs/stable/rpc.html + +Here are some useful resources to learn more about distributed training: + +- `Distributed RPC Framework `_ +- `Distributed Autograd Design `_ +- `Remote Reference Protocol `_ +- `RPC tutorials `_ +- `Autograd mechanics `_ +- **Example**: :ref:`Using TorchOpt with Distributed Training ` + +------ + +Why RPC-Based Distributed Training +---------------------------------- + +Due to the Global Interpreter Lock (GIL) in Python, only one thread can execute Python code at a time. +This means that you can't take advantage of multiple cores on your machine. +Distribute the workload across multiple processes, or namely workers, that will run in parallel to gain faster execution performance. +Each worker will have its own Python interpreter and memory namespace. + +Compare to single-process programming, you need to be aware of the following: + +- **Communication**: You need to explicitly send and receive messages between workers. +- **Synchronization**: You need to explicitly synchronize the states between workers. + +Message Passing Interface (MPI) and Distributed Data-Parallel Training (DDP) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +`MPI `_ is a standard for message passing between processes. +It is a popular choice for `Distributed Data-Parallel Training (DDP) `_. +PyTorch has implemented this with several `backends `_, including `Gloo `_, `MPI `_, and `NCCL `_. + +However, MPI-based parallelism has some drawbacks: + +- **MPI is not user-friendly**. + MPI-like APIs only provide low-level primitives for sending and receiving messages. + It requires the users to manage the message passing between workers manually. + The users should be aware of the communication pattern and the synchronization between workers. + +- **MPI is not flexible**. + MPI-like APIs are designed for `Distributed Data-Parallel Training (DDP) `_, which is a widely adopted `single-program multiple-data (SPMD) `_ training paradigm. + However, for meta-learning tasks, the task dependency tree is complex and dynamic. + It may not fit into the SPMD paradigm. + It is hard to implement the distributed autograd engine on top of MPI. + +- **MPI only communicates the value of tensors but not the gradients and graphs**. + This is a limitation of MPI. + The users need to handle the gradients manually across multiple workers. + For example, receive the gradients from other workers and put them as ``grad_outputs`` to function |torch.autograd.grad|_. + +.. |torch.autograd.grad| replace:: ``torch.autograd.grad`` +.. _torch.autograd.grad: https://pytorch.org/docs/stable/generated/torch.autograd.grad.html + +Distributed Autograd with Remote Procedure Call (RPC) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To address the needs of meta-learning tasks, which have complex and dynamic nature of the training process. +TorchOpt uses PyTorch |Distributed RPC|_ to implement the distributed training mechanism. +PyTorch implements the RPC communication operations with appropriate ``RpcSendBackward`` and ``RpcRecvBackward`` functions. +The `Distributed Autograd Engine `_ automatically calls these functions to send and receive the gradients between workers. + +With **RPC** and **Distributed Autograd**, TorchOpt distributes a **differentiable optimization** job across multiple workers and executes the workers in parallel. +It allows the users to build the whole computation graph (**both forward and backward**) across multiple workers. +The users can wrap code in the distributed autograd module and achieve substantial speedup in training time with only a few changes in existing training scripts. (:ref:`example `) + +Here is an example of distributed autograd graph using RPC from `Distributed Backward Pass `_ documentation: + +.. code-block:: python + :emphasize-lines: 13, 18, 28, 31 + + import torch + import torch.distributed.autograd as dist_autograd + import torch.distributed.rpc as rpc + + def my_add(t1, t2): + return torch.add(t1, t2) + + # On worker 0: + + # Setup the autograd context. Computations that take + # part in the distributed backward pass must be within + # the distributed autograd context manager. + with dist_autograd.context() as context_id: + t1 = torch.rand((3, 3), requires_grad=True) + t2 = torch.rand((3, 3), requires_grad=True) + + # Perform some computation remotely. + t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2)) + + # Perform some computation locally based on the remote result. + t4 = torch.rand((3, 3), requires_grad=True) + t5 = torch.mul(t3, t4) + + # Compute some loss. + loss = t5.sum() + + # Run the backward pass. + dist_autograd.backward(context_id, [loss]) + + # Retrieve the gradients from the context. + dist_autograd.get_gradients(context_id) + +.. image:: https://pytorch.org/docs/stable/_images/distributed_dependencies_computed.png + +For more details, please refer to the `Distributed Autograd Design `_ documentation. + +------ + +TorchOpt's Distributed Training +------------------------------- + +TorchOpt's distributed package is built upon the PyTorch |Distributed RPC|_ and |Distributed Autograd Framework|_. + +.. |Distributed Autograd Framework| replace:: Distributed Autograd Framework (``torch.distributed.autograd``) +.. _Distributed Autograd Framework: https://pytorch.org/docs/stable/rpc.html#distributed-autograd-framework + +TorchOpt provides some utility functions to make it easier to use the distributed training mechanism. + +Initialization and Synchronization +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + torchopt.distributed.auto_init_rpc + torchopt.distributed.barrier + +Users can wrap their program entry function with the decorator :func:`torchopt.distributed.auto_init_rpc`: + +.. code-block:: python + :emphasize-lines: 13 + + import torchopt.distributed as todist + + def parse_arguments(): + parser = argparse.ArgumentParser() + ... + + return args + + def worker_init_fn(): + # set process title, seeding, etc. + ... + + @todist.auto_init_rpc(worker_init_fn) + def main(): + # Your code here + args = parse_arguments() + ... + + if __name__ == '__main__': + main() + +The decorator will initialize the RPC framework and synchronize the workers on startup. + +.. note:: + + By default, all tensors must move to the CPU before sending them to other workers. + If you want to send/receive the tensors directly between GPUs from different workers, you need to specify the ``rpc_backend_options`` with ``device_maps``. + Please refer to the documentation of |torch.distributed.rpc.init_rpc|_ for more details. + +.. |torch.distributed.rpc.init_rpc| replace:: ``torch.distributed.rpc.init_rpc`` +.. _torch.distributed.rpc.init_rpc: https://pytorch.org/docs/stable/rpc.html#torch.distributed.rpc.init_rpc + +Then, users can use |torchrun|_ to launch the program: + +.. code-block:: bash + + torchrun --nnodes=1 --nproc_per_node=8 YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) + +.. |torchrun| replace:: ``torchrun`` (Elastic Launch) +.. _torchrun: https://pytorch.org/docs/stable/elastic/run.html + +Process group information +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + torchopt.distributed.get_world_info + torchopt.distributed.get_world_rank + torchopt.distributed.get_rank + torchopt.distributed.get_world_size + torchopt.distributed.get_local_rank + torchopt.distributed.get_local_world_size + torchopt.distributed.get_worker_id + +After initializing the RPC server, users can use the above functions to get the process group information. + +For example, use :func:`torchopt.distributed.get_local_rank` to determine which GPU device to use: + +.. code-block:: python + + import torch + import torchopt.distributed as todist + + def worker_init_fn(): + local_rank = todist.get_local_rank() + torch.cuda.set_device(local_rank) + + @todist.auto_init_rpc(worker_init_fn) + def main(): + ... + +Worker selection +~~~~~~~~~~~~~~~~ + +.. autosummary:: + + torchopt.distributed.on_rank + torchopt.distributed.not_on_rank + torchopt.distributed.rank_zero_only + torchopt.distributed.rank_non_zero_only + +TorchOpt provides some decorators to execute the decorated function on specific workers. + +For example, use :func:`torchopt.distributed.rank_zero_only` to execute the function only on the main worker (``worker0``), such as saving checkpoints or logging the results: + +.. code-block:: python + :emphasize-lines: 3, 7, 11 + + import torchopt.distributed as todist + + @todist.rank_non_zero_only + def greet(): + print(f'Greetings from worker(rank={todist.get_rank()})!') + + @todist.rank_zero_only + def save_checkpoint(model): + ... + + @todist.rank_zero_only + def log_results(writer, results): + ... + + @todist.auto_init_rpc() + def main(): + greet() + + ... + + for epoch in range(args.epochs): + ... + + if epoch % args.log_interval == 0: + log_results(writer, results) + + if epoch % args.save_interval == 0: + save_checkpoint(model) + +Remote Procedure Call (RPC) +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + torchopt.distributed.remote_async_call + torchopt.distributed.remote_sync_call + +TorchOpt provides two functions to execute the remote procedure call (RPC) on remote workers. +The asynchronous version :func:`remote_async_call` function returns a |torch.Future|_ object, and the :func:`remote_sync_call` function executes and returns the result directly. + +.. |torch.Future| replace:: ``torch.Future`` +.. _torch.Future: https://pytorch.org/docs/stable/futures.html#torch.futures.Future + +Users can distribute their workload (a function) to a specific worker by: + +.. code-block:: python + :emphasize-lines: 12 + + import torchopt.distributed as todist + + @todist.auto_init_rpc(worker_init_fn) + def main(): + ... + + # Execute the function on the remote worker (asynchronously) + future = todist.remote_async_call( + func, + args=(arg1, arg2, ...), + kwargs={...}, + partitioner=worker_id, + ) + + # Wait for the result + result = future.wait() + + ... + +or + +.. code-block:: python + :emphasize-lines: 12 + + import torchopt.distributed as todist + + @todist.auto_init_rpc(worker_init_fn) + def main(): + ... + + # Execute the function on the remote worker + result = todist.remote_sync_call( + func, + args=(arg1, arg2, ...), + kwargs={...}, + partitioner=worker_id, + ) + + ... + +TorchOpt follows the `MapReduce programming model `_ to distribute the workload. + +The ``partitioner`` argument specifies the worker to execute the function. +The users can optionally specify the ``reducer`` argument to aggregate the results from the workers. +Finally, the caller will get a reference to the result on the local worker. + +- ``partitioner``: a function that takes the ``args`` and ``kwargs`` arguments and returns a list of triplets ``(worker_id, worker_args, worker_kwargs)``. + + The ``partitioner`` is responsible for partitioning the workload (inputs) and distributing them to the remote workers. + + If the ``partitioner`` is given by a worker ID (:class:`int` or :class:`str`), the function will be executed on the specified worker. + + If the ``partitioner`` is not given, the :func:`torchopt.distributed.batch_partitioner` will be used. + +- ``mapper``: the ``func`` argument to be executed on the remote worker. +- ``reducer`` (optional): aggregation function, takes a list of results from the remote workers and returns the final result. + + If the ``reducer`` is not given, returns the original unaggregated list. + +Predefined partitioners and reducers +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + torchopt.distributed.dim_partitioner + torchopt.distributed.batch_partitioner + torchopt.distributed.mean_reducer + torchopt.distributed.sum_reducer + +We provide some predefined partitioners and reducers. +Users can combine the :func:`torchopt.distributed.batch_partitioner` and :func:`torchopt.distributed.mean_reducer` to achieve the distributed data parallelism (DDP) easily: + +.. code-block:: python + :emphasize-lines: 18, 19 + + import torchopt.distributed as todist + + def loss_fn(model, batch): + ... + + @todist.rank_zero_only + def train(args): + + for epoch in range(args.epochs): + ... + + for batch in dataloader: + # Partition the data on the batch (first) dimension and distribute them to the remote workers + # Aggregate the results from the remote workers and return the mean loss + loss = todist.remote_sync_call( + loss_fn, + args=(model, batch), + partitioner=todist.batch_partitioner, + reducer=todist.mean_reducer, + ) + + ... + +We also provide a :func:`torchopt.distributed.dim_partitioner` to partition the data on the specified dimension. +While implementing the **Model-Agnostic Meta-Learning** (MAML) :cite:`MAML` algorithm, users can use this to parallel the training for the inner loop: + +.. code-block:: python + :emphasize-lines: 29, 30 + + import torchopt.distributed as todist + + def inner_loop(model, task_batch, args): + # task_batch: shape = (B, *) + inner_model = torchopt.module_clone(model, by='reference', detach_buffers=True) + + # Inner optimization + for inner_step in range(args.inner_steps): + inner_loss = inner_loss_fn(inner_model, task_batch) + + # Update the inner model + ... + + # Compute the outer loss + outer_loss = inner_loss_fn(inner_model, task_batch) + return outer_loss + + @todist.rank_zero_only + def train(args): + + for epoch in range(args.epochs): + ... + + for batch in dataloader: + # batch: shape = (T, B, *) + outer_loss = todist.remote_sync_call( + inner_loop, + args=(model, batch), + partitioner=todist.dim_partitioner(0, exclusive=True, keepdim=False), + reducer=todist.mean_reducer, + ) + + ... + +The ``dim_partitioner(0, exclusive=True, keepdim=False)`` will split the batch of size ``(T, B, *)`` into ``T`` batches of size ``(B, *)``. +Each task will be executed on the remote worker **independently** (``exclusive=True``). +Finally, the results will be aggregated by the :func:`torchopt.distributed.mean_reducer` to compute the mean loss. +Inside the ``inner_loop`` function, users may use another RPC call to further parallelize the inner loop optimization. + +Function parallelization wrappers +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + torchopt.distributed.parallelize + torchopt.distributed.parallelize_async + torchopt.distributed.parallelize_sync + +TorchOpt offers wrappers to parallelize the function execution on the remote workers. +It makes the function execution on the remote workers more transparent to the users and makes the code structure clear. + +.. code-block:: python + :emphasize-lines: 3, 9, 10, 11, 12 + + import torchopt.distributed as todist + + @todist.parallelize(partitioner=todist.batch_partitioner, reducer=todist.mean_reducer) + def distributed_data_parallelism(model, batch, args): + # Compute local loss of the given batch + ... + return loss + + @todist.parallelize( + partitioner=todist.dim_partitioner(0, exclusive=True, keepdim=False), + reducer=todist.mean_reducer, + ) + def inner_loop(model, batch, args): # distributed MAML inner loop + # batch: shape = (B, *) + inner_model = torchopt.module_clone(model, by='reference', detach_buffers=True) + + # Inner optimization + ... + + # Compute the outer loss + outer_loss = inner_loss_fn(inner_model, task_batch) + return outer_loss + + @todist.rank_zero_only + def train(args): + + for epoch in range(args.epochs): + ... + + for batch in dataloader: + # batch: shape = (T, B, *) + outer_loss = inner_loop(model, batch, args) + + ... + +Distributed Autograd +~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + torchopt.distributed.autograd.context + torchopt.distributed.autograd.get_gradients + torchopt.distributed.autograd.backward + torchopt.distributed.autograd.grad + +In this section, we will introduce the distributed autograd system. +Please refer to `Autograd mechanics `_ and `Distributed Autograd Design `_ first before going through this section. + +Recap: Autograd mechanics in single-process training +"""""""""""""""""""""""""""""""""""""""""""""""""""" + +In single-process training, the autograd engine will automatically track the operations on the forward pass and compute the gradients on the backward pass. +For each operation, if the input tensors have ``requires_grad=True`` set, the output tensor will have a ``grad_fn`` attribute to trace the computation graph. +On the backward pass, the autograd engine will traverse the computation graph from the output tensors to the input tensors and compute the gradients for each operation. + +The |torch.autograd.grad|_ function will compute the gradients of the given ``outputs`` with respect to the given ``inputs``. + +.. code-block:: python + + import torch + + model = build_model() + loss = compute_loss(model, data) + + params = tuple(model.parameters()) + grads = torch.autograd.grad(loss, params) + + print(grads) + +In practice, users usually use the PyTorch Autograd Engine with ``loss.backward()`` (or |torch.autograd.backward|_) and optimizers: + +.. code-block:: python + + import torch + import torch.optim as optim + + model = build_model() + optimizer = optim.SGD(model.parameters(), lr=lr) + + loss = compute_loss(model, data) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + +Compare to |torch.autograd.grad|_, the |torch.autograd.backward|_ function will sum and update the ``.grad`` attribute of the parameters. + +.. |torch.autograd.backward| replace:: ``torch.autograd.backward`` +.. _torch.autograd.backward: https://pytorch.org/docs/stable/generated/torch.autograd.backward.html + +RPC-based Distributed Autograd +"""""""""""""""""""""""""""""" + +PyTorch RPC framework implements the communication ``send-recv`` operations with appropriate backward functions (``RpcSendBackward`` and ``RpcRecvBackward``). +They can be tracked by the **Distributed Autograd Engine** like the single-process program we discussed above. + +The only difference between the single-process and distributed training is that users need to explicitly create a **Distributed Autograd Context** and wrap around the forward and backward passes. + +.. code-block:: python + :emphasize-lines: 4, 9, 12 + + import torch + import torch.distributed.autograd as dist_autograd + + with dist_autograd.context() as context_id: + # Forward pass + loss = ... # e.g. remote calls + + # Backward pass + dist_autograd.backward(context_id, [loss]) + + # Retrieve the gradients from the context. + grad_dict = dist_autograd.get_gradients(context_id) # type: Dict[Tensor, Tensor] + +.. warning:: + + Sending |torch.nn.Parameter|_\s over RPC will automatically detach from the autograd graph. + This is an intentional behavior of the PyTorch framework because the |torch.nn.Parameter|_\s are always leaf nodes in the graph. + The leaf tensors will not have ``grad_fn`` attribute and thus cannot be tracked by the autograd engine after sending them to other workers. + + To make the graph can be properly tracked across workers, users should convert the |torch.nn.Parameter|_\s to |torch.Tensor|_\s before sending them over RPC. + For example, explicitly ``clone()`` the parameters to tensors before taking them as arguments of the RPC call. + + .. code-block:: python + + import torch + import torch.distributed.rpc as rpc + + def compute_loss(param): + return param.mean() + + param = torch.nn.Parameter(torch.randn(2, 2), requires_grad=True) + + # The RPC call will detach the parameter from the autograd graph on worker1 + loss1 = rpc.rpc_sync('worker1', compute_loss, args=(param,)) + + # The RPC call will keep connection to the parameter in the autograd graph on worker1 + loss2 = rpc.rpc_sync('worker1', compute_loss, args=(param.clone(),)) + + Users can use :func:`torchopt.module_clone` function to clone the module and convert all its parameters to tensors. + The tensors will have a ``grad_fn`` attribute ``CloneBackward`` to track the computation graph to the original parameters. + + .. code-block:: python + + import torch + import torch.nn as nn + import torchopt + + def compute_loss(model, batch): + ... + return loss + + model = nn.Linear(2, 2) + tuple(model.parameters()) # -> `nn.Parameter`s + + cloned_model = torchopt.module_clone(model, by='clone') + tuple(cloned_model.parameters()) # -> `torch.Tensor`s with `CloneBackward` grad_fn + + # The RPC call will detach the parameter from the autograd graph on worker1 + loss1 = rpc.rpc_sync('worker1', compute_loss, args=(model, batch)) + + # The RPC call will keep the connection to the parameter in the autograd graph on worker1 + loss2 = rpc.rpc_sync('worker1', compute_loss, args=(cloned_model, batch)) + +.. |torch.nn.Parameter| replace:: ``torch.nn.Parameter`` +.. _torch.nn.Parameter: https://pytorch.org/docs/stable/generated/torch.nn.parameter.Parameter.html +.. |torch.Tensor| replace:: ``torch.Tensor`` +.. _torch.Tensor: https://pytorch.org/docs/stable/tensors.html + +TorchOpt wraps the distributed autograd context and provides a more convenient interface to use. + +.. code-block:: python + :emphasize-lines: 5, 10 + + import torchopt.distributed as todist + + model = build_model() + + with todist.autograd.context() as context_id: + # Forward pass + loss = ... # e.g. remote calls + + # Backward pass + grads = todist.autograd.grads(context_id, loss, model.parameters()) + +or + +.. code-block:: python + :emphasize-lines: 7, 13 + + import torch + import torchopt.distributed as todist + + model = build_model() + optimizer = torch.optim.SGD(model.parameters(), lr=lr) + + with todist.autograd.context() as context_id: + # Forward pass + loss = ... # e.g. remote calls + + # Backward pass + optimizer.zero_grad() + todist.autograd.backward(context_id, loss) + optimizer.step() + +.. warning:: + + The distributed autograd context is not thread-safe. + Users should not use the same context in multiple threads. + +Users can update their single-process training code to distributed training code with minimum changes: + +#. Add the distributed autograd context around the forward and backward passes. +#. Wrap the functions with :func:`torchopt.distributed.parallelize` to enable parallel execution. +#. Convert the parameters to tensors before sending them over RPC. +#. Replace the ``torch.autograd`` to ``torchopt.distributed.autograd``. + +Here is a full example of converting the single-process training code to distributed training code: + +.. code-block:: python + :emphasize-lines: 17, 32, 40, 42, 43, 47, 52 + :name: distributed-example + + import torch + import torch.nn as nn + import torchopt.distributed as todist + + def parse_arguments(): + parser = argparse.ArgumentParser(description='TorchOpt Distributed Training') + ... + + args = parser.parse_args() + return args + + def worker_init_fn(): + # set process title, seeding, etc. + setproctitle.setproctitle(f'Worker{todist.get_rank()}') + torch.manual_seed(args.seed + todist.get_rank()) + + @todist.parallelize(partitioner=todist.batch_partitioner, reducer=todist.mean_reducer) + def compute_loss(model, batch): + device = torch.device(f'cuda:{todist.get_local_rank()}') + model = model.to(device) + batch = batch.to(device) + + # Compute local loss of the given batch + ... + return loss.cpu() + + def build_model(): + return nn.Sequential( + ... + ) + + @todist.rank_zero_only + def train(args): + model = build_model() + optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) + train_loader = ... + + for epoch in range(args.epochs): + for batch in train_loader: + with todist.autograd.context() as context_id: + # Forward pass + cloned_model = todist.module_clone(model, by='clone') + loss = compute_loss(cloned_model, batch) + + # Backward pass + optimizer.zero_grad() + todist.autograd.backward(context_id, loss) + + # Update parameters + optimizer.step() + + @todist.auto_init_rpc(worker_init_fn) + def main(): + args = parse_arguments() + train(args) + + if __name__ == '__main__': + main() + +Then, users can use |torchrun|_ to launch the program: + +.. code-block:: bash + + torchrun --nnodes=1 --nproc_per_node=8 YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) diff --git a/docs/source/explicit_diff/explicit_diff.rst b/docs/source/explicit_diff/explicit_diff.rst new file mode 100644 index 00000000..89c38df6 --- /dev/null +++ b/docs/source/explicit_diff/explicit_diff.rst @@ -0,0 +1,162 @@ +Explicit Gradient Differentiation +================================= + +.. currentmodule:: torchopt + +Explicit Gradient +----------------- + +.. image:: /_static/images/explicit-gradient.png + :width: 80% + :align: center + +The idea of explicit gradient is to treat the gradient step as a differentiable function and try to backpropagate through the unrolled optimization path. +Namely, given + +.. math:: + + \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi}) \triangleq \boldsymbol{\theta}_0 - \alpha \sum_{i=0}^{K-1} \nabla_{\boldsymbol{\theta}_i} \mathcal{L}^{\text{in}} (\boldsymbol{\phi},\boldsymbol{\theta}_i), + +we would like to compute the gradient :math:`\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})`. +This is usually done by AutoDiff through an inner optimization's unrolled iterates. + +Differentiable Functional Optimizers +------------------------------------ + +By passing the argument ``inplace`` as :data:`False` to the ``update`` functions, we can make the optimization differentiable. +Here is an example of making :func:`torchopt.adam` differentiable. + +.. code-block:: python + + opt = torchopt.adam() + # Define meta and inner parameters + meta_params = ... + fmodel, params = make_functional(model) + # Initialize optimizer state + state = opt.init(params) + + for iter in range(iter_times): + loss = inner_loss(fmodel, params, meta_params) + grads = torch.autograd.grad(loss, params) + # Apply non-inplace parameter update + updates, state = opt.update(grads, state, inplace=False) + params = torchopt.apply_updates(params, updates) + + loss = outer_loss(fmodel, params, meta_params) + meta_grads = torch.autograd.grad(loss, meta_params) + +Differentiable OOP Meta-Optimizers +---------------------------------- + +For PyTorch-like API (e.g., ``step()``), we designed a base class :class:`torchopt.MetaOptimizer` to wrap our functional optimizers to become differentiable OOP meta-optimizers. + +.. autosummary:: + + torchopt.MetaOptimizer + torchopt.MetaAdam + torchopt.MetaSGD + torchopt.MetaRMSProp + torchopt.MetaAdamW + +By combining low-level API :class:`torchopt.MetaOptimizer` with the previous functional optimizer, we can achieve high-level API: + +.. code-block:: python + + # Low-level API + optim = torchopt.MetaOptimizer(net, torchopt.sgd(lr=1.0)) + + # High-level API + optim = torchopt.MetaSGD(net, lr=1.0) + +Here is an example of using the OOP API :class:`torchopt.MetaAdam` to conduct meta-gradient calculation. + +.. code-block:: python + + # Define meta and inner parameters + meta_params = ... + model = ... + # Define differentiable optimizer + opt = torchopt.MetaAdam(model) + + for iter in range(iter_times): + # Perform the inner update + loss = inner_loss(model, meta_params) + opt.step(loss) + + loss = outer_loss(model, meta_params) + loss.backward() + +CPU/GPU Accelerated Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +TorchOpt performs the symbolic reduction by manually writing the forward and backward functions using C++ OpenMP (CPU) and CUDA (GPU), which largely increase meta-gradient computational efficiency. +Users can use accelerated optimizer by setting the ``use_accelerated_op`` as :data:`True`. +TorchOpt will automatically detect the device and allocate the corresponding accelerated optimizer. + +.. code-block:: python + + # Check whether the `accelerated_op` is available: + torchopt.accelerated_op_available(torch.device('cpu')) + + torchopt.accelerated_op_available(torch.device('cuda')) + + net = Net(1).cuda() + optim = torchopt.Adam(net.parameters(), lr=1.0, use_accelerated_op=True) + +General Utilities +----------------- + +We provide the :func:`torchopt.extract_state_dict` and :func:`torchopt.recover_state_dict` functions to extract and restore the state of network and optimizer. +By default, the extracted state dictionary is a reference (this design is for accumulating gradient of multi-task batch training, MAML for example). +You can also set ``by='copy'`` to extract the copy of the state dictionary or set ``by='deepcopy'`` to have a detached copy. + +.. autosummary:: + + torchopt.extract_state_dict + torchopt.recover_state_dict + torchopt.stop_gradient + +Here is an usage example. + +.. code-block:: python + + net = Net() + x = nn.Parameter(torch.tensor(2.0), requires_grad=True) + + optim = torchopt.MetaAdam(net, lr=1.0) + + # Get the reference of state dictionary + init_net_state = torchopt.extract_state_dict(net, by='reference') + init_optim_state = torchopt.extract_state_dict(optim, by='reference') + # If set `detach_buffers=True`, the parameters are referenced as references while buffers are detached copies + init_net_state = torchopt.extract_state_dict(net, by='reference', detach_buffers=True) + + # Set `copy` to get the copy of the state dictionary + init_net_state_copy = torchopt.extract_state_dict(net, by='copy') + init_optim_state_copy = torchopt.extract_state_dict(optim, by='copy') + + # Set `deepcopy` to get the detached copy of state dictionary + init_net_state_deepcopy = torchopt.extract_state_dict(net, by='deepcopy') + init_optim_state_deepcopy = torchopt.extract_state_dict(optim, by='deepcopy') + + # Conduct 2 inner-loop optimization + for i in range(2): + inner_loss = net(x) + optim.step(inner_loss) + + print(f'a = {net.a!r}') + + # Recover and reconduct 2 inner-loop optimization + torchopt.recover_state_dict(net, init_net_state) + torchopt.recover_state_dict(optim, init_optim_state) + + for i in range(2): + inner_loss = net(x) + optim.step(inner_loss) + + print(f'a = {net.a!r}') # the same result + +Notebook Tutorial +----------------- + +Check the notebook tutorials at `Meta Optimizer `_ and `Stop Gradient `_. diff --git a/docs/source/implicit_diff/implicit_diff.rst b/docs/source/implicit_diff/implicit_diff.rst new file mode 100644 index 00000000..df0927c9 --- /dev/null +++ b/docs/source/implicit_diff/implicit_diff.rst @@ -0,0 +1,178 @@ +Implicit Gradient Differentiation +================================= + +.. currentmodule:: torchopt.diff.implicit + +Implicit Differentiation +------------------------ + +.. image:: /_static/images/implicit-gradient.png + :width: 80% + :align: center + +Implicit differentiation is the task of differentiating the solution of a minimization problem with respect to its inputs. +Namely, given + +.. math:: + + \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi}) \triangleq \underset{\boldsymbol{\theta}}{\mathop{\operatorname{argmin}}} ~ \mathcal{L}^{\text{in}} (\boldsymbol{\phi},\boldsymbol{\theta}). + +By treating the solution :math:`\boldsymbol{\theta}^{\prime}` as an implicit function of :math:`\boldsymbol{\phi}`, the idea of implicit differentiation is to directly get analytical best-response derivatives :math:`\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})` by the implicit function theorem. +This is suitable for algorithms when the inner-level optimal solution is achieved :math:`\left. \frac{\partial \mathcal{L}^{\text{in}} (\boldsymbol{\phi}, \boldsymbol{\theta})}{\partial \boldsymbol{\theta}} \right\rvert_{\boldsymbol{\theta} = \boldsymbol{\theta}^{\prime}} = 0` (e.g., the function :math:`F` in the figure means the solution is obtained by unrolled gradient steps) or reaches some stationary conditions :math:`F (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime}) = 0`, such as `IMAML `_ and `DEQ `_. + +Custom Solvers +-------------- + +.. autosummary:: + + torchopt.diff.implicit.custom_root + +TorchOpt provides the :func:`custom_root` decorators, for easily adding implicit differentiation on top of any existing solver (also called forward optimization). +:func:`custom_root` requires users to define the stationary conditions for the problem solution (e.g., KKT conditions) and will automatically calculate the gradient for backward gradient computation. + +Here is an example of the :func:`custom_root` decorators, which is also the **functional API** for implicit gradient. + +.. code-block:: python + + # Functional API for implicit gradient + def stationary(params, meta_params, data): + # stationary condition construction + return stationary condition + + # Decorator that wraps the function + # Optionally specify the linear solver (conjugate gradient or Neumann series) + @torchopt.diff.implicit.custom_root(stationary) + def solve(params, meta_params, data): + # Forward optimization process for params + return optimal_params + + # Define params, meta_params and get data + params, meta_prams, data = ..., ..., ... + optimal_params = solve(params, meta_params, data) + loss = outer_loss(optimal_params) + + meta_grads = torch.autograd.grad(loss, meta_params) + +OOP API +~~~~~~~ + +.. autosummary:: + + torchopt.nn.ImplicitMetaGradientModule + +Coupled with PyTorch |torch.nn.Module|_, we also design the OOP API :class:`nn.ImplicitMetaGradientModule` for implicit gradient. +The core idea of :class:`nn.ImplicitMetaGradientModule` is to enable the gradient flow from ``self.parameters()`` (usually lower-level parameters) to ``self.meta_parameters()`` (usually the high-level parameters). +Users need to define the forward process ``forward()``, a stationary function ``optimality()`` (or ``objective()``), and inner-loop optimization ``solve``. + +.. |torch.nn.Module| replace:: ``torch.nn.Module`` +.. _torch.nn.Module: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module + +Here is an example of the OOP API. + +.. code-block:: python + + from torchopt.nn import ImplicitMetaGradientModule + + # Inherited from the class ImplicitMetaGradientModule + class InnerNet(ImplicitMetaGradientModule): + def __init__(self, meta_module): + ... + + def forward(self, batch): + # Forward process + ... + + def optimality(self, batch, labels): + # Stationary condition construction for calculating implicit gradient + # NOTE: If this method is not implemented, it will be automatically derived from the + # gradient of the `objective` function. + ... + + def objective(self, batch, labels): + # Define the inner-loop optimization objective + # NOTE: This method is optional if method `optimality` is implemented. + ... + + def solve(self, batch, labels): + # Conduct the inner-loop optimization + ... + return self # optimized module + + # Get meta_params and data + meta_params, data = ..., ... + inner_net = InnerNet() + + # Solve for inner-loop process related to the meta-parameters + optimal_inner_net = inner_net.solve(meta_params, *data) + + # Get outer-loss and solve for meta-gradient + loss = outer_loss(optimal_inner_net) + meta_grad = torch.autograd.grad(loss, meta_params) + +If the optimization objective is to minimize/maximize an objective function, we offer an ``objective`` method interface to simplify the implementation. +Users only need to define the ``objective`` method, while TorchOpt will automatically analyze it for the stationary (optimality) condition from the KKT condition. + +.. note:: + + In ``__init__`` method, users need to define the inner parameters and meta-parameters. + By default, :class:`nn.ImplicitMetaGradientModule` treats all tensors and modules from the method inputs as ``self.meta_parameters()`` / ``self.meta_modules()``. + For example, statement ``self.yyy = xxx`` will assign ``xxx`` as a meta-parameter with name ``'yyy'`` if ``xxx`` is present in the method inputs (e.g., ``def __init__(self, xxx, ...): ...``). + All tensors and modules defined in the ``__init__`` are regarded as ``self.parameters()`` / ``self.modules()``. + Users can also register parameters and meta-parameters by calling ``self.register_parameter()`` and ``self.register_meta_parameter()`` respectively. + +Linear System Solvers +--------------------- + +.. autosummary:: + + torchopt.linear_solve.solve_cg + torchopt.linear_solve.solve_inv + torchopt.linear_solve.solve_normal_cg + +Usually, the computation of implicit gradient involves the computation of the inverse Hessian matrix. +However, the high-dimensional Hessian matrix also makes direct computation intractable, and this is where linear solver comes into play. +By iteratively solving the linear system problem, we can calculate the inverse Hessian matrix up to some precision. We offer the `conjugate-gradient `_ based solver and `neuman-series `_ based solver. + +Here is an example of the linear solver. + +.. code-block:: python + + from torchopt import linear_solve + + torch.random.seed(42) + A = torch.random.randn(3, 3) + b = torch.random.randn(3) + + def matvec_A(x): + return torch.dot(A, x) + + sol = linear_solve.solve_normal_cg(matvec_A, b, tol=1e-5) + print(sol) + + sol = linear_solve.solve_cg(matvec_A, b, tol=1e-5) + print(sol) + +Users can also select the corresponding solver in functional and OOP APIs. + +.. code-block:: python + + # For functional API + @torchopt.diff.implicit.custom_root( + functorch.grad(objective_fn, argnums=0), # optimality function + argnums=1, + solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0), + ) + def solve_fn(...): + ... + + # For OOP API + class InnerNet( + torchopt.nn.ImplicitMetaGradientModule, + linear_solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0), + ): + ... + +Notebook Tutorial +----------------- + +Check the notebook tutorial at `Implicit Differentiation `_. diff --git a/docs/source/index.rst b/docs/source/index.rst index a4c20e22..02fab843 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -3,20 +3,23 @@ TorchOpt -------- -**TorchOpt** is a high-performance optimizer library built upon `PyTorch `_ for easy implementation of functional optimization and gradient-based meta-learning. It consists of two main features: +**TorchOpt** is an efficient library for differentiable optimization built upon `PyTorch `_. +Torchopt is -* TorchOpt provides functional optimizer which enables `JAX-like `_ composable functional optimizer for PyTorch. With TorchOpt, one can easily conduct neural network optimization in PyTorch with functional style optimizer, similar to `Optax `_ in JAX. -* With the design of functional programming, TorchOpt provides efficient, flexible, and easy-to-implement differentiable optimizer for gradient-based meta-learning research. It largely reduces the efforts required to implement sophisticated meta-learning algorithms. +- **Comprehensive**: TorchOpt provides three differentiation modes - explicit differentiation, implicit differentiation, and zero-order differentiation for handling different differentiable optimization situations. +- **Flexible**: TorchOpt provides both functional and objective-oriented API for users different preferences. Users can implement differentiable optimization in JAX-like or PyTorch-like style. +- **Efficient**: TorchOpt provides (1) CPU/GPU acceleration differentiable optimizer (2) RPC-based distributed training framework (3) Fast Tree Operations, to largely increase the training efficiency for bi-level optimization problems. Installation ------------ Requirements: -* `PyTorch `_ -* (Optional) `Graphviz `_ +- `PyTorch `_ +- (Optional) `Graphviz `_ -Please follow the instructions at https://pytorch.org to install PyTorch in your Python environment first. Then run the following command to install TorchOpt from PyPI: +Please follow the instructions at https://pytorch.org to install PyTorch in your Python environment first. +Then run the following command to install TorchOpt from PyPI: .. code-block:: bash @@ -30,7 +33,8 @@ You can also build shared libraries from source, use: cd torchopt pip3 install . -We provide a `conda `_ environment recipe to install the build toolchain such as `cmake`, `g++`, and `nvcc`: +We provide a `conda `_ environment recipe to install the build toolchain such as ``cmake``, ``g++``, and ``nvcc``. +You can use the following commands with `conda `_ / `mamba `_ to create a new isolated environment. .. code-block:: bash @@ -42,21 +46,30 @@ We provide a `conda `_ environment recipe to ins conda activate torchopt +.. toctree:: + :maxdepth: 1 + :caption: Documentation + + basics/basics.rst + optimizer/optim.rst + explicit_diff/explicit_diff.rst + implicit_diff/implicit_diff.rst + zero_order_diff/zero_order_diff.rst + distributed/distributed.rst + visualization/visualization.rst .. toctree:: - :caption: Getting Started + :caption: Tutorial Notebooks :maxdepth: 1 torchopt101/torchopt-101.rst - .. toctree:: :caption: Examples :maxdepth: 1 examples/MAML.rst - .. toctree:: :caption: Developer Documentation :maxdepth: 1 @@ -75,12 +88,12 @@ The Team TorchOpt is a work by -* Jie Ren (`JieRen98 `_) -* Xidong Feng (`waterhorse1 `_) -* Bo Liu (`Benjamin-eecs `_) -* Xuehai Pan (`XuehaiPan `_) -* Luo Mai (`luomai `_) -* Yaodong Yang (`PKU-YYang `_). +- Jie Ren (`JieRen98 `_) +- Xidong Feng (`waterhorse1 `_) +- Bo Liu (`Benjamin-eecs `_) +- Xuehai Pan (`XuehaiPan `_) +- Luo Mai (`luomai `_) +- Yaodong Yang (`PKU-YYang `_). Support ------- @@ -114,6 +127,6 @@ If you find TorchOpt useful, please cite it in your publications. Indices and tables -================== +------------------ -* :ref:`genindex` +- :ref:`genindex` diff --git a/docs/source/optimizer/optim.rst b/docs/source/optimizer/optim.rst new file mode 100644 index 00000000..850bc8c7 --- /dev/null +++ b/docs/source/optimizer/optim.rst @@ -0,0 +1,193 @@ +Optimizers +========== + +.. currentmodule:: torchopt + +The core design of TorchOpt follows the philosophy of functional programming. +Aligned with |functorch|_, users can conduct functional-style programming with models, optimizers, and training in PyTorch. +We first introduce our functional optimizers, which treat the optimization process as a functional transformation. + +.. |functorch| replace:: ``functorch`` +.. _functorch: https://pytorch.org/functorch + +Functional Optimizers +--------------------- + +Currently, TorchOpt supports 4 functional optimizers: :func:`sgd`, :func:`adam`, :func:`rmsprop`, and :func:`adamw`. + +.. autosummary:: + + torchopt.FuncOptimizer + torchopt.adam + torchopt.sgd + torchopt.rmsprop + torchopt.adamw + +Apply Parameter Updates +----------------------- + +TorchOpt offers functional API by passing gradients and optimizer states to the optimizer function to apply updates. + +.. autosummary:: + + torchopt.apply_updates + +Here is an example of functional optimization coupled with |functorch|_: + +.. code-block:: python + + class Net(nn.Module): ... + + class Loader(DataLoader): ... + + net = Net() # init + loader = Loader() + optimizer = torchopt.adam(lr) + + model, params = functorch.make_functional(net) # use functorch extract network parameters + opt_state = optimizer.init(params) # init optimizer + + xs, ys = next(loader) # get data + pred = model(params, xs) # forward + loss = F.cross_entropy(pred, ys) # compute loss + + grads = torch.autograd.grad(loss, params) # compute gradients + updates, opt_state = optimizer.update(grads, opt_state) # get updates + params = torchopt.apply_updates(params, updates) # update network parameters + +We also provide a wrapper :class:`torchopt.FuncOptimizer` to make maintaining the optimizer state easier: + +.. code-block:: python + + net = Net() # init + loader = Loader() + optimizer = torchopt.FuncOptimizer(torchopt.adam()) # wrap with `torchopt.FuncOptimizer` + + model, params = functorch.make_functional(net) # use functorch extract network parameters + + for xs, ys in loader: # get data + pred = model(params, xs) # forward + loss = F.cross_entropy(pred, ys) # compute loss + + params = optimizer.step(loss, params) # update network parameters + +Classic OOP Optimizers +---------------------- + +Combined with the functional optimizers above, we can define our classic OOP optimizers. +We designed a base class :class:`torchopt.Optimizer` that has the same interface as |torch.optim.Optimizer|_. +We offer original PyTorch APIs (e.g., ``zero_grad()`` or ``step()``) for traditional PyTorch-like (OOP) parameter update. + +.. |torch.optim.Optimizer| replace:: ``torch.optim.Optimizer`` +.. _torch.optim.Optimizer: https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer + +.. autosummary:: + + torchopt.Optimizer + torchopt.Adam + torchopt.SGD + torchopt.RMSProp + torchopt.AdamW + +By combining low-level API :class:`torchopt.Optimizer` with the previous functional optimizer, we can achieve high-level API: + +.. code-block:: python + + learning_rate = 1.0 + # High-level API + optim = torchopt.Adam(net.parameters(), lr=learning_rate) + # which can be achieved by low-level API: + optim = torchopt.Optimizer(net.parameters(), torchopt.adam(lr=learning_rate)) + +Here is an example of PyTorch-like APIs: + +.. code-block:: python + + net = Net() # init + loader = Loader() + optimizer = torchopt.Adam(net.parameters()) + + xs, ys = next(loader) # get data + pred = net(xs) # forward + loss = F.cross_entropy(pred, ys) # compute loss + + optimizer.zero_grad() # zero gradients + loss.backward() # backward + optimizer.step() # step updates + +Combining Transformation +------------------------ + +Users always need to conduct multiple gradient transformations (functions) before the final update. +In the designing of TorchOpt, we treat these functions as derivations of :func:`torchopt.chain`. +So we can build our own chain like ``torchopt.chain(torchopt.clip_grad_norm(max_norm=1.), torchopt.sgd(lr=1., moment_requires_grad=True))`` to clip the gradient and update parameters using :func:`sgd`. + +.. autosummary:: + + torchopt.chain + +.. note:: + + :func:`torchopt.chain` will sequentially conduct transformations, so the order matters. + For example, we need to first conduct gradient normalization and then conduct the optimizer step. + The order should be (clip, sgd) in :func:`torchopt.chain` function. + + +Here is an example of chaining :func:`torchopt.clip_grad_norm` and :func:`torchopt.adam` for functional optimizer and OOP optimizer. + +.. code-block:: python + + func_optimizer = torchopt.chain(torchopt.clip_grad_norm(max_norm=2.0), torchopt.adam(1e-1)) + oop_optimizer = torchopt.Optimizer(net.parameters() func_optimizer) + +Optimizer Hooks +--------------- + +Users can also add optimizer hook to control the gradient flow. + +.. autosummary:: + + torchopt.hook.register_hook + torchopt.hook.zero_nan_hook + torchopt.hook.nan_to_num_hook + +For example, :func:`torchopt.hook.zero_nan_hook` registers hook to the first-order gradients. +During the backpropagation, the **NaN** gradients will be set to 0. +Here is an example of such operation coupled with :func:`torchopt.chain`. + +.. code-block:: python + + impl = torchopt.chain(torchopt.hook.register_hook(torchopt.hook.zero_nan_hook), torchopt.adam(1e-1)) + +Optimizer Schedules +------------------- + +TorchOpt also provides implementations of learning rate schedulers, which can be used to control the learning rate during the training process. +TorchOpt mainly offers the linear learning rate scheduler and the polynomial learning rate scheduler. + +.. autosummary:: + + torchopt.schedule.linear_schedule + torchopt.schedule.polynomial_schedule + +Here is an example of combining optimizer with learning rate scheduler. + +.. code-block:: python + + functional_adam = torchopt.adam( + lr=torchopt.schedule.linear_schedule( + init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000 + ) + ) + + adam = torchopt.Adam( + net.parameters(), + lr=torchopt.schedule.linear_schedule( + init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000 + ), + ) + +Notebook Tutorial +----------------- + +Check the notebook tutorial at `Functional Optimizer `_. diff --git a/docs/source/spelling_wordlist.txt b/docs/source/spelling_wordlist.txt index e76966ef..aac17046 100644 --- a/docs/source/spelling_wordlist.txt +++ b/docs/source/spelling_wordlist.txt @@ -59,9 +59,11 @@ Graphviz Autograd autograd attrs +GradientTransformation GradientTransformations args kwargs +kwds chainable adam Adam @@ -78,6 +80,7 @@ Moens AdamW Loshchilov pytree +pytrees booleans subtrees optimality @@ -107,7 +110,13 @@ broadcasted keepdim ndim partitioner +partitioners RPC +rpc +MPI +async +parallelization +unaggregated maxiter str bool @@ -137,6 +146,7 @@ pre numerics parallelize parallelizing +JAX Optax func subfn @@ -144,3 +154,21 @@ vjp jvp ATen samplable +conj +TransformInitFn +TransformUpdateFn +argmin +Jacobian +autodiff +backend +reparametrize +reparameterize +rtype +backpropagate +NaN +iteratively +issubclass +abc +ABCMeta +subclasscheck +ctx diff --git a/docs/source/visualization/visualization.rst b/docs/source/visualization/visualization.rst new file mode 100644 index 00000000..718c6725 --- /dev/null +++ b/docs/source/visualization/visualization.rst @@ -0,0 +1,146 @@ +Visualization +============= + +.. currentmodule:: torchopt.visual + +In `PyTorch `_, if the attribute ``requires_grad`` of a tensor is :data:`True`, the computation graph will be created if we use the tensor to do any operations. +The computation graph is implemented like a link list -- ``Tensors`` are nodes and they are linked by their attribute ``gran_fn``. +`PyTorchViz `_ is a Python package that uses `Graphviz `_ as a backend for plotting computation graphs. +TorchOpt uses PyTorchViz as the blueprint and provides more easy-to-use visualization functions on the premise of supporting all its functions. + +------ + +Usage +----- + +Let's start with a simple multiplication computation graph. +We declared the variable ``x`` with the flag ``requires_grad=True`` and compute ``y = 2 * x``. Then we visualize the computation graph of ``y``. + +We provide the function :func:`make_dot` which takes a tensor as input. +The visualization code is shown as follows: + +.. code-block:: python + + from IPython.display import display + import torch + import torchopt + + + x = torch.tensor(1.0, requires_grad=True) + y = 2 * x + display(torchopt.visual.make_dot(y)) + +.. image:: /_static/images/visualization-fig1.svg + :width: 20% + :align: center + +The figure shows ``y`` is connected by the multiplication edge. +The gradient of ``y`` will flow through the multiplication backward function and then accumulate on ``x``. +Note that we pass a dictionary for adding node labels. + +To add auxiliary notes to the computation graph, we can pass a dictionary as argument ``params`` to :func:`make_dot`. +The keys are the notes which would be shown in the computation figure and the values are the tensors that need to be noted. +So the code above can be modified as follows: + +.. code-block:: python + + from IPython.display import display + import torch + import torchopt + + + x = torch.tensor(1.0, requires_grad=True) + y = 2 * x + display(torchopt.visual.make_dot(y, params={'x': x, 'y': y})) + +Then let's plot a neural network. +Note that we can pass the generator returned by the method ``named_parameters`` for adding node labels. + +.. code-block:: python + + from IPython.display import display + import torch + from torch import nn + import torchopt + + + class Net(nn.Module): + def __init__(self, dim): + super().__init__() + self.fc = nn.Linear(dim, 1, bias=True) + + def forward(self, x): + return self.fc(x) + + + dim = 5 + batch_size = 2 + net = Net(dim) + xs = torch.ones((batch_size, dim)) + ys = torch.ones((batch_size, 1)) + pred = net(xs) + loss = F.mse_loss(pred, ys) + + display(torchopt.visual.make_dot(loss, params=(net.named_parameters(), {'loss': loss}))) + +.. image:: /_static/images/visualization-fig2.svg + :width: 45% + :align: center + +The computation graph of meta-learning algorithms will be much more complex. +Our visualization tool allows users to take as input the extracted network state for better visualization. + +.. code-block:: python + + from IPython.display import display + import torch + from torch import nn + import torchopt + + class MetaNet(nn.Module): + def __init__(self, dim): + super().__init__() + self.fc = nn.Linear(dim, 1, bias=True) + + def forward(self, x, meta_param): + return self.fc(x) + meta_param + + + dim = 5 + batch_size = 2 + net = MetaNet(dim) + + xs = torch.ones((batch_size, dim)) + ys = torch.ones((batch_size, 1)) + + optimizer = torchopt.MetaSGD(net, lr=1e-3) + meta_param = torch.tensor(1.0, requires_grad=True) + + # Set enable_visual + net_state_0 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step0.') + + pred = net(xs, meta_param) + loss = F.mse_loss(pred, ys) + optimizer.step(loss) + + # Set enable_visual + net_state_1 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step1.') + + pred = net(xs, meta_param) + loss = F.mse_loss(pred, torch.ones_like(pred)) + + # Draw computation graph + display( + torchopt.visual.make_dot( + loss, [net_state_0, net_state_1, {'meta_param': meta_param, 'loss': loss}] + ) + ) + +.. image:: /_static/images/visualization-fig3.svg + :width: 65% + :align: center + +Notebook Tutorial +----------------- + +Check the notebook tutorial at `Visualization `_. diff --git a/docs/source/zero_order_diff/zero_order_diff.rst b/docs/source/zero_order_diff/zero_order_diff.rst new file mode 100644 index 00000000..11232c85 --- /dev/null +++ b/docs/source/zero_order_diff/zero_order_diff.rst @@ -0,0 +1,146 @@ +Zero-order Gradient Differentiation +=================================== + +.. currentmodule:: torchopt.diff.zero_order + +Evolutionary Strategy +--------------------- + +.. image:: /_static/images/zero-order.png + :width: 80% + :align: center + +When the inner-loop process is non-differentiable or one wants to eliminate the heavy computation burdens in the previous two modes (brought by Hessian), one can choose Zeroth-order differentiation。 +Zero-order differentiation typically gets gradients based on zero-order estimation, such as finite-difference, or `Evolutionary Strategy `_ (ES). +`ES-MAML `_ and `NAC `_ successfully solve the non-differentiable optimization problem based on ES. + +TorchOpt offers API for ES-based differentiation. +Instead of optimizing the objective :math:`f (\boldsymbol{\theta}): \mathbb{R}^n \to \mathbb{R}`, ES optimizes a Gaussian smoothing objective defined as :math:`\tilde{f}_{\sigma} (\boldsymbol{\theta}) = \mathbb{E}_{\boldsymbol{z} \sim \mathcal{N}( 0, {I}_d )} [ f (\boldsymbol{\theta} + \sigma \, \boldsymbol{z}) ]`, where :math:`\sigma` denotes the precision. +The gradient of such objective is :math:`\nabla_{\boldsymbol{\theta}} \tilde{f}_{\sigma} (\boldsymbol{\theta}) = \frac{1}{\sigma} \mathbb{E}_{\boldsymbol{z} \sim \mathcal{N}( 0, {I}_d )} [ f (\boldsymbol{\theta} + \sigma \, \boldsymbol{z}) \cdot \boldsymbol{z} ]`. +Based on such technique, one can treat the bi-level process as a whole to calculate the meta-gradient based on pure forward process. +Refer to `ES-MAML `_ for more explanations. + +Decorators +---------- + +.. autosummary:: + + torchopt.diff.zero_order.zero_order + +Similar to the implicit gradient, we also use the decorator for ES methods. + +Functional API +~~~~~~~~~~~~~~ + +The basic functional API is :func:`torchopt.diff.zero_order.zero_order`, which is used as the decorator for the forward process zero-order gradient procedures. +Users are required to implement the noise sampling function, which will be used as the input of the zero_order decorator. +Here we show the specific meaning for each parameter used in the decorator. + +- ``distribution`` for noise sampling distribution. The distribution :math:`\lambda` should be spherical symmetric and with a constant variance of :math:`1` for each element. I.e.: + + - Spherical symmetric: :math:`\mathbb{E}_{\boldsymbol{z} \sim \lambda} [ \boldsymbol{z} ] = \boldsymbol{0}`. + - Constant variance of :math:`1` for each element: :math:`\mathbb{E}_{\boldsymbol{z} \sim \lambda} [ {\lvert z_i \rvert}^2 ] = 1`. + - For example, the standard multi-dimensional normal distribution :math:`\mathcal{N} (\boldsymbol{0}, \boldsymbol{1})`. + +- ``method`` for different kind of algorithms, we support ``'naive'`` (`ES RL `_), ``'forward'`` (`Forward-FD `_), and ``'antithetic'`` (`antithetic `_). + + .. math:: + + \begin{align*} + \text{naive} \qquad & \nabla_{\boldsymbol{\theta}} \tilde{f}_{\sigma} (\boldsymbol{\theta}) = \frac{1}{\sigma} \mathbb{E}_{\boldsymbol{z} \sim \lambda} [ f (\boldsymbol{\theta} + \sigma \, \boldsymbol{z}) \cdot \boldsymbol{z} ] \\ + \text{forward} \qquad & \nabla_{\boldsymbol{\theta}} \tilde{f}_{\sigma} (\boldsymbol{\theta}) = \frac{1}{\sigma} \mathbb{E}_{\boldsymbol{z} \sim \lambda} [ ( f (\boldsymbol{\theta} + \sigma \, \boldsymbol{z}) - f (\boldsymbol{\theta}) ) \cdot \boldsymbol{z} ] \\ + \text{antithetic} \qquad & \nabla_{\boldsymbol{\theta}} \tilde{f}_{\sigma} (\boldsymbol{\theta}) = \frac{1}{2 \sigma} \mathbb{E}_{\boldsymbol{z} \sim \lambda} [ (f (\boldsymbol{\theta} + \sigma \, \boldsymbol{z}) - f (\boldsymbol{\theta} + \sigma \, \boldsymbol{z}) ) \cdot \boldsymbol{z} ] + \end{align*} + +- ``argnums`` specifies which parameter we want to trace the meta-gradient. +- ``num_samples`` specifies how many times we want to conduct the sampling. +- ``sigma`` is for precision. This is the scaling factor for the sampling distribution. + +We show the pseudo code in the following part. + +.. code-block:: python + + # Functional API for zero-order differentiation + # 1. Customize the noise distribution via a distribution class + class Distribution: + def sample(self, sample_shape=torch.Size()): + # Sampling function for noise + # NOTE: The distribution should be spherical symmetric and with a constant variance of 1. + ... + return noise_batch + + distribution = Distribution() + + # 2. Customize the noise distribution via a sampling function + def distribution(sample_shape=torch.Size()): + # Sampling function for noise + # NOTE: The distribution should be spherical symmetric and with a constant variance of 1. + ... + return noise_batch + + # 3. Distribution can also be an instance of `torch.distributions.Distribution`, e.g., `torch.distributions.Normal(...)` + distribution = torch.distributions.Normal(loc=0, scale=1) + + # Decorator that wraps the function + @torchopt.diff.zero_order(distribution=distribution, method='naive', argnums=0, num_samples=100, sigma=0.01) + def forward(params, data): + # Forward optimization process for params + ... + return objective # the returned tensor should be a scalar tensor + + # Define params and get data + params, data = ..., ... + + # Forward pass + loss = forward(params, data) + # Backward pass using zero-order differentiation + grads = torch.autograd.grad(loss, params) + +OOP API +~~~~~~~ + +.. autosummary:: + + torchopt.nn.ZeroOrderGradientModule + +Coupled with PyTorch |torch.nn.Module|_, we also design the OOP API :class:`nn.ZeroOrderGradientModule` for ES. +The core idea of :class:`nn.ZeroOrderGradientModule` is to enable the gradient flow forward process to `self.parameters()` (can be the meta-parameters when calculating meta-gradient). +Users need to define the forward process zero-order gradient procedures ``forward()`` and a noise sampling function ``sample()``. + +.. |torch.nn.Module| replace:: ``torch.nn.Module`` +.. _torch.nn.Module: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module + +.. code-block:: python + + from torchopt.nn import ZeroOrderGradientModule + + # Inherited from the class ZeroOrderGradientModule + # Optionally specify the `method` and/or `num_samples` and/or `sigma` used for sampling + class Net(ZeroOrderGradientModule, method='naive', num_samples=100, sigma=0.01): + def __init__(self, ...): + ... + + def forward(self, batch): + # Forward process + ... + return objective # the returned tensor should be a scalar tensor + + def sample(self, sample_shape=torch.Size()): + # Generate a batch of noise samples + # NOTE: The distribution should be spherical symmetric and with a constant variance of 1. + ... + return noise_batch + + # Get model and data + net = Net(...) + data = ... + + # Forward pass + loss = Net(data) + # Backward pass using zero-order differentiation + grads = torch.autograd.grad(loss, net.parameters()) + +Notebook Tutorial +----------------- + +For more details, check the notebook tutorial at `zero-order `_. diff --git a/examples/FuncTorch/maml_omniglot_vmap.py b/examples/FuncTorch/maml_omniglot_vmap.py index 41c17db8..0933b44d 100644 --- a/examples/FuncTorch/maml_omniglot_vmap.py +++ b/examples/FuncTorch/maml_omniglot_vmap.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -196,7 +196,6 @@ def train(db, net, device, meta_opt, epoch, log): qry_accs = 100.0 * torch.mean(torch.stack(qry_accs)).item() i = epoch + float(batch_idx) / n_train_iter iter_time = time.time() - start_time - torch.cuda.empty_cache() if batch_idx % 4 == 0: print( @@ -249,7 +248,6 @@ def test(db, net, device, epoch, log): qry_losses = torch.mean(torch.stack(qry_losses)).item() qry_accs = 100.0 * torch.mean(torch.stack(qry_accs)).item() - torch.cuda.empty_cache() print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}') log.append( diff --git a/examples/L2R/helpers/utils.py b/examples/L2R/helpers/utils.py index 954b27b2..fe923860 100644 --- a/examples/L2R/helpers/utils.py +++ b/examples/L2R/helpers/utils.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,7 +33,6 @@ def get_imbalance_dataset( class_0=4, class_1=9, ): - ratio = 1 - pos_ratio ratio_test = 0.5 @@ -116,7 +115,7 @@ def get_imbalance_dataset( x_test_subset = x_test_subset[idx].astype(np.float32) y_test_subset = y_test_subset[idx].astype(np.float32) - (x_train_subset, y_train_subset, x_val_subset, y_val_subset, x_test_subset, y_test_subset,) = ( + x_train_subset, y_train_subset, x_val_subset, y_val_subset, x_test_subset, y_test_subset = ( torch.tensor(x_train_subset), torch.tensor(y_train_subset), torch.tensor(x_val_subset), diff --git a/examples/L2R/l2r.py b/examples/L2R/l2r.py index e77faa14..5ce4839d 100644 --- a/examples/L2R/l2r.py +++ b/examples/L2R/l2r.py @@ -36,9 +36,6 @@ from torchvision.datasets import MNIST import torchopt - - -# isort: off from helpers.argument import parse_args from helpers.model import LeNet5 from helpers.utils import get_imbalance_dataset, plot, set_seed diff --git a/examples/LOLA/helpers/agent.py b/examples/LOLA/helpers/agent.py index 3b37daf2..a8f8ee31 100644 --- a/examples/LOLA/helpers/agent.py +++ b/examples/LOLA/helpers/agent.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,7 +30,6 @@ def __init__(self, theta): class Agent: def __init__(self, args): - self.args = args # init theta and its optimizer self.theta = nn.Parameter(torch.zeros(5, requires_grad=True)) diff --git a/examples/LOLA/lola_dice.py b/examples/LOLA/lola_dice.py index 4b6b2567..20c0ff0e 100644 --- a/examples/LOLA/lola_dice.py +++ b/examples/LOLA/lola_dice.py @@ -19,8 +19,6 @@ import numpy as np import torch - -# isort: off from helpers.agent import Agent from helpers.argument import parse_args from helpers.env import IPD diff --git a/examples/MAML-RL/func_maml.py b/examples/MAML-RL/func_maml.py index 6413cc71..2534caeb 100644 --- a/examples/MAML-RL/func_maml.py +++ b/examples/MAML-RL/func_maml.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -156,7 +156,6 @@ def main(args): param_orig = [p.detach().clone().requires_grad_() for p in params] _params = list(params) for idx in range(TASK_NUM): - for _ in range(inner_iters): pre_trajs = sample_traj(env, tasks[idx], fpolicy, _params) inner_loss = a2c_loss(pre_trajs, fpolicy, _params, value_coef=0.5) diff --git a/examples/MAML-RL/maml.py b/examples/MAML-RL/maml.py index 447f540e..d4aa8c3c 100644 --- a/examples/MAML-RL/maml.py +++ b/examples/MAML-RL/maml.py @@ -22,9 +22,7 @@ import torch.optim as optim import torchopt - - -from helpers.policy import CategoricalMLPPolicy # isort: skip +from helpers.policy import CategoricalMLPPolicy TASK_NUM = 40 diff --git a/examples/MAML-RL/maml_torchrl.py b/examples/MAML-RL/maml_torchrl.py index 9d1bfe56..3cb72b49 100644 --- a/examples/MAML-RL/maml_torchrl.py +++ b/examples/MAML-RL/maml_torchrl.py @@ -25,9 +25,7 @@ from torchrl.objectives.returns.functional import td_lambda_advantage_estimate import torchopt - - -from helpers.policy_torchrl import ActorCritic # isort: skip +from helpers.policy_torchrl import ActorCritic TASK_NUM = 40 diff --git a/examples/distributed/few-shot/helpers/omniglot_loaders.py b/examples/distributed/few-shot/helpers/omniglot_loaders.py index d857d386..e8f02042 100644 --- a/examples/distributed/few-shot/helpers/omniglot_loaders.py +++ b/examples/distributed/few-shot/helpers/omniglot_loaders.py @@ -118,7 +118,7 @@ def download(self): def find_classes(root_dir): retour = [] - for (root, dirs, files) in os.walk(root_dir): + for root, dirs, files in os.walk(root_dir): for f in files: if f.endswith('png'): r = root.split('/') @@ -170,7 +170,7 @@ def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, rng, device=Non # {label: [img1, img2..., img20], label2: [img1, img2, ...], ... 1623 labels in total} temp = {} - for (img, label) in self.x: + for img, label in self.x: if label in temp.keys(): temp[label].append(img) else: @@ -255,15 +255,12 @@ def load_data_cache(self, data_pack): # print('preload next 50 caches of batchsz of batch.') for sample in range(10): # num of episodes - x_spts, y_spts, x_qrys, y_qrys = [], [], [], [] for i in range(self.batchsz): # one batch means one set - x_spt, y_spt, x_qry, y_qry = [], [], [], [] selected_cls = self.rng.choice(data_pack.shape[0], self.n_way, False) for j, cur_class in enumerate(selected_cls): - selected_img = self.rng.choice(20, self.k_shot + self.k_query, False) # meta-training and meta-test diff --git a/examples/distributed/few-shot/maml_omniglot.py b/examples/distributed/few-shot/maml_omniglot.py index 879792ff..867caf43 100644 --- a/examples/distributed/few-shot/maml_omniglot.py +++ b/examples/distributed/few-shot/maml_omniglot.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -56,9 +56,7 @@ import torchopt import torchopt.distributed as todist - - -from helpers.omniglot_loaders import OmniglotNShot # isort: skip +from helpers.omniglot_loaders import OmniglotNShot mpl.use('Agg') @@ -231,7 +229,6 @@ def train(db: OmniglotNShot, net: nn.Module, meta_opt: optim.Adam, epoch: int, l qry_acc = 100.0 * qry_acc i = epoch + float(batch_idx) / n_train_iter iter_time = time.time() - start_time - torch.cuda.empty_cache() print( f'[Epoch {i:.2f}] Train Loss: {qry_loss:.2f} | Acc: {qry_acc:.2f} | Time: {iter_time:.2f}' @@ -274,7 +271,6 @@ def test(db, net, epoch, log): qry_losses = np.mean(qry_losses) qry_accs = 100.0 * np.mean(qry_accs) - torch.cuda.empty_cache() print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}') log.append( diff --git a/examples/distributed/few-shot/maml_omniglot_local_loader.py b/examples/distributed/few-shot/maml_omniglot_local_loader.py index f7f9e4f0..7f042854 100644 --- a/examples/distributed/few-shot/maml_omniglot_local_loader.py +++ b/examples/distributed/few-shot/maml_omniglot_local_loader.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -58,9 +58,7 @@ import torchopt import torchopt.distributed as todist - - -from helpers.omniglot_loaders import OmniglotNShot # isort: skip +from helpers.omniglot_loaders import OmniglotNShot mpl.use('Agg') @@ -274,7 +272,6 @@ def train(net: nn.Module, meta_opt: optim.Adam, epoch: int, log: list): qry_acc = 100.0 * qry_acc i = epoch + float(batch_idx) / n_train_iter iter_time = time.time() - start_time - torch.cuda.empty_cache() print( f'[Epoch {i:.2f}] Train Loss: {qry_loss:.2f} | Acc: {qry_acc:.2f} | Time: {iter_time:.2f}' @@ -318,7 +315,6 @@ def test(net, epoch, log): qry_losses = np.mean(qry_losses) qry_accs = 100.0 * np.mean(qry_accs) - torch.cuda.empty_cache() print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}') log.append( diff --git a/examples/few-shot/helpers/omniglot_loaders.py b/examples/few-shot/helpers/omniglot_loaders.py index d857d386..e8f02042 100644 --- a/examples/few-shot/helpers/omniglot_loaders.py +++ b/examples/few-shot/helpers/omniglot_loaders.py @@ -118,7 +118,7 @@ def download(self): def find_classes(root_dir): retour = [] - for (root, dirs, files) in os.walk(root_dir): + for root, dirs, files in os.walk(root_dir): for f in files: if f.endswith('png'): r = root.split('/') @@ -170,7 +170,7 @@ def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, rng, device=Non # {label: [img1, img2..., img20], label2: [img1, img2, ...], ... 1623 labels in total} temp = {} - for (img, label) in self.x: + for img, label in self.x: if label in temp.keys(): temp[label].append(img) else: @@ -255,15 +255,12 @@ def load_data_cache(self, data_pack): # print('preload next 50 caches of batchsz of batch.') for sample in range(10): # num of episodes - x_spts, y_spts, x_qrys, y_qrys = [], [], [], [] for i in range(self.batchsz): # one batch means one set - x_spt, y_spt, x_qry, y_qry = [], [], [], [] selected_cls = self.rng.choice(data_pack.shape[0], self.n_way, False) for j, cur_class in enumerate(selected_cls): - selected_img = self.rng.choice(20, self.k_shot + self.k_query, False) # meta-training and meta-test diff --git a/examples/few-shot/maml_omniglot.py b/examples/few-shot/maml_omniglot.py index 879a235a..17172bdd 100644 --- a/examples/few-shot/maml_omniglot.py +++ b/examples/few-shot/maml_omniglot.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -52,9 +52,7 @@ import torch.optim as optim import torchopt - - -from helpers.omniglot_loaders import OmniglotNShot # isort: skip +from helpers.omniglot_loaders import OmniglotNShot mpl.use('Agg') @@ -178,7 +176,6 @@ def train(db, net, meta_opt, epoch, log): qry_accs = 100.0 * np.mean(qry_accs) i = epoch + float(batch_idx) / n_train_iter iter_time = time.time() - start_time - torch.cuda.empty_cache() print( f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}' @@ -239,7 +236,6 @@ def test(db, net, epoch, log): qry_losses = np.mean(qry_losses) qry_accs = 100.0 * np.mean(qry_accs) - torch.cuda.empty_cache() print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}') log.append( diff --git a/examples/iMAML/helpers/omniglot_loaders.py b/examples/iMAML/helpers/omniglot_loaders.py index d857d386..e8f02042 100644 --- a/examples/iMAML/helpers/omniglot_loaders.py +++ b/examples/iMAML/helpers/omniglot_loaders.py @@ -118,7 +118,7 @@ def download(self): def find_classes(root_dir): retour = [] - for (root, dirs, files) in os.walk(root_dir): + for root, dirs, files in os.walk(root_dir): for f in files: if f.endswith('png'): r = root.split('/') @@ -170,7 +170,7 @@ def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, rng, device=Non # {label: [img1, img2..., img20], label2: [img1, img2, ...], ... 1623 labels in total} temp = {} - for (img, label) in self.x: + for img, label in self.x: if label in temp.keys(): temp[label].append(img) else: @@ -255,15 +255,12 @@ def load_data_cache(self, data_pack): # print('preload next 50 caches of batchsz of batch.') for sample in range(10): # num of episodes - x_spts, y_spts, x_qrys, y_qrys = [], [], [], [] for i in range(self.batchsz): # one batch means one set - x_spt, y_spt, x_qry, y_qry = [], [], [], [] selected_cls = self.rng.choice(data_pack.shape[0], self.n_way, False) for j, cur_class in enumerate(selected_cls): - selected_img = self.rng.choice(20, self.k_shot + self.k_query, False) # meta-training and meta-test diff --git a/examples/iMAML/imaml_omniglot.py b/examples/iMAML/imaml_omniglot.py index 2b0c9738..09344900 100644 --- a/examples/iMAML/imaml_omniglot.py +++ b/examples/iMAML/imaml_omniglot.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,12 +33,10 @@ import torch.nn.functional as F import torchopt +from helpers.omniglot_loaders import OmniglotNShot from torchopt.diff.implicit import ImplicitMetaGradientModule -from helpers.omniglot_loaders import OmniglotNShot # isort: skip - - mpl.use('Agg') plt.style.use('bmh') @@ -53,6 +51,13 @@ def __init__(self, meta_net, n_inner_iter, reg_param): self.net = torchopt.module_clone(meta_net, by='deepcopy', detach_buffers=True) self.n_inner_iter = n_inner_iter self.reg_param = reg_param + self.reset_parameters() + + def reset_parameters(self): + with torch.no_grad(): + for p1, p2 in zip(self.parameters(), self.meta_parameters()): + p1.data.copy_(p2.data) + p1.detach_().requires_grad_() def forward(self, x): return self.net(x) @@ -147,21 +152,16 @@ def main(): def train(db, net, meta_opt, epoch, log, args): n_train_iter = db.x_train.shape[0] // db.batchsz - # Given this module we've created, rip out the parameters and buffers - # and return a functional version of the module. `fnet` is stateless - # and can be called with `fnet(params, buffers, args, kwargs)` - # fnet, params, buffers = functorch.make_functional_with_buffers(net) + n_inner_iter = args.inner_steps + reg_param = args.reg_params + task_num = args.task_num + inner_nets = [InnerNet(net, n_inner_iter, reg_param) for _ in range(task_num)] for batch_idx in range(n_train_iter): start_time = time.time() # Sample a batch of support and query images and labels. x_spt, y_spt, x_qry, y_qry = db.next() - task_num = x_spt.size(0) - - n_inner_iter = args.inner_steps - reg_param = args.reg_params - qry_losses = [] qry_accs = [] meta_opt.zero_grad() @@ -171,7 +171,8 @@ def train(db, net, meta_opt, epoch, log, args): # gradient steps w.r.t. the model's parameters. # This adapts the model's meta-parameters to the task. - inner_net = InnerNet(net, n_inner_iter, reg_param) + inner_net = inner_nets[i] + inner_net.reset_parameters() optimal_inner_net = inner_net.solve(x_spt[i], y_spt[i]) # The final set of adapted parameters will induce some @@ -190,7 +191,6 @@ def train(db, net, meta_opt, epoch, log, args): qry_accs = 100.0 * np.mean(qry_accs) i = epoch + float(batch_idx) / n_train_iter iter_time = time.time() - start_time - torch.cuda.empty_cache() print( f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}' @@ -245,7 +245,6 @@ def test(db, net, epoch, log, args): qry_losses = np.mean(qry_losses) qry_accs = 100.0 * np.mean(qry_accs) - torch.cuda.empty_cache() print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}') log.append( diff --git a/examples/iMAML/imaml_omniglot_functional.py b/examples/iMAML/imaml_omniglot_functional.py index 88314366..1c0a089a 100644 --- a/examples/iMAML/imaml_omniglot_functional.py +++ b/examples/iMAML/imaml_omniglot_functional.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -34,12 +34,10 @@ import torch.nn.functional as F import torchopt +from helpers.omniglot_loaders import OmniglotNShot from torchopt import pytree -from helpers.omniglot_loaders import OmniglotNShot # isort: skip - - mpl.use('Agg') plt.style.use('bmh') @@ -167,7 +165,6 @@ def train(db, model, meta_opt_and_state, epoch, log, args): qry_accs = 100.0 * np.mean(qry_accs) i = epoch + float(batch_idx) / n_train_iter iter_time = time.time() - start_time - torch.cuda.empty_cache() print( f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}' @@ -229,7 +226,6 @@ def test(db, model, epoch, log, args): qry_losses = np.mean(qry_losses) qry_accs = 100.0 * np.mean(qry_accs) - torch.cuda.empty_cache() print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}') log.append( diff --git a/include/adam_op/adam_op.h b/include/adam_op/adam_op.h index 8b7ae2bf..a49b0a06 100644 --- a/include/adam_op/adam_op.h +++ b/include/adam_op/adam_op.h @@ -1,4 +1,4 @@ -// Copyright 2022 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2023 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ // ============================================================================= #pragma once + #include #include @@ -67,9 +68,10 @@ TensorArray<2> adamBackwardUpdates(const torch::Tensor &dupdates, const torch::Tensor &new_nu, const pyfloat_t b1, const pyfloat_t b2, + const pyfloat_t eps_root, const pyuint_t count); -void buildSubmodule(py::module &mod); // NOLINT +void buildSubmodule(py::module &mod); // NOLINT[runtime/references] } // namespace adam_op } // namespace torchopt diff --git a/include/adam_op/adam_op_impl_cpu.h b/include/adam_op/adam_op_impl_cpu.h index 3e8da376..37aba528 100644 --- a/include/adam_op/adam_op_impl_cpu.h +++ b/include/adam_op/adam_op_impl_cpu.h @@ -1,4 +1,4 @@ -// Copyright 2022 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2023 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ // ============================================================================= #pragma once + #include #include @@ -63,6 +64,7 @@ TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor &dupdates, const torch::Tensor &new_nu, const pyfloat_t b1, const pyfloat_t b2, + const pyfloat_t eps_root, const pyuint_t count); } // namespace adam_op } // namespace torchopt diff --git a/include/adam_op/adam_op_impl_cuda.cuh b/include/adam_op/adam_op_impl_cuda.cuh index a7ddb937..6e661564 100644 --- a/include/adam_op/adam_op_impl_cuda.cuh +++ b/include/adam_op/adam_op_impl_cuda.cuh @@ -1,4 +1,4 @@ -// Copyright 2022 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2023 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ // ============================================================================= #pragma once + #include #include @@ -63,6 +64,7 @@ TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor &dupdates, const torch::Tensor &new_nu, const pyfloat_t b1, const pyfloat_t b2, + const pyfloat_t eps_root, const pyuint_t count); } // namespace adam_op } // namespace torchopt diff --git a/include/common.h b/include/common.h index 5353e48e..65f9ef33 100644 --- a/include/common.h +++ b/include/common.h @@ -1,4 +1,4 @@ -// Copyright 2022 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2023 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ // ============================================================================= #pragma once + #include #include diff --git a/include/utils.h b/include/utils.h index 714f98d4..0ef98539 100644 --- a/include/utils.h +++ b/include/utils.h @@ -1,4 +1,4 @@ -// Copyright 2022 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2023 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ // ============================================================================= #pragma once + #include #include diff --git a/pyproject.toml b/pyproject.toml index f3e917af..12fd6fe3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,11 +72,10 @@ lint = [ "black[jupyter] >= 22.6.0", "pylint[spelling] >= 2.15.0", "mypy >= 0.990", - "types-setuptools", "flake8", "flake8-bugbear", - "doc8 < 1.0.0a0", - "pydocstyle", + "doc8 < 1.0.0a0", # unpin this when we drop support for Python 3.7 + "pydocstyle[toml]", "pyenchant", "cpplint", "pre-commit", @@ -209,3 +208,9 @@ convention = "google" [tool.doc8] max-line-length = 500 + +[tool.pytest.ini_options] +filterwarnings = [ + "error", + 'ignore:Explicitly requested dtype float64 requested in .* is not available, and will be truncated to dtype float32\.:UserWarning', +] diff --git a/setup.py b/setup.py index 75f32750..0297d43e 100644 --- a/setup.py +++ b/setup.py @@ -77,17 +77,17 @@ def build_extension(self, ext): and hasattr(self, 'parallel') and self.parallel ): - build_args.append(f'--parallel={self.parallel}') + build_args.extend(['--parallel', str(self.parallel)]) else: build_args.append('--parallel') - build_args.extend([f'--target={ext.target}', '--']) + build_args.extend(['--target', ext.target, '--']) try: os.chdir(build_temp) - self.spawn(['cmake', ext.source_dir] + cmake_args) + self.spawn([cmake, ext.source_dir] + cmake_args) if not self.dry_run: - self.spawn(['cmake', '--build', '.'] + build_args) + self.spawn([cmake, '--build', '.'] + build_args) finally: os.chdir(HERE) diff --git a/src/adam_op/adam_op.cpp b/src/adam_op/adam_op.cpp index 18bb5d27..08c9fb74 100644 --- a/src/adam_op/adam_op.cpp +++ b/src/adam_op/adam_op.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2023 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -104,11 +104,11 @@ TensorArray<2> adamBackwardMu(const torch::Tensor &dmu, const pyfloat_t b1) { #if defined(__USE_CUDA__) if (dmu.device().is_cuda()) { - return adamBackwardMuCUDA(dmu, updates, mu, b1); + return adamBackwardMuCUDA(dmu.contiguous(), updates, mu, b1); } #endif if (dmu.device().is_cpu()) { - return adamBackwardMuCPU(dmu, updates, mu, b1); + return adamBackwardMuCPU(dmu.contiguous(), updates, mu, b1); } else { throw std::runtime_error("Not implemented"); } @@ -120,11 +120,11 @@ TensorArray<2> adamBackwardNu(const torch::Tensor &dnu, const pyfloat_t b2) { #if defined(__USE_CUDA__) if (dnu.device().is_cuda()) { - return adamBackwardNuCUDA(dnu, updates, nu, b2); + return adamBackwardNuCUDA(dnu.contiguous(), updates, nu, b2); } #endif if (dnu.device().is_cpu()) { - return adamBackwardNuCPU(dnu, updates, nu, b2); + return adamBackwardNuCPU(dnu.contiguous(), updates, nu, b2); } else { throw std::runtime_error("Not implemented"); } @@ -136,20 +136,23 @@ TensorArray<2> adamBackwardUpdates(const torch::Tensor &dupdates, const torch::Tensor &new_nu, const pyfloat_t b1, const pyfloat_t b2, + const pyfloat_t eps_root, const pyuint_t count) { #if defined(__USE_CUDA__) if (dupdates.device().is_cuda()) { - return adamBackwardUpdatesCUDA(dupdates, updates, new_mu, new_nu, b1, b2, count); + return adamBackwardUpdatesCUDA( + dupdates.contiguous(), updates, new_mu, new_nu, b1, b2, eps_root, count); } #endif if (dupdates.device().is_cpu()) { - return adamBackwardUpdatesCPU(dupdates, updates, new_mu, new_nu, b1, b2, count); + return adamBackwardUpdatesCPU( + dupdates.contiguous(), updates, new_mu, new_nu, b1, b2, eps_root, count); } else { throw std::runtime_error("Not implemented"); } } -void buildSubmodule(py::module &mod) { // NOLINT +void buildSubmodule(py::module &mod) { // NOLINT[runtime/references] py::module m = mod.def_submodule("adam_op", "Adam Ops"); m.def("forward_", &adamForwardInplace, @@ -207,6 +210,7 @@ void buildSubmodule(py::module &mod) { // NOLINT py::arg("new_nu"), py::arg("b1"), py::arg("b2"), + py::arg("eps_root"), py::arg("count")); } diff --git a/src/adam_op/adam_op_impl_cpu.cpp b/src/adam_op/adam_op_impl_cpu.cpp index cf734c4f..b9c14e49 100644 --- a/src/adam_op/adam_op_impl_cpu.cpp +++ b/src/adam_op/adam_op_impl_cpu.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2023 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -40,9 +40,8 @@ void adamForwardInplaceCPUKernel(const other_t b1, scalar_t *__restrict__ updates_ptr, scalar_t *__restrict__ mu_ptr, scalar_t *__restrict__ nu_ptr) { -#pragma omp parallel for num_threads( \ - std::min(n / MIN_NUMEL_USE_OMP, \ - static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT +#pragma omp parallel for num_threads(std::min( \ + n / MIN_NUMEL_USE_OMP, static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) for (size_t tid = 0; tid < n; ++tid) { const scalar_t updates = updates_ptr[tid]; const scalar_t mu = mu_ptr[tid]; @@ -50,8 +49,10 @@ void adamForwardInplaceCPUKernel(const other_t b1, const scalar_t mu_out = b1 * mu + (1 - b1) * updates; const scalar_t nu_out = b2 * nu + (1 - b2) * updates * updates; - const scalar_t updates_out = - mu_out * inv_one_minus_pow_b1 / (sqrt(nu_out * inv_one_minus_pow_b2 + eps_root) + eps); + const scalar_t mu_hat = mu_out * inv_one_minus_pow_b1; + const scalar_t nu_hat = nu_out * inv_one_minus_pow_b2; + + const scalar_t updates_out = mu_hat / (sqrt(nu_hat + eps_root) + eps); mu_ptr[tid] = mu_out; nu_ptr[tid] = nu_out; @@ -94,9 +95,8 @@ void adamForwardMuCPUKernel(const scalar_t *__restrict__ updates_ptr, const other_t b1, const size_t n, scalar_t *__restrict__ mu_out_ptr) { -#pragma omp parallel for num_threads( \ - std::min(n / MIN_NUMEL_USE_OMP, \ - static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT +#pragma omp parallel for num_threads(std::min( \ + n / MIN_NUMEL_USE_OMP, static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) for (size_t tid = 0; tid < n; ++tid) { const scalar_t updates = updates_ptr[tid]; const scalar_t mu = mu_ptr[tid]; @@ -128,9 +128,8 @@ void adamForwardNuCPUKernel(const scalar_t *__restrict__ updates_ptr, const other_t b2, const size_t n, scalar_t *__restrict__ nu_out_ptr) { -#pragma omp parallel for num_threads( \ - std::min(n / MIN_NUMEL_USE_OMP, \ - static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT +#pragma omp parallel for num_threads(std::min( \ + n / MIN_NUMEL_USE_OMP, static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) for (size_t tid = 0; tid < n; ++tid) { const scalar_t updates = updates_ptr[tid]; const scalar_t nu = nu_ptr[tid]; @@ -166,9 +165,8 @@ void adamForwardUpdatesCPUKernel(const scalar_t *__restrict__ new_mu_ptr, const other_t eps_root, const size_t n, scalar_t *__restrict__ updates_out_ptr) { -#pragma omp parallel for num_threads( \ - std::min(n / MIN_NUMEL_USE_OMP, \ - static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT +#pragma omp parallel for num_threads(std::min( \ + n / MIN_NUMEL_USE_OMP, static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) for (size_t tid = 0; tid < n; ++tid) { const scalar_t new_mu = new_mu_ptr[tid]; const scalar_t new_nu = new_nu_ptr[tid]; @@ -212,9 +210,8 @@ void adamBackwardMuCPUKernel(const scalar_t *__restrict__ dmu_ptr, const size_t n, scalar_t *__restrict__ dupdates_out_ptr, scalar_t *__restrict__ dmu_out_ptr) { -#pragma omp parallel for num_threads( \ - std::min(n / MIN_NUMEL_USE_OMP, \ - static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT +#pragma omp parallel for num_threads(std::min( \ + n / MIN_NUMEL_USE_OMP, static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) for (size_t tid = 0; tid < n; ++tid) { const scalar_t dmu = dmu_ptr[tid]; @@ -249,9 +246,8 @@ void adamBackwardNuCPUKernel(const scalar_t *__restrict__ dnu_ptr, const size_t n, scalar_t *__restrict__ dupdates_out_ptr, scalar_t *__restrict__ dnu_out_ptr) { -#pragma omp parallel for num_threads( \ - std::min(n / MIN_NUMEL_USE_OMP, \ - static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT +#pragma omp parallel for num_threads(std::min( \ + n / MIN_NUMEL_USE_OMP, static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) for (size_t tid = 0; tid < n; ++tid) { const scalar_t dnu = dnu_ptr[tid]; const scalar_t updates = updates_ptr[tid]; @@ -290,9 +286,8 @@ void adamBackwardUpdatesCPUKernel(const scalar_t *__restrict__ dupdates_ptr, const size_t n, scalar_t *__restrict__ dnew_mu_out_ptr, scalar_t *__restrict__ dnew_nu_out_ptr) { -#pragma omp parallel for num_threads( \ - std::min(n / MIN_NUMEL_USE_OMP, \ - static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT +#pragma omp parallel for num_threads(std::min( \ + n / MIN_NUMEL_USE_OMP, static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) for (size_t tid = 0; tid < n; ++tid) { const scalar_t dupdates = dupdates_ptr[tid]; const scalar_t updates = updates_ptr[tid]; @@ -320,10 +315,11 @@ TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor &dupdates, const torch::Tensor &new_nu, const pyfloat_t b1, const pyfloat_t b2, + const pyfloat_t eps_root, const pyuint_t count) { using other_t = pyfloat_t; const other_t one_minus_pow_b1 = 1 - std::pow(b1, count); - const other_t inv_one_minus_pow_b2 = 1 / (1 - std::pow(b2, count)); + const other_t inv_one_minus_pow_b2 = 1 / (1 - std::pow(b2, count) + eps_root); auto dmu_out = torch::empty_like(new_mu); auto dnu_out = torch::empty_like(new_nu); diff --git a/src/adam_op/adam_op_impl_cuda.cu b/src/adam_op/adam_op_impl_cuda.cu index 4b65869f..ea1526a6 100644 --- a/src/adam_op/adam_op_impl_cuda.cu +++ b/src/adam_op/adam_op_impl_cuda.cu @@ -1,4 +1,4 @@ -// Copyright 2022 MetaOPT Team. All Rights Reserved. +// Copyright 2022-2023 MetaOPT Team. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -51,8 +51,10 @@ __global__ void adamForwardInplaceCUDAKernel(const other_t b1, const scalar_t mu_out = b1 * mu + (1 - b1) * updates; const scalar_t nu_out = b2 * nu + (1 - b2) * updates * updates; - const scalar_t updates_out = - mu_out * inv_one_minus_pow_b1 / (sqrt(nu_out * inv_one_minus_pow_b2 + eps_root) + eps); + const scalar_t mu_hat = mu_out * inv_one_minus_pow_b1; + const scalar_t nu_hat = nu_out * inv_one_minus_pow_b2; + + const scalar_t updates_out = mu_hat / (sqrt(nu_hat + eps_root) + eps); mu_ptr[tid] = mu_out; nu_ptr[tid] = nu_out; @@ -445,10 +447,11 @@ TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor &dupdates, const torch::Tensor &new_nu, const pyfloat_t b1, const pyfloat_t b2, + const pyfloat_t eps_root, const pyuint_t count) { using other_t = pyfloat_t; const other_t one_minus_pow_b1 = 1 - std::pow(b1, count); - const other_t inv_one_minus_pow_b2 = 1 / (1 - std::pow(b2, count)); + const other_t inv_one_minus_pow_b2 = 1 / (1 - std::pow(b2, count) + eps_root); auto dmu_out = torch::empty_like(new_mu); auto dnu_out = torch::empty_like(new_nu); diff --git a/tests/.coveragerc b/tests/.coveragerc new file mode 100644 index 00000000..462c4c3a --- /dev/null +++ b/tests/.coveragerc @@ -0,0 +1,8 @@ +[run] +omit = + ../torchopt/distributed/* + ../torchopt/visual.py + ../torchopt/version.py + ../docs/* + ../examples/* + ../tutorials/* diff --git a/tests/conftest.py b/tests/conftest.py index 41b7db0b..eaa734b2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/helpers.py b/tests/helpers.py index 6c7c4f01..23e178f0 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,11 +13,13 @@ # limitations under the License. # ============================================================================== +from __future__ import annotations + import copy import itertools import os import random -from typing import Iterable, Optional, Tuple, Union +from typing import Iterable import numpy as np import pytest @@ -26,6 +28,9 @@ import torch.types from torch.utils import data +from torchopt import pytree +from torchopt.typing import TensorTree + BATCH_SIZE = 64 NUM_UPDATES = 5 @@ -100,12 +105,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @torch.no_grad() -def get_models( - device: torch.types.Device = None, dtype: torch.dtype = torch.float32 -) -> Tuple[nn.Module, nn.Module, nn.Module, data.DataLoader]: - seed_everything(seed=42) - - model_base = nn.Sequential( +def get_model(): + return nn.Sequential( MyLinear( in_features=MODEL_NUM_INPUTS, out_features=MODEL_HIDDEN_SIZE, @@ -132,7 +133,16 @@ def get_models( bias=False, ), nn.Softmax(dim=-1), - ).to(dtype=dtype) + ) + + +@torch.no_grad() +def get_models( + device: torch.types.Device = None, dtype: torch.dtype = torch.float32 +) -> tuple[nn.Module, nn.Module, nn.Module, data.DataLoader]: + seed_everything(seed=42) + + model_base = get_model().to(dtype=dtype) for name, param in model_base.named_parameters(recurse=True): if name.endswith('weight') and param.ndim >= 2: nn.init.orthogonal_(param) @@ -158,15 +168,14 @@ def get_models( @torch.no_grad() def assert_model_all_close( - model: Union[nn.Module, Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]], + model: nn.Module | tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]], model_ref: nn.Module, model_base: nn.Module, dtype: torch.dtype = torch.float32, - rtol: Optional[float] = None, - atol: Optional[float] = None, + rtol: float | None = None, + atol: float | None = None, equal_nan: bool = False, -): - +) -> None: if isinstance(model, tuple): params, buffers = model elif isinstance(model, nn.Module): @@ -187,11 +196,10 @@ def assert_all_close( actual: torch.Tensor, expected: torch.Tensor, base: torch.Tensor = None, - rtol: Optional[float] = None, - atol: Optional[float] = None, + rtol: float | None = None, + atol: float | None = None, equal_nan: bool = False, ) -> None: - if base is not None: actual = actual - base expected = expected - base @@ -211,3 +219,32 @@ def assert_all_close( equal_nan=equal_nan, check_dtype=True, ) + + +@torch.no_grad() +def assert_pytree_all_close( + actual: TensorTree, + expected: TensorTree, + base: TensorTree | None = None, + rtol: float | None = None, + atol: float | None = None, + equal_nan: bool = False, +) -> None: + actual_leaves, actual_treespec = pytree.tree_flatten(actual) + expected_leaves, expected_treespec = pytree.tree_flatten(expected) + assert actual_treespec == expected_treespec + if base is not None: + base_leaves, base_treespec = pytree.tree_flatten(base) + assert base_treespec == expected_treespec + else: + base_leaves = [None] * len(actual_leaves) + + for actual_leaf, expected_leaf, base_leaf in zip(actual_leaves, expected_leaves, base_leaves): + assert_all_close( + actual_leaf, + expected_leaf, + base=base_leaf, + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + ) diff --git a/tests/requirements.txt b/tests/requirements.txt index b8c70827..6706dca5 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -10,16 +10,15 @@ optax pytest pytest-cov pytest-xdist -isort +isort >= 5.11.0 black[jupyter] >= 22.6.0 pylint[spelling] >= 2.15.0 mypy >= 0.990 -types-setuptools flake8 flake8-bugbear # https://github.com/PyCQA/doc8/issues/112 doc8 < 1.0.0a0 -pydocstyle +pydocstyle[toml] pyenchant cpplint pre-commit diff --git a/tests/test_accelerated_op.py b/tests/test_accelerated_op.py new file mode 100644 index 00000000..4821a03d --- /dev/null +++ b/tests/test_accelerated_op.py @@ -0,0 +1,193 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import functorch +import torch +import torch.nn.functional as F + +import helpers +import torchopt + + +try: + import torchopt._C.adam_op +except ImportError: + CXX_ACCELERATED_OP_AVAILABLE = False +else: + CXX_ACCELERATED_OP_AVAILABLE = True + + +def test_accelerated_op_is_available() -> None: + assert torchopt.accelerated_op_available('cpu') + assert torchopt.accelerated_op_available(torch.device('cpu')) + + if CXX_ACCELERATED_OP_AVAILABLE: + assert not torchopt.accelerated_op_available('meta') + assert not torchopt.accelerated_op_available(torch.device('meta')) + assert not torchopt.accelerated_op_available(['cpu', 'meta']) + assert not torchopt.accelerated_op_available([torch.device('cpu'), torch.device('meta')]) + else: + assert torchopt.accelerated_op_available('meta') + assert torchopt.accelerated_op_available(torch.device('meta')) + assert torchopt.accelerated_op_available(['cpu', 'meta']) + assert torchopt.accelerated_op_available([torch.device('cpu'), torch.device('meta')]) + + if torch.cuda.is_available(): + assert torchopt.accelerated_op_available() + assert torchopt.accelerated_op_available('cuda') + assert torchopt.accelerated_op_available('cuda:0') + assert torchopt.accelerated_op_available(0) + assert torchopt.accelerated_op_available(['cpu', 'cuda']) + assert torchopt.accelerated_op_available(['cpu', 'cuda:0']) + assert torchopt.accelerated_op_available(['cpu', 0]) + else: + assert not torchopt.accelerated_op_available() + assert not torchopt.accelerated_op_available('cuda') + assert not torchopt.accelerated_op_available('cuda:0') + assert not torchopt.accelerated_op_available(0) + assert not torchopt.accelerated_op_available(['cpu', 'cuda']) + assert not torchopt.accelerated_op_available(['cpu', 'cuda:0']) + assert not torchopt.accelerated_op_available(['cpu', 0]) + + +@helpers.parametrize( + dtype=[torch.float64, torch.float32], + lr=[1e-2, 1e-3, 1e-4], + inplace=[True, False], +) +def test_accelerated_op( + dtype: torch.dtype, + lr: float, + inplace: bool, +) -> None: + if dtype is torch.float32 and inplace: + return + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + fmodel, params, buffers = functorch.make_functional_with_buffers(model) + optim = torchopt.adam( + lr, + use_accelerated_op=True, + ) + optim_state = optim.init(params) + + fmodel_ref, params_ref, buffers_ref = functorch.make_functional_with_buffers(model_ref) + optim_ref = torchopt.adam( + lr, + use_accelerated_op=False, + ) + optim_state_ref = optim_ref.init(params_ref) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = fmodel(params, buffers, xs) + pred_ref = fmodel_ref(params_ref, buffers_ref, xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + grads = torch.autograd.grad(loss, params, allow_unused=True) + updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) + params = torchopt.apply_updates(params, updates, inplace=inplace) + + grads = torch.autograd.grad(loss_ref, params_ref, allow_unused=True) + updates, optim_state_ref = optim_ref.update( + grads, optim_state_ref, params=params, inplace=inplace + ) + params_ref = torchopt.apply_updates(params_ref, updates, inplace=inplace) + + helpers.assert_pytree_all_close(params, params_ref) + + +@helpers.parametrize( + dtype=[torch.float64, torch.float32], + outer_lr=[1e-2, 1e-3, 1e-4], + inner_lr=[1e-2, 1e-3, 1e-4], + inner_update=[2, 3, 5], + inplace=[True, False], +) +def test_maml_accelerated_op( + dtype: torch.dtype, + outer_lr: float, + inner_lr: float, + inner_update: int, + inplace: bool, +) -> None: + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + fmodel, params, buffers = functorch.make_functional_with_buffers(model) + outer_optim = torchopt.adam( + outer_lr, + use_accelerated_op=True, + ) + outer_optim_state = outer_optim.init(params) + + fmodel_ref, params_ref, buffers_ref = functorch.make_functional_with_buffers(model_ref) + outer_optim_ref = torchopt.adam( + outer_lr, + use_accelerated_op=False, + ) + outer_optim_state_ref = outer_optim_ref.init(params_ref) + + def maml_inner_solver(params, data, use_accelerated_op): + # Initial functional optimizer based on TorchOpt + x, y, f, b = data + inner_optimizer = torchopt.adam( + inner_lr, + use_accelerated_op=use_accelerated_op, + ) + inner_opt_state = inner_optimizer.init(params) + with torch.enable_grad(): + # Temporarily enable gradient computation for conducting the optimization + for _ in range(inner_update): + pred = f(params, b, x) + inner_loss = F.cross_entropy(pred, y) # compute loss + grads = torch.autograd.grad( + inner_loss, params, allow_unused=True + ) # compute gradients + updates, inner_opt_state = inner_optimizer.update( + grads, inner_opt_state, inplace=False + ) # get updates + params = torchopt.apply_updates(params, updates, inplace=False) + return (params, b) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + data = (xs, ys, fmodel, buffers) + data_ref = (xs, ys, fmodel_ref, buffers_ref) + + params_prime, buffers_prime = maml_inner_solver(params, data, use_accelerated_op=True) + params_prime_ref, buffers_prime_ref = maml_inner_solver( + params_ref, data_ref, use_accelerated_op=False + ) + + pred = fmodel(params_prime, buffers_prime, xs) + pred_ref = fmodel_ref(params_prime_ref, buffers_prime_ref, xs) + outer_loss = F.cross_entropy(pred, ys) + outer_loss_ref = F.cross_entropy(pred_ref, ys) + + grads = torch.autograd.grad(outer_loss, params, allow_unused=True) + updates, outer_optim_state = outer_optim.update( + grads, outer_optim_state, params=params, inplace=inplace + ) + params = torchopt.apply_updates(params, updates, inplace=inplace) + + grads = torch.autograd.grad(outer_loss_ref, params_ref, allow_unused=True) + updates, outer_optim_state_ref = outer_optim_ref.update( + grads, outer_optim_state_ref, params=params, inplace=inplace + ) + params_ref = torchopt.apply_updates(params_ref, updates, inplace=inplace) + + torchopt.stop_gradient(model) + torchopt.stop_gradient(model_ref) diff --git a/tests/test_alias.py b/tests/test_alias.py index 50b42835..b609cf58 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,9 @@ # limitations under the License. # ============================================================================== -from typing import Tuple +from __future__ import annotations + +from typing import Callable import functorch import pytest @@ -22,6 +24,7 @@ import helpers import torchopt +from torchopt.alias.utils import _set_use_chain_flat @helpers.parametrize( @@ -33,6 +36,7 @@ inplace=[True, False], weight_decay=[0.0, 1e-2], maximize=[False, True], + use_chain_flat=[True, False], ) def test_sgd( dtype: torch.dtype, @@ -43,10 +47,13 @@ def test_sgd( inplace: bool, weight_decay: float, maximize: bool, + use_chain_flat: bool, ) -> None: if nesterov and (momentum <= 0.0 or dampening != 0.0): pytest.skip('Nesterov momentum requires a momentum and zero dampening.') + _set_use_chain_flat(use_chain_flat) + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) fmodel, params, buffers = functorch.make_functional_with_buffers(model) @@ -85,6 +92,7 @@ def test_sgd( optim_ref.step() helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) + _set_use_chain_flat(True) @helpers.parametrize( @@ -95,16 +103,22 @@ def test_sgd( inplace=[True, False], weight_decay=[0.0, 1e-2], maximize=[False, True], + use_accelerated_op=[False, True], + use_chain_flat=[True, False], ) def test_adam( dtype: torch.dtype, lr: float, - betas: Tuple[float, float], + betas: tuple[float, float], eps: float, inplace: bool, weight_decay: float, maximize: bool, + use_accelerated_op: bool, + use_chain_flat: bool, ) -> None: + _set_use_chain_flat(use_chain_flat) + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) fmodel, params, buffers = functorch.make_functional_with_buffers(model) @@ -115,6 +129,7 @@ def test_adam( eps_root=0.0, weight_decay=weight_decay, maximize=maximize, + use_accelerated_op=use_accelerated_op, ) optim_state = optim.init(params) optim_ref = torch.optim.Adam( @@ -143,64 +158,97 @@ def test_adam( optim_ref.step() helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) + _set_use_chain_flat(True) @helpers.parametrize( dtype=[torch.float64], - lr=[1e-2, 1e-3, 1e-4], + outer_lr=[1e-2, 1e-3, 1e-4], + inner_lr=[1e-2, 1e-3, 1e-4], + inner_update=[2, 3, 5], betas=[(0.9, 0.999), (0.95, 0.9995)], eps=[1e-8], inplace=[True, False], weight_decay=[0.0, 1e-2], maximize=[False, True], + use_accelerated_op=[False, True], + use_chain_flat=[True, False], ) -def test_adamw( +def test_maml_adam( dtype: torch.dtype, - lr: float, - betas: Tuple[float, float], + outer_lr: float, + inner_lr: float, + inner_update: int, + betas: tuple[float, float], eps: float, inplace: bool, weight_decay: float, maximize: bool, + use_accelerated_op: bool, + use_chain_flat: bool, ) -> None: + _set_use_chain_flat(use_chain_flat) + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) fmodel, params, buffers = functorch.make_functional_with_buffers(model) - optim = torchopt.adamw( - lr, + outer_optim = torchopt.adam( + outer_lr, betas=betas, eps=eps, eps_root=0.0, weight_decay=weight_decay, maximize=maximize, + use_accelerated_op=use_accelerated_op, ) - optim_state = optim.init(params) - optim_ref = torch.optim.AdamW( - model_ref.parameters(), - lr, - betas=betas, - eps=eps, - amsgrad=False, - weight_decay=weight_decay, - maximize=maximize, - ) + outer_optim_state = outer_optim.init(params) + + def maml_inner_solver_torchopt(params, data, use_accelerated_op): + # Initial functional optimizer based on TorchOpt + x, y, f, b = data + inner_optimizer = torchopt.adam( + inner_lr, + betas=betas, + eps=eps, + eps_root=0.0, + weight_decay=weight_decay, + maximize=maximize, + use_accelerated_op=use_accelerated_op, + ) + inner_opt_state = inner_optimizer.init(params) + with torch.enable_grad(): + # Temporarily enable gradient computation for conducting the optimization + for _ in range(inner_update): + pred = f(params, b, x) + inner_loss = F.cross_entropy(pred, y) # compute loss + grads = torch.autograd.grad( + inner_loss, params, allow_unused=True + ) # compute gradients + updates, inner_opt_state = inner_optimizer.update( + grads, inner_opt_state, params=params, inplace=False + ) # get updates + params = torchopt.apply_updates(params, updates, inplace=False) + return (params, b) for xs, ys in loader: xs = xs.to(dtype=dtype) - pred = fmodel(params, buffers, xs) - pred_ref = model_ref(xs) - loss = F.cross_entropy(pred, ys) - loss_ref = F.cross_entropy(pred_ref, ys) - - grads = torch.autograd.grad(loss, params, allow_unused=True) - updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) + data = (xs, ys, fmodel, buffers) + + params_prime, buffers_prime = maml_inner_solver_torchopt( + params, data, use_accelerated_op=True + ) + pred = fmodel(params_prime, buffers_prime, xs) + outer_loss = F.cross_entropy(pred, ys) + + grads = torch.autograd.grad(outer_loss, params, allow_unused=True) + updates, outer_optim_state = outer_optim.update( + grads, outer_optim_state, params=params, inplace=inplace + ) params = torchopt.apply_updates(params, updates, inplace=inplace) - optim_ref.zero_grad() - loss_ref.backward() - optim_ref.step() + torchopt.stop_gradient(model) - helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) + _set_use_chain_flat(True) @helpers.parametrize( @@ -209,32 +257,38 @@ def test_adamw( betas=[(0.9, 0.999), (0.95, 0.9995)], eps=[1e-8], inplace=[True, False], - weight_decay=[1e-2, 1e-1], + weight_decay=[0.0, 1e-2], maximize=[False, True], + use_accelerated_op=[False, True], + use_chain_flat=[True, False], ) -def test_adam_accelerated_cpu( +def test_adamw( dtype: torch.dtype, lr: float, - betas: Tuple[float, float], + betas: tuple[float, float], eps: float, inplace: bool, weight_decay: float, maximize: bool, + use_accelerated_op: bool, + use_chain_flat: bool, ) -> None: + _set_use_chain_flat(use_chain_flat) + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) fmodel, params, buffers = functorch.make_functional_with_buffers(model) - optim = torchopt.adam( + optim = torchopt.adamw( lr, betas=betas, eps=eps, eps_root=0.0, weight_decay=weight_decay, maximize=maximize, - use_accelerated_op=True, + use_accelerated_op=use_accelerated_op, ) optim_state = optim.init(params) - optim_ref = torch.optim.Adam( + optim_ref = torch.optim.AdamW( model_ref.parameters(), lr, betas=betas, @@ -260,32 +314,44 @@ def test_adam_accelerated_cpu( optim_ref.step() helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) + _set_use_chain_flat(True) @pytest.mark.skipif(not torch.cuda.is_available(), reason='No CUDA device available.') @helpers.parametrize( dtype=[torch.float64], lr=[1e-2, 1e-3, 1e-4], + optimizers=[ + (torchopt.adam, torch.optim.Adam), + (torchopt.adamw, torch.optim.AdamW), + ], betas=[(0.9, 0.999), (0.95, 0.9995)], eps=[1e-8], inplace=[True, False], weight_decay=[0.0, 1e-2], maximize=[False, True], + use_chain_flat=[True, False], ) def test_adam_accelerated_cuda( dtype: torch.dtype, lr: float, - betas: Tuple[float, float], + optimizers: tuple[Callable, torch.optim.Optimizer], + betas: tuple[float, float], eps: float, inplace: bool, weight_decay: float, maximize: bool, + use_chain_flat: bool, ) -> None: + _set_use_chain_flat(use_chain_flat) + device = 'cuda' model, model_ref, model_base, loader = helpers.get_models(device=device, dtype=dtype) + torchopt_optimizer, torch_optimizer = optimizers + fmodel, params, buffers = functorch.make_functional_with_buffers(model) - optim = torchopt.adam( + optim = torchopt_optimizer( lr, betas=betas, eps=eps, @@ -295,7 +361,7 @@ def test_adam_accelerated_cuda( use_accelerated_op=True, ) optim_state = optim.init(params) - optim_ref = torch.optim.Adam( + optim_ref = torch_optimizer( model_ref.parameters(), lr, betas=betas, @@ -322,6 +388,7 @@ def test_adam_accelerated_cuda( optim_ref.step() helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) + _set_use_chain_flat(True) @helpers.parametrize( @@ -333,6 +400,7 @@ def test_adam_accelerated_cuda( centered=[False, True], weight_decay=[0.0, 1e-2], inplace=[True, False], + use_chain_flat=[True, False], ) def test_rmsprop( dtype: torch.dtype, @@ -343,7 +411,10 @@ def test_rmsprop( centered: bool, weight_decay: float, inplace: bool, + use_chain_flat: bool, ) -> None: + _set_use_chain_flat(use_chain_flat) + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) fmodel, params, buffers = functorch.make_functional_with_buffers(model) @@ -383,3 +454,4 @@ def test_rmsprop( optim_ref.step() helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) + _set_use_chain_flat(True) diff --git a/tests/test_clip.py b/tests/test_clip.py index f8d3b289..0b191cfe 100644 --- a/tests/test_clip.py +++ b/tests/test_clip.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ import helpers import torchopt +from torchopt.alias.utils import _set_use_chain_flat @helpers.parametrize( @@ -31,6 +32,7 @@ nesterov=[False, True], weight_decay=[0.0, 1e-2], maximize=[False, True], + use_chain_flat=[True, False], ) def test_sgd( dtype: torch.dtype, @@ -41,10 +43,13 @@ def test_sgd( nesterov: bool, weight_decay: float, maximize: bool, + use_chain_flat: bool, ) -> None: if nesterov and (momentum <= 0.0 or dampening != 0.0): pytest.skip('Nesterov momentum requires a momentum and zero dampening.') + _set_use_chain_flat(use_chain_flat) + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) chain = torchopt.chain( @@ -86,3 +91,4 @@ def test_sgd( optim_ref.step() helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype) + _set_use_chain_flat(True) diff --git a/tests/test_combine.py b/tests/test_combine.py new file mode 100644 index 00000000..ad018d21 --- /dev/null +++ b/tests/test_combine.py @@ -0,0 +1,51 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import torchopt +from torchopt.alias.utils import _set_use_chain_flat + + +def test_chain() -> None: + assert torchopt.chain() == torchopt.base.identity() + assert torchopt.chain(torchopt.base.identity()) == torchopt.base.identity() + assert ( + torchopt.chain(torchopt.base.identity(), torchopt.base.identity()) + == torchopt.base.identity() + ) + assert torchopt.base.identity().chain(torchopt.base.identity()) == torchopt.base.identity() + assert isinstance(torchopt.base.identity(), torchopt.base.IdentityGradientTransformation) + assert isinstance( + torchopt.base.identity().chain(torchopt.base.identity()), + torchopt.base.ChainedGradientTransformation, + ) + + _set_use_chain_flat(False) + adam = torchopt.adam() + assert isinstance(adam, torchopt.base.ChainedGradientTransformation) + assert isinstance( + adam.chain(torchopt.base.identity()), torchopt.base.ChainedGradientTransformation + ) + assert adam.chain(torchopt.base.identity()) == adam + assert torchopt.base.identity().chain(adam) == adam + assert torchopt.chain(torchopt.base.identity(), adam, torchopt.base.identity()) == adam + _set_use_chain_flat(True) + + assert isinstance(adam, torchopt.base.GradientTransformation) + assert isinstance( + adam.chain(torchopt.base.identity()), torchopt.base.ChainedGradientTransformation + ) + assert adam.chain(torchopt.base.identity()) == adam + assert torchopt.base.identity().chain(adam) == adam + assert torchopt.chain(torchopt.base.identity(), adam, torchopt.base.identity()) == adam diff --git a/tests/test_hook.py b/tests/test_hook.py new file mode 100644 index 00000000..1f3024c7 --- /dev/null +++ b/tests/test_hook.py @@ -0,0 +1,38 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import torch + +import torchopt +from torchopt import pytree + + +def test_nan_to_num_hook() -> None: + nan = torch.tensor(torch.nan) + inf = torch.tensor(torch.inf) + ninf = torch.tensor(-torch.inf) + hook = torchopt.hook.nan_to_num_hook(0.0, 1.0, -1.0) + result = pytree.tree_map(hook, [nan, inf, ninf]) + assert torch.equal(result[0], torch.tensor(0.0)) + assert torch.equal(result[1], torch.tensor(1.0)) + assert torch.equal(result[2], torch.tensor(-1.0)) + + +def test_zero_nan_hook() -> None: + tensor = torch.tensor(1.0, requires_grad=True) + hook = torchopt.hook.zero_nan_hook + fn = torchopt.register_hook(hook) + fn.update(tensor, None) + assert tensor._backward_hooks[0] is hook diff --git a/tests/test_implicit.py b/tests/test_implicit.py index ac61b3be..9e3722d3 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,10 +13,11 @@ # limitations under the License. # ============================================================================== +from __future__ import annotations + import copy from collections import OrderedDict from types import FunctionType -from typing import Tuple import functorch import jax @@ -55,7 +56,7 @@ def forward(self, x): return self.fc(x) -def get_model_jax(dtype: np.dtype = np.float32) -> Tuple[FunctionType, OrderedDict]: +def get_model_jax(dtype: np.dtype = np.float32) -> tuple[FunctionType, OrderedDict]: helpers.seed_everything(seed=42) def func(params, x): @@ -73,7 +74,7 @@ def func(params, x): @torch.no_grad() def get_model_torch( device: torch.types.Device = None, dtype: torch.dtype = torch.float32 -) -> Tuple[nn.Module, data.DataLoader]: +) -> tuple[nn.Module, data.DataLoader]: helpers.seed_everything(seed=42) model = FcNet(MODEL_NUM_INPUTS, MODEL_NUM_CLASSES).to(dtype=dtype) @@ -230,8 +231,7 @@ def outer_level(p, xs, ys): nn.Parameter(torch.tensor(np.asarray(jax_params[j]), dtype=dtype)) for j in jax_params ) - for p, p_ref in zip(params, jax_params_as_tensor): - helpers.assert_all_close(p, p_ref) + helpers.assert_pytree_all_close(params, jax_params_as_tensor) @helpers.parametrize( @@ -358,8 +358,7 @@ def outer_level(p, xs, ys): nn.Parameter(torch.tensor(np.asarray(jax_params[j]), dtype=dtype)) for j in jax_params ) - for p, p_ref in zip(params, jax_params_as_tensor): - helpers.assert_all_close(p, p_ref) + helpers.assert_pytree_all_close(params, jax_params_as_tensor) @helpers.parametrize( @@ -470,8 +469,7 @@ def outer_level(p, xs, ys): nn.Parameter(torch.tensor(np.asarray(jax_params[j]), dtype=dtype)) for j in jax_params ) - for p, p_ref in zip(model.parameters(), jax_params_as_tensor): - helpers.assert_all_close(p, p_ref) + helpers.assert_pytree_all_close(tuple(model.parameters()), jax_params_as_tensor) @helpers.parametrize( diff --git a/tests/test_import.py b/tests/test_import.py new file mode 100644 index 00000000..30cf914e --- /dev/null +++ b/tests/test_import.py @@ -0,0 +1,365 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import torchopt + + +def test_accelerated_op_import() -> None: + torchopt.accelerated_op.adam_op.AdamOp + torchopt.accelerated_op.is_available + torchopt.accelerated_op_available + from torchopt.accelerated_op import is_available + from torchopt.accelerated_op.adam_op import AdamOp + + +def test_alias_import() -> None: + torchopt.adam + torchopt.adamw + torchopt.rmsprop + torchopt.sgd + torchopt.alias.adam + torchopt.alias.adamw + torchopt.alias.rmsprop + torchopt.alias.sgd + from torchopt import adam, adamw, rmsprop, sgd + from torchopt.alias import adam, adamw, rmsprop, sgd + + +def test_diff_import() -> None: + torchopt.diff.implicit + torchopt.diff.implicit.custom_root + torchopt.diff.implicit.ImplicitMetaGradientModule + torchopt.diff.implicit.nn.ImplicitMetaGradientModule + torchopt.diff.zero_order + torchopt.diff.zero_order.zero_order + torchopt.diff.zero_order.ZeroOrderGradientModule + torchopt.diff.zero_order.nn.ZeroOrderGradientModule + from torchopt.diff import implicit, zero_order + from torchopt.diff.implicit import ImplicitMetaGradientModule, custom_root + from torchopt.diff.zero_order import ZeroOrderGradientModule, zero_order + + +def test_distributed_import() -> None: + torchopt.distributed.api + torchopt.distributed.autograd + torchopt.distributed.world + torchopt.distributed.is_available + torchopt.distributed.TensorDimensionPartitioner + torchopt.distributed.dim_partitioner + torchopt.distributed.batch_partitioner + torchopt.distributed.mean_reducer + torchopt.distributed.sum_reducer + torchopt.distributed.remote_async_call + torchopt.distributed.remote_sync_call + torchopt.distributed.parallelize + torchopt.distributed.parallelize_async + torchopt.distributed.parallelize_sync + torchopt.distributed.get_world_info + torchopt.distributed.get_world_rank + torchopt.distributed.get_rank + torchopt.distributed.get_world_size + torchopt.distributed.get_local_rank + torchopt.distributed.get_local_world_size + torchopt.distributed.get_worker_id + torchopt.distributed.barrier + torchopt.distributed.auto_init_rpc + torchopt.distributed.on_rank + torchopt.distributed.not_on_rank + torchopt.distributed.rank_zero_only + torchopt.distributed.rank_non_zero_only + torchopt.distributed.autograd.is_available + torchopt.distributed.autograd.context + from torchopt.distributed import api, autograd, world + + +def test_linalg_import() -> None: + torchopt.linalg.cg + torchopt.linalg.ns + torchopt.linalg.ns_inv + from torchopt.linalg import cg, ns, ns_inv + + +def test_linear_solve_import() -> None: + torchopt.linear_solve.solve_cg + torchopt.linear_solve.solve_inv + torchopt.linear_solve.solve_normal_cg + from torchopt.linear_solve import solve_cg, solve_inv, solve_normal_cg + + +def test_nn_import() -> None: + torchopt.nn.MetaGradientModule + torchopt.nn.ImplicitMetaGradientModule + torchopt.nn.ZeroOrderGradientModule + from torchopt.nn import ImplicitMetaGradientModule, MetaGradientModule, ZeroOrderGradientModule + + +def test_optim_import() -> None: + torchopt.FuncOptimizer + torchopt.MetaAdam + torchopt.MetaAdamW + torchopt.MetaRMSProp + torchopt.MetaRMSprop + torchopt.MetaSGD + torchopt.Adam + torchopt.AdamW + torchopt.Optimizer + torchopt.RMSProp + torchopt.RMSprop + torchopt.SGD + torchopt.optim.meta.MetaAdam + torchopt.optim.meta.MetaAdamW + torchopt.optim.meta.MetaRMSProp + torchopt.optim.meta.MetaRMSprop + torchopt.optim.meta.MetaSGD + torchopt.optim.Adam + torchopt.optim.AdamW + torchopt.optim.Optimizer + torchopt.optim.RMSProp + torchopt.optim.RMSprop + torchopt.optim.SGD + torchopt.optim.func.FuncOptimizer + from torchopt import ( + SGD, + Adam, + AdamW, + FuncOptimizer, + MetaAdam, + MetaAdamW, + MetaOptimizer, + MetaRMSProp, + MetaRMSprop, + MetaSGD, + Optimizer, + RMSProp, + ) + from torchopt.optim import SGD, Adam, AdamW, FuncOptimizer, Optimizer, RMSProp + from torchopt.optim.func import FuncOptimizer + from torchopt.optim.meta import ( + MetaAdam, + MetaAdamW, + MetaOptimizer, + MetaRMSProp, + MetaRMSprop, + MetaSGD, + ) + + +def test_schedule_import() -> None: + torchopt.schedule.linear_schedule + torchopt.schedule.polynomial_schedule + from torchopt.schedule import linear_schedule, polynomial_schedule + + +def test_transform_import() -> None: + torchopt.transform.add_decayed_weights + torchopt.transform.scale + torchopt.transform.scale_by_accelerated_adam + torchopt.transform.scale_by_adam + torchopt.transform.scale_by_rms + torchopt.transform.scale_by_schedule + torchopt.transform.scale_by_stddev + torchopt.transform.trace + torchopt.transform.nan_to_num + torchopt.nan_to_num + from torchopt import nan_to_num + from torchopt.transform import ( + add_decayed_weights, + nan_to_num, + scale, + scale_by_accelerated_adam, + scale_by_adam, + scale_by_rms, + scale_by_schedule, + scale_by_stddev, + trace, + ) + + +def test_base_import() -> None: + torchopt.base.EmptyState + torchopt.base.GradientTransformation + torchopt.base.ChainedGradientTransformation + torchopt.base.identity + from torchopt.base import ( + ChainedGradientTransformation, + EmptyState, + GradientTransformation, + identity, + ) + + +def test_clip_import() -> None: + torchopt.clip_grad_norm + torchopt.clip.clip_grad_norm + from torchopt import clip_grad_norm + from torchopt.clip import clip_grad_norm + + +def test_combine_import() -> None: + torchopt.chain + torchopt.chain.flat + torchopt.combine.chain + torchopt.combine.chain.flat + torchopt.combine.chain_flat + from torchopt import chain + from torchopt.combine import chain, chain_flat + + +def test_hook_import() -> None: + torchopt.register_hook + torchopt.hook.register_hook + torchopt.hook.zero_nan_hook + torchopt.hook.nan_to_num_hook + from torchopt import register_hook + from torchopt.hook import nan_to_num_hook, register_hook, zero_nan_hook + + +def test_pytree_import() -> None: + torchopt.pytree.tree_flatten_as_tuple + torchopt.pytree.tree_pos + torchopt.pytree.tree_neg + torchopt.pytree.tree_add + torchopt.pytree.tree_add_scalar_mul + torchopt.pytree.tree_sub + torchopt.pytree.tree_sub_scalar_mul + torchopt.pytree.tree_mul + torchopt.pytree.tree_matmul + torchopt.pytree.tree_scalar_mul + torchopt.pytree.tree_truediv + torchopt.pytree.tree_vdot_real + torchopt.pytree.tree_wait + from torchopt.pytree import ( + tree_add, + tree_add_scalar_mul, + tree_flatten_as_tuple, + tree_matmul, + tree_mul, + tree_neg, + tree_pos, + tree_scalar_mul, + tree_sub, + tree_sub_scalar_mul, + tree_truediv, + tree_vdot_real, + tree_wait, + ) + + +def test_typing_import() -> None: + torchopt.typing.GradientTransformation + torchopt.typing.ChainedGradientTransformation + torchopt.typing.EmptyState + torchopt.typing.UninitializedState + torchopt.typing.Params + torchopt.typing.Updates + torchopt.typing.OptState + torchopt.typing.Scalar + torchopt.typing.Numeric + torchopt.typing.Schedule + torchopt.typing.ScalarOrSchedule + torchopt.typing.PyTree + torchopt.typing.Tensor + torchopt.typing.OptionalTensor + torchopt.typing.ListOfTensors + torchopt.typing.TupleOfTensors + torchopt.typing.SequenceOfTensors + torchopt.typing.TensorOrTensors + torchopt.typing.TensorTree + torchopt.typing.ListOfOptionalTensors + torchopt.typing.TupleOfOptionalTensors + torchopt.typing.SequenceOfOptionalTensors + torchopt.typing.OptionalTensorOrOptionalTensors + torchopt.typing.OptionalTensorTree + torchopt.typing.TensorContainer + torchopt.typing.ModuleTensorContainers + torchopt.typing.Future + torchopt.typing.LinearSolver + torchopt.typing.Device + torchopt.typing.Size + torchopt.typing.Distribution + torchopt.typing.SampleFunc + torchopt.typing.Samplable + from torchopt.typing import ( + ChainedGradientTransformation, + Device, + Distribution, + EmptyState, + Future, + GradientTransformation, + LinearSolver, + ListOfOptionalTensors, + ListOfTensors, + ModuleTensorContainers, + Numeric, + OptionalTensor, + OptionalTensorOrOptionalTensors, + OptionalTensorTree, + OptState, + Params, + PyTree, + Samplable, + SampleFunc, + Scalar, + ScalarOrSchedule, + Schedule, + SequenceOfOptionalTensors, + SequenceOfTensors, + Size, + Tensor, + TensorContainer, + TensorOrTensors, + TensorTree, + TupleOfOptionalTensors, + TupleOfTensors, + UninitializedState, + Updates, + ) + + +def test_update_import() -> None: + torchopt.apply_updates + torchopt.update.apply_updates + from torchopt import apply_updates + from torchopt.update import apply_updates + + +def test_utils_import() -> None: + torchopt.utils.ModuleState + torchopt.utils.stop_gradient + torchopt.utils.extract_state_dict + torchopt.utils.recover_state_dict + torchopt.utils.module_clone + torchopt.utils.module_detach_ + from torchopt.utils import ( + ModuleState, + extract_state_dict, + module_clone, + module_detach_, + recover_state_dict, + stop_gradient, + ) + + +def test_version_import() -> None: + torchopt.__version__ + torchopt.version.__version__ + from torchopt import __version__ + from torchopt.version import __version__ + + +def test_visual_import() -> None: + torchopt.visual.make_dot + torchopt.visual.resize_graph + from torchopt.visual import make_dot, resize_graph diff --git a/tests/test_linalg.py b/tests/test_linalg.py new file mode 100644 index 00000000..7758b7db --- /dev/null +++ b/tests/test_linalg.py @@ -0,0 +1,27 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import torch + +import torchopt + + +def test_normalize_matvec() -> None: + A = [torch.rand(10, 10) for _ in range(10)] + x = [torch.rand(10, 1) for _ in range(10)] + AxFn = torchopt.linalg.utils.normalize_matvec(A) + Ax = AxFn(x) + for Ax_item, A_item, x_item in zip(Ax, A, x): + assert torch.equal(Ax_item, A_item @ x_item) diff --git a/tests/test_meta_optim.py b/tests/test_meta_optim.py index 5916574e..61f8a7ad 100644 --- a/tests/test_meta_optim.py +++ b/tests/test_meta_optim.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,11 +13,78 @@ # limitations under the License. # ============================================================================== +from __future__ import annotations + +import torch +import torch.nn.functional as F + import helpers import torchopt -def test_filter_nones_in_params(): - model = helpers.get_models()[0] +@helpers.parametrize( + dtype=[torch.float64], + outer_lr=[1e-2, 1e-3, 1e-4], + inner_lr=[1e-2, 1e-3, 1e-4], + inner_update=[2, 3, 5], + betas=[(0.9, 0.999), (0.95, 0.9995)], + eps=[1e-8], + eps_root=[0.0, 1e-8], + weight_decay=[0.0, 1e-2], + maximize=[False, True], + use_accelerated_op=[False, True], + moment_requires_grad=[True, False], +) +def test_maml_meta_adam( + dtype: torch.dtype, + outer_lr: float, + inner_lr: float, + inner_update: int, + betas: tuple[float, float], + eps: float, + eps_root: float, + weight_decay: float, + maximize: bool, + use_accelerated_op: bool, + moment_requires_grad: bool, +) -> None: + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + outer_optim = torchopt.Adam( + model.parameters(), + outer_lr, + betas=betas, + eps=eps, + eps_root=0.0, + weight_decay=weight_decay, + maximize=maximize, + use_accelerated_op=use_accelerated_op, + ) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + + inner_optim = torchopt.MetaAdam( + module=model, + lr=inner_lr, + betas=betas, + eps=eps, + eps_root=eps_root, + moment_requires_grad=moment_requires_grad, + weight_decay=weight_decay, + maximize=maximize, + use_accelerated_op=use_accelerated_op, + ) + + for _ in range(inner_update): + pred = model(xs) + inner_loss = F.cross_entropy(pred, ys) # compute loss + inner_optim.step(inner_loss) + + pred = model(xs) + outer_loss = F.cross_entropy(pred, ys) + outer_optim.zero_grad() + outer_loss.backward() + outer_optim.step() - meta_adam = torchopt.MetaAdam(model) + torchopt.stop_gradient(model) diff --git a/tests/test_nn.py b/tests/test_nn.py new file mode 100644 index 00000000..1b48c06b --- /dev/null +++ b/tests/test_nn.py @@ -0,0 +1,180 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import re + +import pytest +import torch +import torch.nn as nn + +import helpers +import torchopt + + +def test_property() -> None: + m = torchopt.nn.MetaGradientModule() + x = helpers.get_model() + m.add_module('x', x) + assert m.x is x + delattr(m, 'x') + assert not hasattr(m, 'x') + m.add_meta_module('x', x) + assert m.x is x + delattr(m, 'x') + assert not hasattr(m, 'x') + x = torch.tensor(1.0, requires_grad=True) + m.register_parameter('x', x) + assert m.x is x + delattr(m, 'x') + assert not hasattr(m, 'x') + x = torch.tensor(1.0, requires_grad=True) + m.register_meta_parameter('x', x) + assert m.x is x + delattr(m, 'x') + assert not hasattr(m, 'x') + m.register_buffer('x', x) + assert len(m._buffers) == 1 + assert m.x is x + delattr(m, 'x') + assert len(m._buffers) == 0 + assert not hasattr(m, 'x') + + +def test_register_tensors() -> None: + x = torch.tensor(1.0, requires_grad=True) + y = torch.tensor(1.0, requires_grad=True) + z = torch.tensor(1.0, requires_grad=False) + b = torch.tensor(1.0, requires_grad=False) + + m = torchopt.nn.MetaGradientModule() + m.register_meta_parameter('x', x) + assert m.x is x + + m = torchopt.nn.MetaGradientModule(x) + m.x = x + m.y = y + m.z = z + + assert m._meta_parameters['x'] is x + assert m._parameters['y'] is y + assert hasattr(m, 'z') and m.z is z and 'z' not in m._buffers + + del m.x + object.__setattr__(m, 'x', x) + assert hasattr(m, 'x') and m.x is x and 'x' not in m._meta_parameters + m.x = x + assert m._meta_parameters['x'] is x + + m.register_buffer('b', None) + assert m.b is None + m.b = b + assert m.b is b and 'b' in m._buffers + + +def test_no_super_init() -> None: + class NoSuper1(torchopt.nn.MetaGradientModule): + def __init__(self, x): + self.x = x + + with pytest.raises( + AttributeError, match=re.escape('cannot assign parameters before Module.__init__() call') + ): + NoSuper1(torch.tensor(1.0, requires_grad=True)) + + class NoSuper2(torchopt.nn.MetaGradientModule): + def __init__(self): + self.x = torch.tensor(1.0, requires_grad=True) + + with pytest.raises( + AttributeError, match=re.escape('cannot assign parameters before Module.__init__() call') + ): + NoSuper2() + + class NoSuper3(torchopt.nn.MetaGradientModule): + def __init__(self): + self.register_buffer('x', torch.tensor(1.0)) + + with pytest.raises( + AttributeError, match=re.escape('cannot assign buffer before Module.__init__() call') + ): + NoSuper3() + + class NoSuper4(torchopt.nn.MetaGradientModule): + def __init__(self): + self.x = torch.tensor(1.0, requires_grad=False) + + NoSuper4() # no error + + class NoSuper5(torchopt.nn.MetaGradientModule): + def __init__(self, x): + self.x = x + + with pytest.raises( + AttributeError, match=re.escape('cannot assign module before Module.__init__() call') + ): + NoSuper5(nn.Linear(1, 1)) + + class NoSuper6(torchopt.nn.MetaGradientModule): + def __init__(self): + self.x = nn.Linear(1, 1) + + with pytest.raises( + AttributeError, match=re.escape('cannot assign module before Module.__init__() call') + ): + NoSuper6() + + +def test_add_meta_module() -> None: + meta_module = helpers.get_model() + fc = nn.Linear(1, 1) + + m = torchopt.nn.MetaGradientModule(meta_module) + m.fc = fc + assert m.fc is fc + assert m._modules['fc'] is fc + + m.meta = meta_module + assert m.meta is meta_module + assert m._meta_modules['meta'] is meta_module + + assert all(p1 is p2 for p1, p2 in zip(m.parameters(), fc.parameters())) + assert all(p1 is p2 for p1, p2 in zip(m.meta_parameters(), meta_module.parameters())) + + m = torchopt.nn.MetaGradientModule(meta_module) + m.add_meta_module('fc', fc) + assert m.fc is fc + assert all(p1 is p2 for p1, p2 in zip(m.meta_parameters(), fc.parameters())) + + +def test_meta_module() -> None: + m = torchopt.nn.MetaGradientModule() + meta_module = torch.nn.Linear(1, 1) + m.add_meta_module('m', meta_module) + assert next(m.named_meta_modules())[1] is meta_module + assert next(m.named_meta_children())[1] is meta_module + assert next(m.meta_children()) is meta_module + assert next(m.meta_modules()) is meta_module + + +def test_add_meta_parameters() -> None: + m = torchopt.nn.MetaGradientModule() + x = torch.tensor(1.0, requires_grad=True) + m.register_meta_parameter('x', x) + assert next(m.named_meta_parameters())[1] is x + + +def test_named_modules() -> None: + m = torchopt.nn.MetaGradientModule() + assert next(m.named_modules())[1] is m diff --git a/tests/test_optim.py b/tests/test_optim.py index fe1697c9..b2be7500 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,9 @@ # limitations under the License. # ============================================================================== -from typing import Callable, Tuple +from __future__ import annotations + +from typing import Callable import functorch import pytest @@ -91,14 +93,16 @@ def test_SGD( eps=[1e-8], weight_decay=[0.0, 1e-2], maximize=[False, True], + use_accelerated_op=[False, True], ) def test_Adam( dtype: torch.dtype, lr: float, - betas: Tuple[float, float], + betas: tuple[float, float], eps: float, weight_decay: float, maximize: bool, + use_accelerated_op: bool, ) -> None: model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) @@ -110,6 +114,7 @@ def test_Adam( eps_root=0.0, weight_decay=weight_decay, maximize=maximize, + use_accelerated_op=use_accelerated_op, ) optim_ref = torch.optim.Adam( model_ref.parameters(), @@ -146,14 +151,16 @@ def test_Adam( eps=[1e-8], weight_decay=[1e-2, 1e-1], maximize=[False, True], + use_accelerated_op=[False, True], ) def test_AdamW( dtype: torch.dtype, lr: float, - betas: Tuple[float, float], + betas: tuple[float, float], eps: float, weight_decay: float, maximize: bool, + use_accelerated_op: bool, ) -> None: model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) @@ -165,6 +172,7 @@ def test_AdamW( eps_root=0.0, weight_decay=weight_decay, maximize=maximize, + use_accelerated_op=use_accelerated_op, ) optim_ref = torch.optim.AdamW( model_ref.parameters(), @@ -194,66 +202,14 @@ def test_AdamW( helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype) -@helpers.parametrize( - dtype=[torch.float64], - lr=[1e-2, 1e-3, 1e-4], - betas=[(0.9, 0.999), (0.95, 0.9995)], - eps=[1e-8], - weight_decay=[0.0, 1e-2], - maximize=[False, True], -) -def test_Adam_accelerated_cpu( - dtype: torch.dtype, - lr: float, - betas: Tuple[float, float], - eps: float, - weight_decay: float, - maximize: bool, -) -> None: - model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) - - optim = torchopt.Adam( - model.parameters(), - lr, - betas=betas, - eps=eps, - eps_root=0.0, - weight_decay=weight_decay, - maximize=maximize, - use_accelerated_op=True, - ) - optim_ref = torch.optim.Adam( - model_ref.parameters(), - lr, - betas=betas, - eps=eps, - amsgrad=False, - weight_decay=weight_decay, - maximize=maximize, - ) - - for xs, ys in loader: - xs = xs.to(dtype=dtype) - pred = model(xs) - pred_ref = model_ref(xs) - loss = F.cross_entropy(pred, ys) - loss_ref = F.cross_entropy(pred_ref, ys) - - optim.zero_grad() - loss.backward() - optim.step() - - optim_ref.zero_grad() - loss_ref.backward() - optim_ref.step() - - helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype) - - @pytest.mark.skipif(not torch.cuda.is_available(), reason='No CUDA device available.') @helpers.parametrize( dtype=[torch.float64], lr=[1e-2, 1e-3, 1e-4], + optimizers=[ + (torchopt.Adam, torch.optim.Adam), + (torchopt.AdamW, torch.optim.AdamW), + ], betas=[(0.9, 0.999), (0.95, 0.9995)], eps=[1e-8], weight_decay=[0.0, 1e-2], @@ -262,7 +218,8 @@ def test_Adam_accelerated_cpu( def test_Adam_accelerated_cuda( dtype: torch.dtype, lr: float, - betas: Tuple[float, float], + optimizers: tuple[torchopt.Optimizer, torch.optim.Optimizer], + betas: tuple[float, float], eps: float, weight_decay: float, maximize: bool, @@ -270,7 +227,9 @@ def test_Adam_accelerated_cuda( device = 'cuda' model, model_ref, model_base, loader = helpers.get_models(device=device, dtype=dtype) - optim = torchopt.Adam( + torchopt_optimizer, torch_optimizer = optimizers + + optim = torchopt_optimizer( model.parameters(), lr, betas=betas, @@ -280,7 +239,7 @@ def test_Adam_accelerated_cuda( maximize=maximize, use_accelerated_op=True, ) - optim_ref = torch.optim.Adam( + optim_ref = torch_optimizer( model_ref.parameters(), lr, betas=betas, @@ -382,7 +341,7 @@ def test_RMSProp( def test_FuncOptimizer( dtype: torch.dtype, lr: float, - optimizers: Tuple[Callable, torch.optim.Optimizer], + optimizers: tuple[Callable, torch.optim.Optimizer], inplace: bool, weight_decay: float, ) -> None: diff --git a/tests/test_pytree.py b/tests/test_pytree.py new file mode 100644 index 00000000..5594e30b --- /dev/null +++ b/tests/test_pytree.py @@ -0,0 +1,214 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import torch + +import helpers +from torchopt import pytree + + +tree_a = (torch.randn(20, 10), torch.randn(20)) +tree_b = (torch.randn(20, 10), torch.randn(20)) + +tree_a_dict = ( + torch.tensor(1.0), + {'k1': torch.tensor(1.0), 'k2': (torch.tensor(1.0), torch.tensor(1.0))}, + torch.tensor(1.0), +) +tree_b_dict = ( + torch.tensor(1.0), + {'k1': torch.tensor(2.0), 'k2': (torch.tensor(3.0), torch.tensor(4.0))}, + torch.tensor(5.0), +) + +tensor_a = torch.randn(20) +tensor_b = torch.randn(20) + + +def test_tree_flatten_as_tuple() -> None: + expected_leaves, expected_treespec = (tensor_a,), pytree.tree_structure(tensor_a) + actual_leaves, actual_treespec = pytree.tree_flatten_as_tuple(tensor_a) + assert actual_leaves == expected_leaves + assert actual_treespec == expected_treespec + + leaves_a, treespec_a = pytree.tree_flatten(tree_a) + expected_leaves, expected_treespec = tuple(leaves_a), treespec_a + actual_leaves, actual_treespec = pytree.tree_flatten_as_tuple(tree_a) + assert actual_leaves == expected_leaves + assert actual_treespec == expected_treespec + + +def test_tree_pos() -> None: + expected = +tensor_a + actual = pytree.tree_pos(tensor_a) + helpers.assert_pytree_all_close(actual, expected) + + expected = (+tree_a[0], +tree_a[1]) + actual = pytree.tree_pos(tree_a) + helpers.assert_pytree_all_close(actual, expected) + + +def test_tree_neg() -> None: + expected = -tensor_a + actual = pytree.tree_neg(tensor_a) + helpers.assert_pytree_all_close(actual, expected) + + expected = (-tree_a[0], -tree_a[1]) + actual = pytree.tree_neg(tree_a) + helpers.assert_pytree_all_close(actual, expected) + + +def test_tree_add() -> None: + expected = tensor_a + tensor_b + actual = pytree.tree_add(tensor_a, tensor_b) + helpers.assert_pytree_all_close(actual, expected) + + expected = (tree_a[0] + tree_b[0], tree_a[1] + tree_b[1]) + actual = pytree.tree_add(tree_a, tree_b) + helpers.assert_pytree_all_close(actual, expected) + + +def test_tree_add_scalar_mul() -> None: + expected = (tree_a[0] + tree_b[0], tree_a[1] + tree_b[1]) + actual = pytree.tree_add_scalar_mul(tree_a, tree_b) + helpers.assert_pytree_all_close(actual, expected) + + expected = (tree_a[0] + 0.5 * tree_b[0], tree_a[1] + 0.5 * tree_b[1]) + actual = pytree.tree_add_scalar_mul(tree_a, tree_b, 0.5) + helpers.assert_pytree_all_close(actual, expected) + + +def test_tree_sub() -> None: + expected = tensor_a - tensor_b + actual = pytree.tree_sub(tensor_a, tensor_b) + helpers.assert_pytree_all_close(actual, expected) + + expected = (tree_a[0] - tree_b[0], tree_a[1] - tree_b[1]) + actual = pytree.tree_sub(tree_a, tree_b) + helpers.assert_pytree_all_close(actual, expected) + + +def test_tree_sub_scalar_mul() -> None: + expected = (tree_a[0] - tree_b[0], tree_a[1] - tree_b[1]) + actual = pytree.tree_sub_scalar_mul(tree_a, tree_b) + helpers.assert_pytree_all_close(actual, expected) + + expected = (tree_a[0] - 0.5 * tree_b[0], tree_a[1] - 0.5 * tree_b[1]) + actual = pytree.tree_sub_scalar_mul(tree_a, tree_b, 0.5) + helpers.assert_pytree_all_close(actual, expected) + + +def test_tree_mul() -> None: + expected = tensor_a * tensor_b + actual = pytree.tree_mul(tensor_a, tensor_b) + helpers.assert_pytree_all_close(actual, expected) + + expected = (tree_a[0] * tree_b[0], tree_a[1] * tree_b[1]) + actual = pytree.tree_mul(tree_a, tree_b) + helpers.assert_pytree_all_close(actual, expected) + + +def test_tree_matmul() -> None: + tree_a = (torch.randn(20, 10), torch.randn(20, 1)) + tree_b = (torch.randn(10, 20), torch.randn(1, 20)) + tensor_a = torch.randn(10, 20) + tensor_b = torch.randn(20) + expected = tensor_a @ tensor_b + actual = pytree.tree_matmul(tensor_a, tensor_b) + helpers.assert_pytree_all_close(actual, expected) + + expected = (tree_a[0] @ tree_b[0], tree_a[1] @ tree_b[1]) + actual = pytree.tree_matmul(tree_a, tree_b) + helpers.assert_pytree_all_close(actual, expected) + + +def test_tree_scalar_mul() -> None: + expected = 0.5 * tensor_a + actual = pytree.tree_scalar_mul(0.5, tensor_a) + helpers.assert_pytree_all_close(actual, expected) + + expected = (0.5 * tree_a[0], 0.5 * tree_a[1]) + actual = pytree.tree_scalar_mul(0.5, tree_a) + helpers.assert_pytree_all_close(actual, expected) + + +def test_tree_truediv() -> None: + expected = (tree_a[0] / tree_b[0], tree_a[1] / tree_b[1]) + actual = pytree.tree_truediv(tree_a, tree_b) + helpers.assert_pytree_all_close(actual, expected) + + actual = pytree.tree_truediv(tree_a_dict, tree_b_dict) + expected = ( + torch.tensor(1.0), + {'k1': torch.tensor(0.5), 'k2': (torch.tensor(1.0 / 3.0), torch.tensor(0.25))}, + torch.tensor(0.2), + ) + helpers.assert_pytree_all_close(actual, expected) + + +def test_tree_vdot_real() -> None: + expected = torch.vdot(tensor_a, tensor_b).real + actual = torch.tensor(pytree.tree_vdot_real(tensor_a, tensor_b)) + helpers.assert_pytree_all_close(actual, expected) + + expected = ( + torch.vdot(tree_a[0].contiguous().view(-1), tree_b[0].contiguous().view(-1)) + + torch.vdot(tree_a[1].contiguous().view(-1), tree_b[1].contiguous().view(-1)) + ).real + actual = torch.tensor(pytree.tree_vdot_real(tree_a, tree_b)) + helpers.assert_all_close(actual, expected) + + tensor_a_complex = torch.randn(20, dtype=torch.cfloat) + tensor_b_complex = torch.randn(20, dtype=torch.cfloat) + expected = torch.vdot(tensor_a_complex, tensor_b_complex).real + actual = torch.tensor(pytree.tree_vdot_real(tensor_a_complex, tensor_b_complex)) + helpers.assert_pytree_all_close(actual, expected) + + tree_a_complex, tree_b_complex = pytree.tree_map( + lambda x: torch.randn(x.size(), dtype=torch.cfloat), (tree_a, tree_b) + ) + expected = ( + torch.vdot(tree_a_complex[0].contiguous().view(-1), tree_b_complex[0].contiguous().view(-1)) + + torch.vdot( + tree_a_complex[1].contiguous().view(-1), tree_b_complex[1].contiguous().view(-1) + ) + ).real + actual = torch.tensor(pytree.tree_vdot_real(tree_a_complex, tree_b_complex)) + helpers.assert_all_close(actual, expected) + + +@helpers.parametrize( + tree_name=[ + 'tree_a', + 'tree_b', + 'tree_a_dict', + 'tree_b_dict', + 'tensor_a', + 'tensor_b', + ] +) +def test_tree_wait(tree_name: str) -> None: + tree = globals()[tree_name] + + future_tree = pytree.tree_map(lambda x: torch.futures.Future(), tree) + new_future_tree = pytree.tree_map( + lambda fut: fut.then(lambda f: torch.square(f.wait()) + 1.0), future_tree + ) + pytree.tree_map_(lambda fut, x: fut.set_result(x), future_tree, tree) + + expected = pytree.tree_map(lambda x: torch.square(x) + 1.0, tree) + actual = pytree.tree_wait(new_future_tree) + assert all(fut.done() for fut in pytree.tree_leaves(new_future_tree)) + helpers.assert_pytree_all_close(actual, expected) diff --git a/tests/test_schedule.py b/tests/test_schedule.py index 67e3429a..ae714875 100644 --- a/tests/test_schedule.py +++ b/tests/test_schedule.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,9 @@ # limitations under the License. # ============================================================================== -from typing import Callable, Tuple +from __future__ import annotations + +from typing import Callable import functorch import numpy as np @@ -22,6 +24,7 @@ import helpers import torchopt +from torchopt.alias.utils import _set_use_chain_flat def test_linear_schedule() -> None: @@ -55,15 +58,19 @@ def test_linear_schedule() -> None: ], inplace=[True, False], weight_decay=[0.0, 1e-2], + use_chain_flat=[True, False], ) def test_lr_linear_schedule( dtype: torch.dtype, lr: float, total_iters: int, - optimizers: Tuple[Callable, torch.optim.Optimizer], + optimizers: tuple[Callable, torch.optim.Optimizer], inplace: bool, weight_decay: float, + use_chain_flat: bool, ) -> None: + _set_use_chain_flat(use_chain_flat) + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) torchopt_optimizer, torch_optimizer = optimizers @@ -102,3 +109,4 @@ def test_lr_linear_schedule( torch_scheduler.step() helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) + _set_use_chain_flat(True) diff --git a/tests/test_transform.py b/tests/test_transform.py new file mode 100644 index 00000000..9598386d --- /dev/null +++ b/tests/test_transform.py @@ -0,0 +1,60 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import torch + +import torchopt + + +def test_nan_to_num() -> None: + fn = torchopt.nan_to_num(0.0, 1.0, -1.0) + nan = torch.tensor(torch.nan) + inf = torch.tensor(torch.inf) + ninf = torch.tensor(-torch.inf) + updated, _ = fn.update(nan, None, inplace=False) + assert torch.equal(updated, torch.tensor(0.0)) + assert updated is not nan + + updated, _ = fn.update(inf, None, inplace=False) + assert torch.equal(updated, torch.tensor(1.0)) + assert updated is not inf + + updated, _ = fn.update(ninf, None, inplace=False) + assert torch.equal(updated, torch.tensor(-1.0)) + assert updated is not ninf + + updated, _ = fn.update(nan, None, inplace=True) + assert torch.equal(updated, torch.tensor(0.0)) + assert updated is nan + + updated, _ = fn.update(inf, None, inplace=True) + assert torch.equal(updated, torch.tensor(1.0)) + assert updated is inf + + updated, _ = fn.update(ninf, None, inplace=True) + assert torch.equal(updated, torch.tensor(-1.0)) + assert updated is ninf + + +def test_masked() -> None: + fn = torchopt.nan_to_num(0.0, 1.0, -1.0) + nan = torch.tensor(torch.nan) + updates = [nan, nan, nan] + + masked_fn = torchopt.transform.masked(fn, [True, False, True]) + state = masked_fn.init(updates) + + updates, _ = masked_fn.update(updates, state) + assert nan is updates[1] diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..0c80cec0 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,140 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import torch + +import torchopt +from torchopt import pytree + + +def test_stop_gradient() -> None: + x = torch.tensor(1.0, requires_grad=True) + y = 2 * x + assert y.grad_fn is not None + torchopt.stop_gradient(y) + assert y.grad_fn is None + fc = torch.nn.Linear(1, 1, False) + fc._parameters['weight'] = fc.weight * 2 + assert fc.weight.grad_fn is not None + torchopt.stop_gradient(fc) + assert fc.weight.grad_fn is None + + +def test_module_clone() -> None: + x = torch.tensor(1.0, requires_grad=True) + y = 2 * x + assert y.grad_fn is not None + z = torchopt.module_clone(y, by='reference') + assert z is y + z = torchopt.module_clone(x, by='copy') + assert z is not x + assert z.grad_fn.next_functions[0][0].variable is x + + z = torchopt.module_clone(y, by='deepcopy') + assert z is not y + assert z.grad_fn is None + assert torch.equal(z, y) + + x = torch.tensor(1.0, requires_grad=True) + y = torchopt.module_clone(x, by='reference', device='meta') + assert y.grad_fn.next_functions[0][0].variable is x + assert y.is_meta + + y = torchopt.module_clone(x, by='copy', device='meta') + assert y is not x + assert y.grad_fn.next_functions[0][0].next_functions[0][0].variable is x + assert y.is_meta + + y = torchopt.module_clone(x, by='deepcopy', device='meta') + assert y is not x + assert y.grad_fn is None + assert y.is_meta + + if torch.cuda.is_available(): + x = torch.tensor(1.0, requires_grad=True) + y = torchopt.module_clone(x, by='reference', device='cuda') + assert y.grad_fn.next_functions[0][0].variable is x + assert y.is_cuda + + y = torchopt.module_clone(x, by='copy', device='cuda') + assert y is not x + assert y.grad_fn.next_functions[0][0].next_functions[0][0].variable is x + assert y.is_cuda + + y = torchopt.module_clone(x, by='deepcopy', device='cuda') + assert y is not x + assert y.grad_fn is None + assert torch.equal(y.to(x.device), x) + assert y.is_cuda + + +def test_extract_state_dict(): + fc = torch.nn.Linear(1, 1) + state_dict = torchopt.extract_state_dict(fc, by='reference', device=torch.device('meta')) + for param_dict in state_dict.params: + for k, v in param_dict.items(): + assert v.is_meta + assert v.grad_fn.next_functions[0][0].variable is fc._parameters[k] + + state_dict = torchopt.extract_state_dict(fc, by='copy', device=torch.device('meta')) + for param_dict in state_dict.params: + for k, v in param_dict.items(): + assert v.is_meta + assert v.grad_fn.next_functions[0][0].next_functions[0][0].variable is fc._parameters[k] + + state_dict = torchopt.extract_state_dict(fc, by='deepcopy', device=torch.device('meta')) + for param_dict in state_dict.params: + for k, v in param_dict.items(): + assert v.is_meta + assert v.grad_fn is None + + state_dict = torchopt.extract_state_dict(fc, by='reference') + for param_dict in state_dict.params: + for k, v in param_dict.items(): + assert v is fc._parameters[k] + + state_dict = torchopt.extract_state_dict(fc, by='copy') + for param_dict in state_dict.params: + for k, v in param_dict.items(): + assert torch.equal(v, fc._parameters[k]) + assert v.grad_fn.next_functions[0][0].variable is fc._parameters[k] + + state_dict = torchopt.extract_state_dict(fc, by='deepcopy') + for param_dict in state_dict.params: + for k, v in param_dict.items(): + assert torch.equal(v, fc._parameters[k]) + assert v.grad_fn is None + + optim = torchopt.MetaAdam(fc, 1.0) + loss = fc(torch.ones(1, 1)).sum() + optim.step(loss) + state_dict = torchopt.extract_state_dict(optim) + same = pytree.tree_map(lambda x, y: x is y, state_dict, tuple(optim.state_groups)) + assert all(pytree.tree_flatten(same)[0]) + + +def test_stop_gradient_for_state_dict() -> None: + fc = torch.nn.Linear(1, 1) + + state_dict = torchopt.extract_state_dict(fc, by='copy') + for param_dict in state_dict.params: + for k, v in param_dict.items(): + assert v.grad_fn.next_functions[0][0].variable is fc._parameters[k] + + torchopt.stop_gradient(state_dict) + for param_dict in state_dict.params: + for k, v in param_dict.items(): + assert v.grad_fn is None + assert torch.equal(v, fc._parameters[k]) diff --git a/tests/test_zero_order.py b/tests/test_zero_order.py index 32d3ae3b..ac7ae840 100644 --- a/tests/test_zero_order.py +++ b/tests/test_zero_order.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ import functorch import torch import torch.nn as nn +import torch.nn.functional as F import torch.types import helpers @@ -30,20 +31,17 @@ class FcNet(nn.Module): def __init__(self, dim, out): super().__init__() self.fc = nn.Linear(in_features=dim, out_features=out, bias=True) - nn.init.ones_(self.fc.weight) - nn.init.zeros_(self.fc.bias) def forward(self, x): return self.fc(x) @helpers.parametrize( - dtype=[torch.float64, torch.float32], lr=[1e-2, 1e-3], method=['naive', 'forward', 'antithetic'], sigma=[0.01, 0.1, 1], ) -def test_zero_order(dtype: torch.dtype, lr: float, method: str, sigma: float) -> None: +def test_zero_order(lr: float, method: str, sigma: float) -> None: helpers.seed_everything(42) input_size = 32 output_size = 1 @@ -56,24 +54,66 @@ def test_zero_order(dtype: torch.dtype, lr: float, method: str, sigma: float) -> fmodel, params = functorch.make_functional(model) x = torch.randn(batch_size, input_size) * coef - y = torch.randn(input_size) * coef + y = torch.randn(batch_size, 1) * coef distribution = torch.distributions.Normal(loc=0, scale=1) - @torchopt.diff.zero_order.zero_order( + @torchopt.diff.zero_order( distribution=distribution, method=method, argnums=0, sigma=sigma, num_samples=num_samples ) def forward_process(params, fn, x, y): y_pred = fn(params, x) - loss = torch.mean((y - y_pred) ** 2) + loss = F.mse_loss(y_pred, y) return loss optimizer = torchopt.adam(lr=lr) - opt_state = optimizer.init(params) + opt_state = optimizer.init(params) # init optimizer for i in range(num_iterations): - opt_state = optimizer.init(params) # init optimizer loss = forward_process(params, fmodel, x, y) # compute loss grads = torch.autograd.grad(loss, params) # compute gradients updates, opt_state = optimizer.update(grads, opt_state) # get updates params = torchopt.apply_updates(params, updates) # update network parameters + + +@helpers.parametrize( + lr=[1e-2, 1e-3], + method=['naive', 'forward', 'antithetic'], + sigma=[0.01, 0.1, 1], +) +def test_zero_order_module(lr: float, method: str, sigma: float) -> None: + helpers.seed_everything(42) + input_size = 32 + output_size = 1 + batch_size = BATCH_SIZE + coef = 0.1 + num_iterations = NUM_UPDATES + num_samples = 500 + + class FcNetWithLoss( + torchopt.nn.ZeroOrderGradientModule, method=method, sigma=sigma, num_samples=num_samples + ): + def __init__(self, dim, out): + super().__init__() + self.net = FcNet(dim, out) + self.loss = nn.MSELoss() + self.distribution = torch.distributions.Normal(loc=0, scale=1) + + def forward(self, x, y): + return self.loss(self.net(x), y) + + def sample(self, sample_shape=torch.Size()): + return self.distribution.sample(sample_shape) + + x = torch.randn(batch_size, input_size) * coef + y = torch.randn(batch_size, 1) * coef + model_with_loss = FcNetWithLoss(input_size, output_size) + + optimizer = torchopt.Adam(model_with_loss.parameters(), lr=lr) + + for i in range(num_iterations): + loss = model_with_loss(x, y) # compute loss + + optimizer.zero_grad() + loss.backward() # compute gradients + optimizer.step() # update network parameters diff --git a/torchopt/_C/adam_op.pyi b/torchopt/_C/adam_op.pyi index 39d51a5a..7ecfe7c2 100644 --- a/torchopt/_C/adam_op.pyi +++ b/torchopt/_C/adam_op.pyi @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,9 +14,8 @@ # ============================================================================== # pylint: disable=all -# isort: off -from typing import Tuple +from __future__ import annotations import torch @@ -29,7 +28,7 @@ def forward_( eps: float, eps_root: float, count: int, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ... +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ... def forward_mu(updates: torch.Tensor, mu: torch.Tensor, b1: float) -> torch.Tensor: ... def forward_nu(updates: torch.Tensor, nu: torch.Tensor, b2: float) -> torch.Tensor: ... def forward_updates( @@ -43,10 +42,10 @@ def forward_updates( ) -> torch.Tensor: ... def backward_mu( dmu: torch.Tensor, updates: torch.Tensor, mu: torch.Tensor, b1: float -) -> Tuple[torch.Tensor, torch.Tensor]: ... +) -> tuple[torch.Tensor, torch.Tensor]: ... def backward_nu( dnu: torch.Tensor, updates: torch.Tensor, nu: torch.Tensor, b2: float -) -> Tuple[torch.Tensor, torch.Tensor]: ... +) -> tuple[torch.Tensor, torch.Tensor]: ... def backward_updates( dupdates: torch.Tensor, updates: torch.Tensor, @@ -54,5 +53,6 @@ def backward_updates( new_nu: torch.Tensor, b1: float, b2: float, + eps_root: float, count: int, -) -> Tuple[torch.Tensor, torch.Tensor]: ... +) -> tuple[torch.Tensor, torch.Tensor]: ... diff --git a/torchopt/__init__.py b/torchopt/__init__.py index db78f217..0c36ac07 100644 --- a/torchopt/__init__.py +++ b/torchopt/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,13 +15,18 @@ """TorchOpt: a high-performance optimizer library built upon PyTorch.""" from torchopt import ( + accelerated_op, + alias, + base, clip, combine, diff, distributed, hook, + linalg, linear_solve, nn, + optim, pytree, schedule, typing, @@ -32,7 +37,7 @@ from torchopt.clip import clip_grad_norm from torchopt.combine import chain from torchopt.hook import register_hook -from torchopt.optim import SGD, Adam, AdamW, Optimizer, RMSProp, RMSprop, meta +from torchopt.optim import SGD, Adam, AdamW, Optimizer, RMSProp, RMSprop from torchopt.optim.func import FuncOptimizer from torchopt.optim.meta import ( MetaAdam, @@ -56,7 +61,6 @@ __all__ = [ 'accelerated_op_available', - 'diff', 'adam', 'adamw', 'rmsprop', diff --git a/torchopt/accelerated_op/__init__.py b/torchopt/accelerated_op/__init__.py index 874174f2..ede60009 100644 --- a/torchopt/accelerated_op/__init__.py +++ b/torchopt/accelerated_op/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,16 +14,17 @@ # ============================================================================== """The accelerated Ops.""" -from typing import Iterable, Optional, Union +from __future__ import annotations + +from typing import Iterable import torch from torchopt.accelerated_op.adam_op import AdamOp +from torchopt.typing import Device -def is_available( - devices: Optional[Union[int, str, torch.device, Iterable[Union[int, str, torch.device]]]] = None -) -> bool: +def is_available(devices: Device | Iterable[Device] | None = None) -> bool: """Check the availability of accelerated optimizer.""" op = AdamOp() @@ -42,5 +43,5 @@ def is_available( updates = torch.tensor(1.0, device=device) op(updates, updates, updates, 1) return True - except BaseException: # pylint: disable=broad-except + except Exception: # pylint: disable=broad-except return False diff --git a/torchopt/accelerated_op/_src/adam_op.py b/torchopt/accelerated_op/_src/adam_op.py index 65752446..ab5ea195 100644 --- a/torchopt/accelerated_op/_src/adam_op.py +++ b/torchopt/accelerated_op/_src/adam_op.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ # pylint: disable=invalid-name,too-many-arguments,unused-argument -from typing import Tuple +from __future__ import annotations import torch @@ -30,29 +30,34 @@ def forward_( eps: float, eps_root: float, count: int, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Adam forward inplace.""" - inv_one_minus_pow_b1 = 1.0 / (1.0 - pow(b1, count)) - inv_one_minus_pow_b2 = 1.0 / (1.0 - pow(b2, count)) - mu = mu.mul_(b1).add_(updates, alpha=1.0 - b1) - nu = nu.mul_(b2).add_(updates.square(), alpha=1.0 - b2) + nu = nu.mul_(b2).addcmul_(updates, updates, value=1.0 - b2) updates.copy_( - mu.mul(inv_one_minus_pow_b1).div_( - nu.mul(inv_one_minus_pow_b2).add_(eps_root).sqrt_().add_(eps) + mu.div(1.0 - pow(b1, count)).div_( + nu.div(1.0 - pow(b2, count)).add_(eps_root).sqrt_().add_(eps) ) ) return updates, mu, nu -def forward_mu(updates: torch.Tensor, mu: torch.Tensor, b1: float) -> torch.Tensor: +def forward_mu( + updates: torch.Tensor, + mu: torch.Tensor, + b1: float, +) -> torch.Tensor: """Adam forward mu.""" return mu.mul(b1).add_(updates, alpha=1.0 - b1) -def forward_nu(updates: torch.Tensor, nu: torch.Tensor, b2: float) -> torch.Tensor: +def forward_nu( + updates: torch.Tensor, + nu: torch.Tensor, + b2: float, +) -> torch.Tensor: """Adam forward nu.""" - return nu.mul(b2).add_(updates.square(), alpha=1.0 - b2) + return nu.mul(b2).addcmul_(updates, updates, value=1.0 - b2) def forward_updates( @@ -65,16 +70,17 @@ def forward_updates( count: int, ) -> torch.Tensor: """Adam forward updates.""" - inv_one_minus_pow_b1 = 1.0 / (1.0 - pow(b1, count)) - inv_one_minus_pow_b2 = 1.0 / (1.0 - pow(b2, count)) - return new_mu.mul(inv_one_minus_pow_b1).div_( - new_nu.mul(inv_one_minus_pow_b2).add_(eps_root).sqrt_().add_(eps) + return new_mu.div(1.0 - pow(b1, count)).div_( + new_nu.div(1.0 - pow(b2, count)).add_(eps_root).sqrt_().add_(eps) ) def backward_mu( - dmu: torch.Tensor, updates: torch.Tensor, mu: torch.Tensor, b1: float -) -> Tuple[torch.Tensor, torch.Tensor]: + dmu: torch.Tensor, + updates: torch.Tensor, + mu: torch.Tensor, + b1: float, +) -> tuple[torch.Tensor, torch.Tensor]: """Adam backward mu.""" dupdates = dmu.mul(1.0 - b1) dmu = dmu.mul(b1) @@ -82,8 +88,11 @@ def backward_mu( def backward_nu( - dnu: torch.Tensor, updates: torch.Tensor, nu: torch.Tensor, b2: float -) -> Tuple[torch.Tensor, torch.Tensor]: + dnu: torch.Tensor, + updates: torch.Tensor, + nu: torch.Tensor, + b2: float, +) -> tuple[torch.Tensor, torch.Tensor]: """Adam backward nu.""" dupdates = updates.mul(dnu).mul_(2.0 * (1.0 - b2)) dnu = dnu.mul(b2) @@ -97,17 +106,18 @@ def backward_updates( new_nu: torch.Tensor, b1: float, b2: float, + eps_root: float, count: int, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """Adam backward updates.""" one_minus_pow_b1 = 1.0 - pow(b1, count) - inv_one_minus_pow_b2 = 1.0 / (1.0 - pow(b2, count)) + inv_one_minus_pow_b2 = 1.0 / (1.0 - pow(b2, count) + eps_root) updates_div_new_mu = updates.div(new_mu) - denominator = updates_div_new_mu.mul_(one_minus_pow_b1) dnew_mu_out = dupdates.mul(updates_div_new_mu) + denominator = updates_div_new_mu.mul_(one_minus_pow_b1) dnew_nu_out = ( - dupdates.mul(updates).mul_(denominator.square_()).mul_(-0.5 * inv_one_minus_pow_b2) + denominator.square_().mul_(dupdates).mul_(updates).mul_(-0.5 * inv_one_minus_pow_b2) ) mask = new_mu == 0 diff --git a/torchopt/accelerated_op/adam_op.py b/torchopt/accelerated_op/adam_op.py index 56792487..232513d6 100644 --- a/torchopt/accelerated_op/adam_op.py +++ b/torchopt/accelerated_op/adam_op.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,7 +16,10 @@ # pylint: disable=c-extension-no-member,invalid-name -from typing import Any, Optional, Tuple +from __future__ import annotations + +import contextlib +from typing import Any import torch @@ -35,11 +38,11 @@ class MuOp(torch.autograd.Function): # pylint: disable=abstract-method @staticmethod def jvp(ctx: Any, *grad_inputs: Any) -> Any: - """Defines a formula for differentiating the operation with forward mode automatic differentiation.""" + """Define a formula for differentiating the operation with forward mode automatic differentiation.""" @staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: - """Performs the operation.""" + """Perform the operation.""" updates, mu, b1 = args new_mu = adam_op.forward_mu(updates, mu, b1) ctx.save_for_backward(updates, mu) @@ -49,7 +52,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: @staticmethod def backward(ctx: Any, *args: Any) -> Any: # pylint: disable-next=line-too-long - """Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the :meth:`vjp` method).""" + """Define a formula for differentiating the operation with backward mode automatic differentiation (alias to the :meth:`vjp` method).""" dmu = args[0] updates, mu = ctx.saved_tensors b1 = ctx.b1 @@ -61,11 +64,11 @@ class NuOp(torch.autograd.Function): # pylint: disable=abstract-method @staticmethod def jvp(ctx: Any, *grad_inputs: Any) -> Any: - """Defines a formula for differentiating the operation with forward mode automatic differentiation.""" + """Define a formula for differentiating the operation with forward mode automatic differentiation.""" @staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: - """Performs the operation.""" + """Perform the operation.""" updates, nu, b2 = args new_nu = adam_op.forward_nu(updates, nu, b2) ctx.save_for_backward(updates, nu) @@ -75,7 +78,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: @staticmethod def backward(ctx: Any, *args: Any) -> Any: # pylint: disable-next=line-too-long - """Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the :meth:`vjp` function).""" + """Define a formula for differentiating the operation with backward mode automatic differentiation (alias to the :meth:`vjp` function).""" dnu = args[0] updates, nu = ctx.saved_tensors b2 = ctx.b2 @@ -87,11 +90,11 @@ class UpdatesOp(torch.autograd.Function): # pylint: disable=abstract-method @staticmethod def jvp(ctx: Any, *grad_inputs: Any) -> Any: - """Defines a formula for differentiating the operation with forward mode automatic differentiation.""" + """Define a formula for differentiating the operation with forward mode automatic differentiation.""" @staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: - """Performs the operation.""" + """Perform the operation.""" new_mu, new_nu, (b1, b2, eps, eps_root, count) = args new_updates = adam_op.forward_updates(new_mu, new_nu, b1, b2, eps, eps_root, count) ctx.save_for_backward(new_updates, new_mu, new_nu) @@ -101,11 +104,13 @@ def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: @staticmethod def backward(ctx: Any, *args: Any) -> Any: # pylint: disable-next=line-too-long - """Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the :meth:`vjp` function).""" + """Define a formula for differentiating the operation with backward mode automatic differentiation (alias to the :meth:`vjp` function).""" dupdates = args[0] updates, new_mu, new_nu = ctx.saved_tensors - b1, b2, _, _, count = ctx.others - result = adam_op.backward_updates(dupdates, updates, new_mu, new_nu, b1, b2, count) + b1, b2, _, eps_root, count = ctx.others + result = adam_op.backward_updates( + dupdates, updates, new_mu, new_nu, b1, b2, eps_root, count + ) return result[0], result[1], None # pylint: disable-next=too-many-arguments @@ -118,7 +123,7 @@ def __init__( eps_root: float = 0.0, inplace: bool = True, ) -> None: - """The :meth:`__init__` function.""" + """Initialize the Adam operator.""" self.b1 = b1 self.b2 = b2 self.eps = eps @@ -126,24 +131,44 @@ def __init__( self.inplace = inplace def __call__( - self, mu: torch.Tensor, nu: torch.Tensor, updates: Optional[torch.Tensor], count: int - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """The :meth:`__call__` function.""" + self, + mu: torch.Tensor, + nu: torch.Tensor, + updates: torch.Tensor | None, + count: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: + """Apply the Adam operator.""" if updates is None: return mu, nu, None - if updates.is_cuda: - current_device = torch.cuda.current_device() - torch.cuda.set_device(updates.device) - if self.inplace: - new_updates, new_mu, new_nu = adam_op.forward_( - updates, mu, nu, self.b1, self.b2, self.eps, self.eps_root, count - ) - else: - new_mu = self.MuOp.apply(updates, mu, self.b1) - new_nu = self.NuOp.apply(updates, nu, self.b2) - new_updates = self.UpdatesOp.apply( - new_mu, new_nu, (self.b1, self.b2, self.eps, self.eps_root, count) - ) - if updates.is_cuda: - torch.cuda.set_device(current_device) + device_context = ( + torch.cuda.device(torch.cuda.current_device()) + if updates.is_cuda + else contextlib.nullcontext() + ) + with device_context: # type: ignore[attr-defined] + if self.inplace: + new_updates, new_mu, new_nu = adam_op.forward_( + updates, + mu, + nu, + self.b1, + self.b2, + self.eps, + self.eps_root, + count, + ) + else: + new_mu = self.MuOp.apply(updates, mu, self.b1) + new_nu = self.NuOp.apply(updates, nu, self.b2) + new_updates = self.UpdatesOp.apply( + new_mu, + new_nu, + ( + self.b1, + self.b2, + self.eps, + self.eps_root, + count, + ), + ) return new_mu, new_nu, new_updates diff --git a/torchopt/alias/adam.py b/torchopt/alias/adam.py index 637b40c7..08654577 100644 --- a/torchopt/alias/adam.py +++ b/torchopt/alias/adam.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,10 +31,14 @@ # ============================================================================== """Preset :class:`GradientTransformation` for the Adam optimizer.""" -from typing import Tuple +from __future__ import annotations -from torchopt.alias.utils import flip_sign_and_add_weight_decay, scale_by_neg_lr -from torchopt.combine import chain_flat +from torchopt.alias.utils import ( + _get_use_chain_flat, + flip_sign_and_add_weight_decay, + scale_by_neg_lr, +) +from torchopt.combine import chain from torchopt.transform import scale_by_accelerated_adam, scale_by_adam from torchopt.typing import GradientTransformation, ScalarOrSchedule @@ -45,7 +49,7 @@ # pylint: disable-next=too-many-arguments def adam( lr: ScalarOrSchedule = 1e-3, - betas: Tuple[float, float] = (0.9, 0.999), + betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.0, *, @@ -54,7 +58,7 @@ def adam( maximize: bool = False, use_accelerated_op: bool = False, ) -> GradientTransformation: - """The functional Adam optimizer. + """Create a functional version of the Adam optimizer. Adam is an SGD variant with learning rate adaptation. The *learning rate* used for each weight is computed from estimates of first- and second-order moments of the gradients (using suitable @@ -64,26 +68,25 @@ def adam( - Kingma et al, 2014: https://arxiv.org/abs/1412.6980 Args: - lr: (default: :const:`1e-3`) - This is a fixed global scaling factor. - betas: (default: :const:`(0.9, 0.999)`) - Coefficients used for computing running averages of gradient and its square. - eps: (default: :const:`1e-8`) - A small constant applied to denominator outside of the square root (as in the Adam - paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - eps_root: (default: :data:`0.0`) - A small constant applied to denominator inside the square root (as in RMSProp), to avoid - dividing by zero when rescaling. This is needed for example when computing - (meta-)gradients through Adam. - moment_requires_grad: (default: :data:`False`) - If :data:`True` the momentums will be created with flag ``requires_grad=True``, this - flag is often used in Meta-Learning algorithms. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - use_accelerated_op: (default: :data:`False`) - If :data:`True` use our implemented fused operator. + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + eps_root (float, optional): A small constant applied to denominator inside the square root + (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example + when computing (meta-)gradients through Adam. (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created with + flag ``requires_grad=True``, this flag is often used in Meta-Learning algorithms. + (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of minimizing. + (default: :data:`False`) + use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator. + (default: :data:`False`) Returns: The corresponding :class:`GradientTransformation` instance. @@ -93,31 +96,40 @@ def adam( """ b1, b2 = betas # pylint: disable=invalid-name # pylint: disable=unneeded-not - if not (callable(lr) or 0.0 <= lr): + if not (callable(lr) or 0.0 <= lr): # pragma: no cover raise ValueError(f'Invalid learning rate: {lr}') - if not 0.0 <= eps: + if not 0.0 <= eps: # pragma: no cover raise ValueError(f'Invalid epsilon value: {eps}') - if not 0.0 <= b1 < 1.0: + if not 0.0 <= b1 < 1.0: # pragma: no cover raise ValueError(f'Invalid beta parameter at index 0: {b1}') - if not 0.0 <= b2 < 1.0: + if not 0.0 <= b2 < 1.0: # pragma: no cover raise ValueError(f'Invalid beta parameter at index 1: {b2}') - if not 0.0 <= weight_decay: + if not 0.0 <= weight_decay: # pragma: no cover raise ValueError(f'Invalid weight_decay value: {weight_decay}') # pylint: enable=unneeded-not + chain_fn = chain + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay if use_accelerated_op: - adam_scaler = scale_by_accelerated_adam.flat # type: ignore[attr-defined] + adam_scaler_fn = scale_by_accelerated_adam else: - adam_scaler = scale_by_adam.flat # type: ignore[attr-defined] + adam_scaler_fn = scale_by_adam + scale_by_neg_lr_fn = scale_by_neg_lr - return chain_flat( - flip_sign_and_add_weight_decay(weight_decay=weight_decay, maximize=maximize), - adam_scaler( + if _get_use_chain_flat(): # default behavior + chain_fn = chain_fn.flat # type: ignore[attr-defined] + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay_fn.flat # type: ignore[attr-defined] + adam_scaler_fn = adam_scaler_fn.flat # type: ignore[attr-defined] + scale_by_neg_lr_fn = scale_by_neg_lr_fn.flat # type: ignore[attr-defined] + + return chain_fn( + flip_sign_and_add_weight_decay_fn(weight_decay=weight_decay, maximize=maximize), + adam_scaler_fn( b1=b1, b2=b2, eps=eps, eps_root=eps_root, moment_requires_grad=moment_requires_grad, ), - scale_by_neg_lr(lr), + scale_by_neg_lr_fn(lr), ) diff --git a/torchopt/alias/adamw.py b/torchopt/alias/adamw.py index b088be60..21ef84ef 100644 --- a/torchopt/alias/adamw.py +++ b/torchopt/alias/adamw.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,31 +31,37 @@ # ============================================================================== """Preset :class:`GradientTransformation` for the AdamW optimizer.""" -from typing import Any, Callable, Optional, Tuple, Union +from __future__ import annotations -from torchopt.alias.utils import flip_sign_and_add_weight_decay, scale_by_neg_lr -from torchopt.combine import chain_flat +from typing import Callable + +from torchopt.alias.utils import ( + _get_use_chain_flat, + flip_sign_and_add_weight_decay, + scale_by_neg_lr, +) +from torchopt.combine import chain from torchopt.transform import add_decayed_weights, scale_by_accelerated_adam, scale_by_adam -from torchopt.typing import GradientTransformation, Params, ScalarOrSchedule +from torchopt.typing import GradientTransformation, OptState, Params, ScalarOrSchedule __all__ = ['adamw'] -# pylint: disable-next=too-many-arguments +# pylint: disable-next=too-many-arguments,too-many-locals def adamw( lr: ScalarOrSchedule = 1e-3, - betas: Tuple[float, float] = (0.9, 0.999), + betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 1e-2, *, eps_root: float = 0.0, - mask: Optional[Union[Any, Callable[[Params], Any]]] = None, + mask: OptState | Callable[[Params], OptState] | None = None, moment_requires_grad: bool = False, maximize: bool = False, use_accelerated_op: bool = False, ) -> GradientTransformation: - """Adam with weight decay regularization. + """Create a functional version of the Adam optimizer with weight decay regularization. AdamW uses weight decay to regularize learning towards small weights, as this leads to better generalization. In SGD you can also use L2 regularization @@ -66,35 +72,34 @@ def adamw( - Loshchilov et al, 2019: https://arxiv.org/abs/1711.05101 Args: - lr: (default: :const:`1e-3`) - This is a fixed global scaling factor. - betas: (default: :const:`(0.9, 0.999)`) - Coefficients used for computing running averages of gradient and its square. - eps: (default: :const:`1e-8`) - A small constant applied to denominator outside of the square root (as in the Adam - paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`1e-2`) - Strength of the weight decay regularization. Note that this weight decay is multiplied - with the learning rate. This is consistent with other frameworks such as PyTorch, but - different from (Loshchilov et al, 2019) where the weight decay is only multiplied with - the "schedule multiplier", but not the base learning rate. - eps_root: (default: :data:`0.0`) - A small constant applied to denominator inside the square root (as in RMSProp), to avoid - dividing by zero when rescaling. This is needed for example when computing - (meta-)gradients through Adam. - mask: (default: :data:`None`) - A tree with same structure as (or a prefix of) the params PyTree, or a Callable that + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with other + frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight + decay is only multiplied with the "schedule multiplier", but not the base learning rate. + (default: :const:`1e-2`) + eps_root (float, optional): A small constant applied to denominator inside the square root + (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example + when computing (meta-)gradients through Adam. (default: :const:`0.0`) + mask (tree of Tensor, callable, or None, optional): + A tree with same structure as (or a prefix of) the params pytree, or a function that returns such a pytree given the params/updates. The leaves should be booleans, :data:`True` for leaves/subtrees you want to apply the weight decay to, and - :data:`False` for those you want to skip. Note that the Adam gradient - transformations are applied to all parameters. - moment_requires_grad: (default: :data:`False`) - If :data:`True` the momentums will be created with flag ``requires_grad=True``, this - flag is often used in Meta-Learning algorithms. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - use_accelerated_op: (default: :data:`False`) - If :data:`True` use our implemented fused operator. + :data:`False` for those you want to skip. Note that the Adam gradient transformations + are applied to all parameters. (default: :data:`None`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created with + flag ``requires_grad=True``, this flag is often used in Meta-Learning algorithms. + (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) + use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator. + (default: :data:`False`) Returns: The corresponding :class:`GradientTransformation` instance. @@ -104,32 +109,43 @@ def adamw( """ b1, b2 = betas # pylint: disable=invalid-name # pylint: disable=unneeded-not - if not (callable(lr) or 0.0 <= lr): + if not (callable(lr) or 0.0 <= lr): # pragma: no cover raise ValueError(f'Invalid learning rate: {lr}') - if not 0.0 <= eps: + if not 0.0 <= eps: # pragma: no cover raise ValueError(f'Invalid epsilon value: {eps}') - if not 0.0 <= b1 < 1.0: + if not 0.0 <= b1 < 1.0: # pragma: no cover raise ValueError(f'Invalid beta parameter at index 0: {b1}') - if not 0.0 <= b2 < 1.0: + if not 0.0 <= b2 < 1.0: # pragma: no cover raise ValueError(f'Invalid beta parameter at index 1: {b2}') - if not 0.0 <= weight_decay: + if not 0.0 <= weight_decay: # pragma: no cover raise ValueError(f'Invalid weight_decay value: {weight_decay}') # pylint: enable=unneeded-not + chain_fn = chain + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay if use_accelerated_op: - adam_scaler = scale_by_accelerated_adam.flat # type: ignore[attr-defined] + adam_scaler_fn = scale_by_accelerated_adam else: - adam_scaler = scale_by_adam.flat # type: ignore[attr-defined] + adam_scaler_fn = scale_by_adam + add_decayed_weights_fn = add_decayed_weights + scale_by_neg_lr_fn = scale_by_neg_lr + + if _get_use_chain_flat(): # default behavior + chain_fn = chain_fn.flat # type: ignore[attr-defined] + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay_fn.flat # type: ignore[attr-defined] + adam_scaler_fn = adam_scaler_fn.flat # type: ignore[attr-defined] + add_decayed_weights_fn = add_decayed_weights_fn.flat # type: ignore[attr-defined] + scale_by_neg_lr_fn = scale_by_neg_lr_fn.flat # type: ignore[attr-defined] - return chain_flat( - flip_sign_and_add_weight_decay(weight_decay=0.0, maximize=maximize), - adam_scaler( + return chain_fn( + flip_sign_and_add_weight_decay_fn(weight_decay=0.0, maximize=maximize), + adam_scaler_fn( b1=b1, b2=b2, eps=eps, eps_root=eps_root, moment_requires_grad=moment_requires_grad, ), - add_decayed_weights.flat(weight_decay=weight_decay, mask=mask), # type: ignore[attr-defined] - scale_by_neg_lr(lr), + add_decayed_weights_fn(weight_decay=weight_decay, mask=mask), + scale_by_neg_lr_fn(lr), ) diff --git a/torchopt/alias/rmsprop.py b/torchopt/alias/rmsprop.py index 6d2ddeb3..f0eb92cd 100644 --- a/torchopt/alias/rmsprop.py +++ b/torchopt/alias/rmsprop.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,8 +31,12 @@ # ============================================================================== """Preset :class:`GradientTransformation` for the RMSProp optimizer.""" -from torchopt.alias.utils import flip_sign_and_add_weight_decay, scale_by_neg_lr -from torchopt.combine import chain_flat +from torchopt.alias.utils import ( + _get_use_chain_flat, + flip_sign_and_add_weight_decay, + scale_by_neg_lr, +) +from torchopt.combine import chain from torchopt.transform import scale_by_rms, scale_by_stddev, trace from torchopt.typing import GradientTransformation, ScalarOrSchedule @@ -53,7 +57,7 @@ def rmsprop( nesterov: bool = False, maximize: bool = False, ) -> GradientTransformation: - """The functional version of the RMSProp optimizer. + """Create a functional version of the RMSProp optimizer. RMSProp is an SGD variant with learning rate adaptation. The *learning rate* used for each weight is scaled by a suitable estimate of the magnitude of the gradients on previous steps. @@ -65,28 +69,25 @@ def rmsprop( - Graves, 2013: https://arxiv.org/abs/1308.0850 Args: - lr: (default: :const:`1e-2`) - This is a fixed global scaling factor. - alpha: (default: :const:`0.99`) - Smoothing constant, the decay used to track the magnitude of previous gradients. - eps: (default: :const:`1e-8`) - A small numerical constant to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - momentum: (default: :const:`0.0`) - The decay rate used by the momentum term. The momentum is not used when it is set to - :const:`0.0`. - centered: (default: :data:`False`) - If :data:`True`, use the variance of the past gradients to rescale the latest - gradients. - initial_scale: (default: :data:`0.0`) - Initialization of accumulators tracking the magnitude of previous updates. PyTorch - uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When reproducing results from a - paper, verify the value used by the authors. - nesterov: (default: :data:`False`) - Whether to use Nesterov momentum. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-2`) + alpha (float, optional): Smoothing constant, the decay used to track the magnitude of + previous gradients. (default: :const:`0.99`) + eps (float, optional): A small numerical constant to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + momentum (float, optional): The decay rate used by the momentum term. The momentum is not + used when it is set to :const:`0.0`. (default: :const:`0.0`) + centered (bool, optional): If :data:`True`, use the variance of the past gradients to + rescale the latest gradients. (default: :data:`False`) + initial_scale (float, optional): Initialization of accumulators tracking the magnitude of + previous updates. PyTorch uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When + reproducing results from a paper, verify the value used by the authors. + (default: :data:`0.0`) + nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) Returns: The corresponding :class:`GradientTransformation` instance. @@ -95,30 +96,41 @@ def rmsprop( The functional optimizer wrapper :class:`torchopt.FuncOptimizer`. """ # pylint: disable=unneeded-not - if not (callable(lr) or 0.0 <= lr): + if not (callable(lr) or 0.0 <= lr): # pragma: no cover raise ValueError(f'Invalid learning rate: {lr}') - if not 0.0 <= alpha: + if not 0.0 <= alpha: # pragma: no cover raise ValueError(f'Invalid alpha value: {alpha}') - if not 0.0 <= eps: + if not 0.0 <= eps: # pragma: no cover raise ValueError(f'Invalid epsilon value: {eps}') - if not 0.0 <= momentum: + if not 0.0 <= momentum: # pragma: no cover raise ValueError(f'Invalid momentum value: {momentum}') - if not 0.0 <= weight_decay: + if not 0.0 <= weight_decay: # pragma: no cover raise ValueError(f'Invalid weight_decay value: {weight_decay}') # pylint: enable=unneeded-not + chain_fn = chain + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay if centered: - rmsprop_scaler = scale_by_stddev.flat # type: ignore[attr-defined] + rmsprop_scaler_fn = scale_by_stddev else: - rmsprop_scaler = scale_by_rms.flat # type: ignore[attr-defined] + rmsprop_scaler_fn = scale_by_rms + trace_fn = trace + scale_by_neg_lr_fn = scale_by_neg_lr - return chain_flat( - flip_sign_and_add_weight_decay(weight_decay=weight_decay, maximize=maximize), - rmsprop_scaler( + if _get_use_chain_flat(): # default behavior + chain_fn = chain_fn.flat # type: ignore[attr-defined] + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay_fn.flat # type: ignore[attr-defined] + rmsprop_scaler_fn = rmsprop_scaler_fn.flat # type: ignore[attr-defined] + trace_fn = trace_fn.flat # type: ignore[attr-defined] + scale_by_neg_lr_fn = scale_by_neg_lr_fn.flat # type: ignore[attr-defined] + + return chain_fn( + flip_sign_and_add_weight_decay_fn(weight_decay=weight_decay, maximize=maximize), + rmsprop_scaler_fn( alpha=alpha, eps=eps, initial_scale=initial_scale, ), - trace.flat(momentum=momentum, nesterov=nesterov), # type: ignore[attr-defined] - scale_by_neg_lr(lr), + trace_fn(momentum=momentum, nesterov=nesterov), + scale_by_neg_lr_fn(lr), ) diff --git a/torchopt/alias/sgd.py b/torchopt/alias/sgd.py index af87587f..7d86b538 100644 --- a/torchopt/alias/sgd.py +++ b/torchopt/alias/sgd.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,8 +31,12 @@ # ============================================================================== """Preset :class:`GradientTransformation` for the SGD optimizer.""" -from torchopt.alias.utils import flip_sign_and_add_weight_decay, scale_by_neg_lr -from torchopt.combine import chain_flat +from torchopt.alias.utils import ( + _get_use_chain_flat, + flip_sign_and_add_weight_decay, + scale_by_neg_lr, +) +from torchopt.combine import chain from torchopt.transform import trace from torchopt.typing import GradientTransformation, ScalarOrSchedule @@ -50,7 +54,7 @@ def sgd( moment_requires_grad: bool = False, maximize: bool = False, ) -> GradientTransformation: - """The functional version of the canonical Stochastic Gradient Descent optimizer. + """Create a functional version of the canonical Stochastic Gradient Descent optimizer. This implements stochastic gradient descent. It also includes support for momentum, and nesterov acceleration, as these are standard practice when using stochastic gradient descent to train @@ -60,21 +64,19 @@ def sgd( - Sutskever et al, 2013: http://proceedings.mlr.press/v28/sutskever13.pdf Args: - lr: This is a fixed global scaling factor. - momentum: (default: :const:`0.0`) - The decay rate used by the momentum term. The momentum is not used when it is set to - :const:`0.0`. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - dampening: (default: :const:`0.0`) - Dampening for momentum. - nesterov: (default: :data:`False`) - Whether to use Nesterov momentum. - moment_requires_grad: (default: :data:`False`) - If :data:`True` the momentums will be created with flag ``requires_grad=True``, this - flag is often used in Meta-Learning algorithms. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. + lr (float or callable): This is a fixed global scaling factor or a learning rate + scheduler. + momentum (float, optional): The decay rate used by the momentum term. The momentum is not + used when it is set to :const:`0.0`. (default: :const:`0.0`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + dampening (float, optional): Dampening for momentum. (default: :const:`0.0`) + nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created with + flag ``requires_grad=True``, this flag is often used in Meta-Learning algorithms. + (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) Returns: The corresponding :class:`GradientTransformation` instance. @@ -83,23 +85,34 @@ def sgd( The functional optimizer wrapper :class:`torchopt.FuncOptimizer`. """ # pylint: disable=unneeded-not - if not (callable(lr) or 0.0 <= lr): + if not (callable(lr) or 0.0 <= lr): # pragma: no cover raise ValueError(f'Invalid learning rate: {lr}') - if not 0.0 <= momentum: + if not 0.0 <= momentum: # pragma: no cover raise ValueError(f'Invalid momentum value: {momentum}') - if not 0.0 <= weight_decay: + if not 0.0 <= weight_decay: # pragma: no cover raise ValueError(f'Invalid weight_decay value: {weight_decay}') - if nesterov and (momentum <= 0.0 or dampening != 0.0): + if nesterov and (momentum <= 0.0 or dampening != 0.0): # pragma: no cover raise ValueError('Nesterov momentum requires a momentum and zero dampening') # pylint: enable=unneeded-not - return chain_flat( - flip_sign_and_add_weight_decay(weight_decay=weight_decay, maximize=maximize), - trace.flat( # type: ignore[attr-defined] + chain_fn = chain + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay + trace_fn = trace + scale_by_neg_lr_fn = scale_by_neg_lr + + if _get_use_chain_flat(): # default behavior + chain_fn = chain_fn.flat # type: ignore[attr-defined] + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay_fn.flat # type: ignore[attr-defined] + trace_fn = trace_fn.flat # type: ignore[attr-defined] + scale_by_neg_lr_fn = scale_by_neg_lr_fn.flat # type: ignore[attr-defined] + + return chain_fn( + flip_sign_and_add_weight_decay_fn(weight_decay=weight_decay, maximize=maximize), + trace_fn( momentum=momentum, dampening=dampening, nesterov=nesterov, moment_requires_grad=moment_requires_grad, ), - scale_by_neg_lr(lr), + scale_by_neg_lr_fn(lr), ) diff --git a/torchopt/alias/utils.py b/torchopt/alias/utils.py index 3ba3b6dc..b5088164 100644 --- a/torchopt/alias/utils.py +++ b/torchopt/alias/utils.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,29 +13,90 @@ # limitations under the License. r"""Utilities for the aliases of preset :class:`GradientTransformation`\s for optimizers.""" +from __future__ import annotations + +import threading + +from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation, identity from torchopt.transform import scale, scale_by_schedule -from torchopt.transform.utils import tree_map_flat -from torchopt.typing import ScalarOrSchedule +from torchopt.transform.utils import tree_map_flat, tree_map_flat_ +from torchopt.typing import OptState, Params, ScalarOrSchedule, Updates __all__ = ['flip_sign_and_add_weight_decay', 'scale_by_neg_lr'] -def flip_sign_and_add_weight_decay(weight_decay: float = 0.0, maximize=False): - """Flips the sign of the updates and adds weight decay.""" - if not 0.0 <= weight_decay: # pylint: disable=unneeded-not +__USE_CHAIN_FLAT_LOCK = threading.Lock() +__USE_CHAIN_FLAT = True + + +def _set_use_chain_flat(use_chain_flat: bool) -> None: # only used for testing purposes + global __USE_CHAIN_FLAT # pylint: disable=global-statement + with __USE_CHAIN_FLAT_LOCK: + __USE_CHAIN_FLAT = use_chain_flat + + +def _get_use_chain_flat() -> bool: # only used for testing purposes + with __USE_CHAIN_FLAT_LOCK: + return __USE_CHAIN_FLAT + + +def flip_sign_and_add_weight_decay( + weight_decay: float = 0.0, maximize=False +) -> GradientTransformation: + """Flip the sign of the updates and adds weight decay.""" + return _flip_sign_and_add_weight_decay( + weight_decay=weight_decay, + maximize=maximize, + already_flattened=False, + ) + + +def _flip_sign_and_add_weight_decay_flat( + weight_decay: float = 0.0, maximize=False +) -> GradientTransformation: + """Flip the sign of the updates and adds weight decay.""" + return _flip_sign_and_add_weight_decay( + weight_decay=weight_decay, + maximize=maximize, + already_flattened=True, + ) + + +def _flip_sign_and_add_weight_decay( + weight_decay: float = 0.0, + maximize=False, + *, + already_flattened: bool = False, +) -> GradientTransformation: + """Flip the sign of the updates and adds weight decay.""" + # pylint: disable-next=unneeded-not + if not 0.0 <= weight_decay: # pragma: no cover raise ValueError(f'Invalid weight_decay value: {weight_decay}') if not maximize and weight_decay == 0.0: return identity() - def init_fn(params): # pylint: disable=unused-argument + if already_flattened: + tree_map = tree_map_flat + tree_map_ = tree_map_flat_ + else: + tree_map = pytree.tree_map # type: ignore[assignment] + tree_map_ = pytree.tree_map_ # type: ignore[assignment] + + def init_fn(params: Params) -> OptState: # pylint: disable=unused-argument return EmptyState() if not maximize: # gradient descent - def update_fn(updates, state, *, params=None, inplace=True): + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, + inplace: bool = True, + ) -> tuple[Updates, OptState]: assert params is not None, ( 'Parameters are required for weight decay. ' 'Call `update(updates, state, params=params)` instead.' @@ -48,35 +109,52 @@ def f(g, p): return g.add_(p, alpha=weight_decay) return g.add_(p.data, alpha=weight_decay) + updates = tree_map_(f, updates, params) + else: def f(g, p): return g.add(p, alpha=weight_decay) - updates = tree_map_flat(f, updates, params) + updates = tree_map(f, updates, params) + return updates, state else: # gradient ascent - if weight_decay == 0.0: - # pylint: disable-next=unused-argument - def update_fn(updates, state, *, params=None, inplace=True): + + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: if inplace: def f(g): return g.neg_() + updates = tree_map_(f, updates) + else: def f(g): return g.neg() - updates = tree_map_flat(f, updates) + updates = tree_map(f, updates) + return updates, state else: - def update_fn(updates, state, *, params=None, inplace=True): + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, + inplace: bool = True, + ) -> tuple[Updates, OptState]: assert params is not None, ( 'Parameters are required for weight decay. ' 'Call `update(updates, state, params=params)` instead.' @@ -85,26 +163,39 @@ def update_fn(updates, state, *, params=None, inplace=True): if inplace: def f(g, p): - if g is not None: - if g.requires_grad: - return g.neg_().add_(p, alpha=weight_decay) - return g.neg_().add_(p.data, alpha=weight_decay) - return None + if g.requires_grad: + return g.neg_().add_(p, alpha=weight_decay) + return g.neg_().add_(p.data, alpha=weight_decay) + + updates = tree_map_(f, updates, params) else: def f(g, p): return g.neg().add_(p, alpha=weight_decay) - updates = tree_map_flat(f, updates, params) + updates = tree_map(f, updates, params) + return updates, state return GradientTransformation(init_fn, update_fn) -def scale_by_neg_lr(lr: ScalarOrSchedule): - """Scales the updates by the negative learning rate.""" - if not (callable(lr) or 0.0 <= lr): +flip_sign_and_add_weight_decay.flat = _flip_sign_and_add_weight_decay_flat # type: ignore[attr-defined] +flip_sign_and_add_weight_decay.impl = _flip_sign_and_add_weight_decay # type: ignore[attr-defined] + + +def scale_by_neg_lr(lr: ScalarOrSchedule) -> GradientTransformation: + """Scale the updates by the negative learning rate.""" + return _scale_by_neg_lr(lr=lr, already_flattened=False) + + +def _scale_by_neg_lr_flat(lr: ScalarOrSchedule) -> GradientTransformation: + return _scale_by_neg_lr(lr=lr, already_flattened=True) + + +def _scale_by_neg_lr(lr: ScalarOrSchedule, *, already_flattened=False) -> GradientTransformation: + if not (callable(lr) or 0.0 <= lr): # pragma: no cover raise ValueError(f'Invalid learning rate: {lr}') if callable(lr): @@ -112,5 +203,12 @@ def scale_by_neg_lr(lr: ScalarOrSchedule): def schedule_wrapper(count): return -lr(count) # type: ignore[operator] - return scale_by_schedule.flat(schedule_wrapper) # type: ignore[attr-defined] - return scale.flat(-lr) # type: ignore[attr-defined] + return scale_by_schedule.impl( # type: ignore[attr-defined] + schedule_wrapper, + already_flattened=already_flattened, + ) + return scale.impl(-lr, already_flattened=already_flattened) # type: ignore[attr-defined] + + +scale_by_neg_lr.flat = _scale_by_neg_lr_flat # type: ignore[attr-defined] +scale_by_neg_lr.impl = _scale_by_neg_lr # type: ignore[attr-defined] diff --git a/torchopt/base.py b/torchopt/base.py index 5706957e..b250c387 100644 --- a/torchopt/base.py +++ b/torchopt/base.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,13 +31,15 @@ # ============================================================================== """The base classes for gradient transformation.""" +from __future__ import annotations + import itertools from abc import abstractmethod -from typing import TYPE_CHECKING, Callable, NamedTuple, Optional, Tuple +from typing import TYPE_CHECKING, Callable, NamedTuple from typing_extensions import Protocol # Python 3.8+ -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from torchopt.typing import OptState, Params, Updates @@ -67,12 +69,11 @@ class TransformInitFn(Protocol): # pylint: disable=too-few-public-methods """ @abstractmethod - def __call__(self, params: 'Params') -> 'OptState': - """The ``init`` function. + def __call__(self, params: Params) -> OptState: + """Initialize the gradient transformation state. Args: - params: - The initial value of the parameters. + params (tree of Tensor): The initial value of the parameters. Returns: The initial state of the gradient transformation. @@ -93,21 +94,21 @@ class TransformUpdateFn(Protocol): # pylint: disable=too-few-public-methods @abstractmethod def __call__( self, - updates: 'Updates', - state: 'OptState', + updates: Updates, + state: OptState, *, - params: Optional['Params'] = None, + params: Params | None = None, inplace: bool = True, - ) -> Tuple['Updates', 'OptState']: - """The ``update`` function. + ) -> tuple[Updates, OptState]: + """Transform the updates and state. Args: - updates: A tree of candidate updates. - state: The state of the gradient transformation. - params: (optional) - The current value of the parameters. - inplace: (optional) - If :data:`True`, modify updates and state using inplace operations. + updates (tree of Tensor): A tree of candidate updates. + state (tree of Tensor): The state of the gradient transformation. + params (tree of Tensor or None, optional): The current value of the parameters. + (default: :data:`None`) + inplace (bool, optional): If :data:`True`, modify updates and state using inplace + operations. (default: :data:`True`) Returns: The transformed ``updates``, and the updated ``state``. @@ -134,9 +135,9 @@ class GradientTransformation(NamedTuple): optimizer state. update: A pure function which takes as input a pytree of updates (with the same tree structure - as the original params ``pytree`` passed to :attr:`init`), the previous optimizer state - (which may have been initialized using the :attr:`init` function), and optionally the - ``inplace`` flag. The :attr:`update` function then returns the computed gradient + as the original params ``pytree`` passed to ``init``), the previous optimizer state + (which may have been initialized using the ``init`` function), and optionally the + ``inplace`` flag. The ``update`` function then returns the computed gradient updates, and a updates optimizer state. If the ``inplace`` flag is :data:`True`, the output results are the same instance as the input. """ @@ -145,7 +146,7 @@ class GradientTransformation(NamedTuple): update: TransformUpdateFn # pylint: disable-next=redefined-builtin - def chain(self, next: 'GradientTransformation') -> 'ChainedGradientTransformation': + def chain(self, next: GradientTransformation) -> ChainedGradientTransformation: """Chain two gradient transformations together.""" return ChainedGradientTransformation(self, next) @@ -157,10 +158,10 @@ class ChainedGradientTransformation(GradientTransformation): gradient transformations. """ - transformations: Tuple[GradientTransformation, ...] + transformations: tuple[GradientTransformation, ...] - def __new__(cls, *transformations: GradientTransformation) -> 'ChainedGradientTransformation': - """Creates a new chained gradient transformation.""" + def __new__(cls, *transformations: GradientTransformation) -> ChainedGradientTransformation: + """Create a new chained gradient transformation.""" transformations = tuple( itertools.chain.from_iterable( t.transformations @@ -170,12 +171,21 @@ def __new__(cls, *transformations: GradientTransformation) -> 'ChainedGradientTr ) ) + if len(transformations) == 0: + transformations = (IdentityGradientTransformation(),) + init_fns, update_fns = tuple(zip(*transformations)) - def init_fn(params): + def init_fn(params: Params) -> OptState: return tuple(fn(params) for fn in init_fns) - def update_fn(updates, state, *, params=None, inplace=True): + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, + inplace: bool = True, + ) -> tuple[Updates, OptState]: if len(update_fns) != len(state): raise ValueError( 'The number of updates and states has to be the same in chain! Make sure you' @@ -191,16 +201,15 @@ def update_fn(updates, state, *, params=None, inplace=True): instance.transformations = transformations return instance - def __str__(self) -> str: - """Returns a string representation of the chained gradient transformation.""" - return '{}(\n {}\n)'.format( - self.__class__.__name__, ',\n '.join(repr(t) for t in self.transformations) + def __repr__(self) -> str: + """Return a string representation of the chained gradient transformation.""" + return '{}(\n {},\n)'.format( + self.__class__.__name__, + ',\n '.join(repr(t) for t in self.transformations), ) - __repr__ = __str__ - def __eq__(self, other: object) -> bool: - """Returns whether two chained gradient transformations are equal.""" + """Return whether two chained gradient transformations are equal.""" if isinstance(other, ChainedGradientTransformation): return self.transformations == other.transformations if isinstance(other, GradientTransformation): @@ -208,19 +217,19 @@ def __eq__(self, other: object) -> bool: return False def __hash__(self) -> int: - """Returns the hash of the chained gradient transformation.""" + """Return the hash of the chained gradient transformation.""" return hash(self.transformations) - def __getstate__(self) -> Tuple[GradientTransformation, ...]: - """Returns the state of the chained gradient transformation for serialization.""" + def __getstate__(self) -> tuple[GradientTransformation, ...]: + """Return the state of the chained gradient transformation for serialization.""" return self.transformations - def __setstate__(self, state: Tuple[GradientTransformation, ...]) -> None: - """Sets the state of the chained gradient transformation from serialization.""" + def __setstate__(self, state: tuple[GradientTransformation, ...]) -> None: + """Set the state of the chained gradient transformation from serialization.""" self.transformations = state - def __reduce__(self) -> Tuple[Callable, Tuple[Tuple[GradientTransformation, ...]]]: - """Serialization support for chained gradient transformation.""" + def __reduce__(self) -> tuple[Callable, tuple[tuple[GradientTransformation, ...]]]: + """Serialize the chained gradient transformation.""" return ChainedGradientTransformation, (self.transformations,) @@ -232,19 +241,19 @@ def __new__(cls): return super().__new__(cls, init=cls.init_fn, update=cls.update_fn) @staticmethod - def init_fn(params: 'Params') -> 'OptState': # pylint: disable=unused-argument - """Returns empty state.""" + def init_fn(params: Params) -> OptState: # pylint: disable=unused-argument + """Return empty state.""" return EmptyState() @staticmethod def update_fn( - updates: 'Updates', - state: 'OptState', + updates: Updates, + state: OptState, *, - params: Optional['Params'] = None, # pylint: disable=unused-argument + params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, # pylint: disable=unused-argument - ) -> Tuple['Updates', 'OptState']: - """Returns updates unchanged.""" + ) -> tuple[Updates, OptState]: + """Return updates unchanged.""" return updates, state diff --git a/torchopt/clip.py b/torchopt/clip.py index 29c26032..b2aafb48 100644 --- a/torchopt/clip.py +++ b/torchopt/clip.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,12 +17,13 @@ # ============================================================================== """Utilities for gradient clipping.""" -from typing import Union +from __future__ import annotations import torch from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation +from torchopt.typing import OptState, Params, Updates __all__ = ['clip_grad_norm'] @@ -32,27 +33,34 @@ def clip_grad_norm( - max_norm: Union[float, int], - norm_type: Union[float, int] = 2.0, + max_norm: float | int, + norm_type: float | int = 2.0, error_if_nonfinite: bool = False, ) -> GradientTransformation: - """Clips gradient norm of an iterable of parameters. + """Clip gradient norm of an iterable of parameters. Args: max_norm (float or int): The maximum absolute value for each element in the update. - norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for - infinity norm. - error_if_nonfinite (bool): if :data:`True`, an error is thrown if the total norm of the - gradients from :attr:`updates` is ``nan``, ``inf``, or ``-inf``. + norm_type (float or int, optional): Type of the used p-norm. Can be ``'inf'`` for infinity + norm. (default: :const:`2.0`) + error_if_nonfinite (bool, optional): If :data:`True`, an error is thrown if the total norm + of the gradients from ``updates`` is ``nan``, ``inf``, or ``-inf``. + (default: :data:`False`) Returns: An ``(init_fn, update_fn)`` tuple. """ - def init_fn(params): # pylint: disable=unused-argument + def init_fn(params: Params) -> OptState: # pylint: disable=unused-argument return ClipState() - def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: available_updates = pytree.tree_leaves(updates) if len(available_updates) == 0: return updates, state diff --git a/torchopt/combine.py b/torchopt/combine.py index 26f66214..0f1ed8ec 100644 --- a/torchopt/combine.py +++ b/torchopt/combine.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,24 +31,26 @@ # ============================================================================== """Utilities to define a chained transformation.""" +from __future__ import annotations + from torchopt import pytree from torchopt.base import ChainedGradientTransformation, GradientTransformation, identity -from torchopt.typing import Updates +from torchopt.typing import OptState, Params, Updates __all__ = ['chain', 'chain_flat'] def chain(*transformations: GradientTransformation) -> GradientTransformation: - """Applies a list of chainable update transformations. + """Apply a list of chainable update transformations. Given a sequence of chainable transforms, :func:`chain` returns an :func:`init_fn` that constructs a ``state`` by concatenating the states of the individual transforms, and returns an :func:`update_fn` which chains the update transformations feeding the appropriate state to each. Args: - *transformations: - A sequence of chainable ``(init_fn, update_fn)`` tuples. + *transformations (iterable of GradientTransformation): A sequence of chainable + ``(init_fn, update_fn)`` tuples. Returns: A single ``(init_fn, update_fn)`` tuple. @@ -61,11 +63,11 @@ def chain(*transformations: GradientTransformation) -> GradientTransformation: def chain_flat(*transformations: GradientTransformation) -> GradientTransformation: - """Wraps around the inner transformations that manipulates the flattened tree structure (:class:``list``). + """Wrap around the inner transformations that manipulate the flattened tree structure (:class:``list``). Args: - *transformations: - A sequence of chainable ``(init_fn, update_fn)`` tuples. + *transformations (iterable of GradientTransformation): A sequence of chainable + ``(init_fn, update_fn)`` tuples. Returns: A single ``(init_fn, update_fn)`` tuple. @@ -77,10 +79,16 @@ def chain_flat(*transformations: GradientTransformation) -> GradientTransformati else: inner = chain(*transformations) - def init_fn(params): + def init_fn(params: Params) -> OptState: return inner.init(pytree.tree_leaves(params, none_is_leaf=True)) - def update_fn(updates, state, *, params=None, inplace=True): + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, + inplace: bool = True, + ) -> tuple[Updates, OptState]: flat_updates, treespec = pytree.tree_flatten(updates, none_is_leaf=True) if params is not None: flat_params = pytree.tree_leaves(params, none_is_leaf=True) diff --git a/torchopt/diff/__init__.py b/torchopt/diff/__init__.py index 45674fcf..984841ed 100644 --- a/torchopt/diff/__init__.py +++ b/torchopt/diff/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,3 +16,4 @@ from torchopt.diff import implicit, zero_order from torchopt.diff.implicit import ImplicitMetaGradientModule +from torchopt.diff.zero_order import ZeroOrderGradientModule diff --git a/torchopt/diff/implicit/decorator.py b/torchopt/diff/implicit/decorator.py index aaeda594..a5908963 100644 --- a/torchopt/diff/implicit/decorator.py +++ b/torchopt/diff/implicit/decorator.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,9 +16,11 @@ # pylint: disable=invalid-name +from __future__ import annotations + import functools import inspect -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, Sequence, Tuple import functorch import torch @@ -47,7 +49,7 @@ def __init__( optimality_fn: Callable[..., TensorOrTensors], solution: TensorOrTensors, output_is_tensor: bool, - argnums: Tuple[int, ...], + argnums: tuple[int, ...], *args: Any, ) -> None: self.optimality_fn = optimality_fn @@ -88,10 +90,9 @@ def _root_vjp( args: Args, grad_outputs: TupleOfTensors, output_is_tensor: bool, - argnums: Tuple[int, ...], + argnums: tuple[int, ...], solve: Callable[..., TensorOrTensors] = linear_solve.solve_normal_cg(), ) -> TupleOfOptionalTensors: - if output_is_tensor: def optimality_cond(solution: TupleOfTensors) -> TensorOrTensors: @@ -146,14 +147,14 @@ def matvec(u: TupleOfTensors) -> TupleOfTensors: return tuple(true_output) -def _extract_kwargs(kwarg_keys: Sequence[str], flat_args: Tuple[Any, ...]) -> Tuple[Args, KwArgs]: +def _extract_kwargs(kwarg_keys: Sequence[str], flat_args: tuple[Any, ...]) -> tuple[Args, KwArgs]: nargs = len(flat_args) - len(kwarg_keys) args, kwarg_vals = flat_args[:nargs], flat_args[nargs:] kwargs = dict(zip(kwarg_keys, kwarg_vals)) return args, kwargs -def _signature_bind(signature: inspect.Signature, *args: Any, **kwargs: Any) -> Tuple[Args, KwArgs]: +def _signature_bind(signature: inspect.Signature, *args: Any, **kwargs: Any) -> tuple[Args, KwArgs]: bound = signature.bind(*args, **kwargs) bound.apply_defaults() return bound.args, bound.kwargs @@ -161,7 +162,7 @@ def _signature_bind(signature: inspect.Signature, *args: Any, **kwargs: Any) -> def _signature_bind_and_match( signature: inspect.Signature, *args: Any, **kwargs: Any -) -> Tuple[Args, KwArgs, Callable[[Args], Tuple[Args, KwArgs]]]: +) -> tuple[Args, KwArgs, Callable[[Args], tuple[Args, KwArgs]]]: # We want to bind *args and **kwargs based on the provided signature, but also to associate the # resulting positional arguments back. To achieve this, we lift arguments to a triple: # @@ -194,13 +195,13 @@ def map_args_back(out_args): def _split_tensor_and_others( - mixed_tuple: Tuple[Any, ...], -) -> Tuple[pytree.PyTreeSpec, Tuple[bool, ...], TupleOfTensors, Tuple[Any, ...]]: - flattened: List[Any] + mixed_tuple: tuple[Any, ...], +) -> tuple[pytree.PyTreeSpec, tuple[bool, ...], TupleOfTensors, tuple[Any, ...]]: + flattened: list[Any] flattened, treespec = pytree.tree_flatten(mixed_tuple, none_is_leaf=True) # type: ignore[arg-type] tensors: ListOfTensors = [] - non_tensors: List[Any] = [] - is_tensor_mask: List[bool] = [] + non_tensors: list[Any] = [] + is_tensor_mask: list[bool] = [] for item in flattened: is_tensor = isinstance(item, torch.Tensor) is_tensor_mask.append(is_tensor) @@ -213,10 +214,10 @@ def _split_tensor_and_others( def _merge_tensor_and_others( treespec: pytree.PyTreeSpec, - is_tensor_mask: Tuple[bool, ...], + is_tensor_mask: tuple[bool, ...], tensors: TupleOfTensors, - non_tensors: Tuple[Any, ...], -) -> Tuple[Any, ...]: + non_tensors: tuple[Any, ...], +) -> tuple[Any, ...]: tensor_counter = 0 non_tensor_counter = 0 results = [] @@ -232,13 +233,13 @@ def _merge_tensor_and_others( # pylint: disable-next=too-many-arguments,too-many-statements def _custom_root( - solver_fn: Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]], + solver_fn: Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]], optimality_fn: Callable[..., TensorOrTensors], solve: Callable[..., TensorOrTensors], - argnums: Tuple[int, ...], + argnums: tuple[int, ...], has_aux: bool, - reference_signature: Optional[Union[inspect.Signature, Callable]] = None, -) -> Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]]: + reference_signature: inspect.Signature | Callable | None = None, +) -> Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]]: solver_fn_signature = inspect.signature(solver_fn) if reference_signature is None: @@ -250,16 +251,16 @@ def _custom_root( reference_signature = inspect.signature(fn) def make_custom_vjp_solver_fn( - solver_fn: Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]], + solver_fn: Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]], kwarg_keys: Sequence[str], - args_signs: Tuple[Tuple[int, int, Optional[Union[Type[tuple], Type[list]]]], ...], - ) -> Type[Function]: + args_signs: tuple[tuple[int, int, type[tuple] | type[list] | None], ...], + ) -> type[Function]: # pylint: disable-next=missing-class-docstring,abstract-method class ImplicitMetaGradient(Function): @staticmethod def forward( # type: ignore[override] # pylint: disable=arguments-differ ctx: Any, *flat_args: Any - ) -> Tuple[Any, ...]: + ) -> tuple[Any, ...]: output, aux, output_is_tensor = None, None, False args = [] @@ -289,6 +290,7 @@ def forward( # type: ignore[override] # pylint: disable=arguments-differ f'solver_fn should be a torch.Tensor or a tuple of torch.Tensor. ' f'Got {output}' ) + output = tuple(t.data for t in output) ( args_treespec, @@ -361,12 +363,12 @@ def backward( # pylint: disable=too-many-locals @functools.wraps(solver_fn) def wrapped_solver_fn( *args: Any, **kwargs: Any - ) -> Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]: + ) -> TensorOrTensors | tuple[TensorOrTensors, Any]: args, kwargs = _signature_bind(solver_fn_signature, *args, **kwargs) keys, vals = list(kwargs.keys()), list(kwargs.values()) - args_signs: List[Tuple[int, int, Optional[Union[Type[tuple], Type[list]]]]] = [] - flat_args: List[Any] = [] + args_signs: list[tuple[int, int, type[tuple] | type[list] | None]] = [] + flat_args: list[Any] = [] args_offset = 0 for idx, arg in enumerate(args): if idx in argnums: @@ -410,14 +412,14 @@ def wrapped_solver_fn( def custom_root( optimality_fn: Callable[..., TensorOrTensors], - argnums: Union[int, Tuple[int, ...]], + argnums: int | tuple[int, ...], has_aux: bool = False, solve: Callable[..., TensorOrTensors] = linear_solve.solve_normal_cg(), ) -> Callable[ - [Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]]], - Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]], + [Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]]], + Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]], ]: - """Decorator for adding implicit differentiation to a root solver. + """Return a decorator for adding implicit differentiation to a root solver. This wrapper should be used as a decorator: @@ -442,18 +444,17 @@ def solver_fn(params, arg1, arg2, ...): **In best practice, the ``optimality_fn`` should have the same signature as ``solver_fn``.** Args: - optimality_fn: (callable) - An equation function, ``optimality_fn(params, *args)``. The invariant is - ``optimality_fn(solution, *args) == 0`` at the solution / root of ``solution``. - argnums: (int or tuple of ints) - Specifies arguments to compute gradients with respect to. The ``argnums`` can be an - integer or a tuple of integers, which respect to the zero-based indices of the arguments - of the ``solver_fn(params, *args)`` function. The argument ``params`` is included - for the counting, while it is indexed as ``argnums=0``. - has_aux: (default: :data:`False`) - Whether the decorated solver function returns auxiliary data. - solve: (callable, optional, default: :func:`linear_solve.solve_normal_cg`) - a linear solver of the form ``solve(matvec, b)``. + optimality_fn (callable): An equation function, ``optimality_fn(params, *args)``. The + invariant is ``optimality_fn(solution, *args) == 0`` at the solution / root of + ``solution``. + argnums (int or tuple of int): Specifies arguments to compute gradients with respect to. The + ``argnums`` can be an integer or a tuple of integers, which respect to the zero-based + indices of the arguments of the ``solver_fn(params, *args)`` function. The argument + ``params`` is included for the counting, while it is indexed as ``argnums=0``. + has_aux (bool, optional): Whether the decorated solver function returns auxiliary data. + (default: :data:`False`) + solve (callable, optional): A linear solver of the form ``solve(matvec, b)``. + (default: :func:`linear_solve.solve_normal_cg`) Returns: A solver function decorator, i.e., ``custom_root(optimality_fn)(solver_fn)``. diff --git a/torchopt/diff/implicit/nn/__init__.py b/torchopt/diff/implicit/nn/__init__.py index 95a2ea85..5bc7aa8d 100644 --- a/torchopt/diff/implicit/nn/__init__.py +++ b/torchopt/diff/implicit/nn/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,9 +14,10 @@ # ============================================================================== """The base class for differentiable implicit meta-gradient models.""" -# Preload to resolve circular references -import torchopt.nn.module # pylint: disable=unused-import +import torchopt.nn.module # preload to resolve circular references from torchopt.diff.implicit.nn.module import ImplicitMetaGradientModule __all__ = ['ImplicitMetaGradientModule'] + +del torchopt diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py index ed27b14c..bbae37c9 100644 --- a/torchopt/diff/implicit/nn/module.py +++ b/torchopt/diff/implicit/nn/module.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,158 +16,139 @@ # pylint: disable=redefined-builtin -import contextlib +from __future__ import annotations + +import abc import functools import itertools -from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, Type +from typing import Any, Iterable import functorch import torch -from torchopt import pytree from torchopt.diff.implicit.decorator import custom_root from torchopt.nn.module import MetaGradientModule -from torchopt.typing import LinearSolver, TensorTree, TupleOfTensors -from torchopt.utils import extract_module_containers +from torchopt.nn.stateless import reparametrize, swap_state +from torchopt.typing import LinearSolver, TupleOfTensors __all__ = ['ImplicitMetaGradientModule'] -def update_containers( - dst_containers: Iterable[Dict[str, Optional[torch.Tensor]]], - src_containers: Iterable[Dict[str, Optional[torch.Tensor]]], -) -> None: - """Update the tensor containers in ``dst_containers`` with the ones in ``src_containers``.""" - for src_container, dst_container in zip(src_containers, dst_containers): - dst_container.update(src_container) - - -@contextlib.contextmanager -def container_context( - orig_containers: Iterable[Dict[str, Optional[torch.Tensor]]], - args_containers: Iterable[Dict[str, Optional[torch.Tensor]]], -) -> Generator[None, None, None]: - # pylint: disable-next=line-too-long - """A context manager that temporarily updates the containers in ``orig_containers`` with the ones in ``args_containers``.""" - if not isinstance(orig_containers, (list, tuple)): - orig_containers = list(orig_containers) - orig_containers_backups = [container.copy() for container in orig_containers] - try: - update_containers(orig_containers, args_containers) - yield - finally: - update_containers(orig_containers, orig_containers_backups) +def _stateless_objective_fn( + __flat_params: TupleOfTensors, + __flat_meta_params: TupleOfTensors, + __params_names: Iterable[str], + __meta_params_names: Iterable[str], + self: ImplicitMetaGradientModule, + *input, + **kwargs, +) -> torch.Tensor: + with reparametrize( + self, + itertools.chain( + zip(__params_names, __flat_params), + zip(__meta_params_names, __flat_meta_params), + ), + ): + return self.objective(*input, **kwargs) + + +def _stateless_optimality_fn( + __flat_params: TupleOfTensors, + __flat_meta_params: TupleOfTensors, + __params_names: Iterable[str], + __meta_params_names: Iterable[str], + self: ImplicitMetaGradientModule, + *input, + **kwargs, +) -> TupleOfTensors: + with reparametrize( + self, + itertools.chain( + zip(__params_names, __flat_params), + zip(__meta_params_names, __flat_meta_params), + ), + ): + return self.optimality(*input, **kwargs) def make_optimality_from_objective( - objective: Callable[..., torch.Tensor] -) -> Callable[..., TupleOfTensors]: - """Make a function that computes the optimality function of the objective function.""" - - def optimality(self: 'ImplicitMetaGradientModule', *input, **kwargs) -> TupleOfTensors: - params_containers = extract_module_containers(self, with_buffers=False)[0] - flat_params: TupleOfTensors - # pylint: disable-next=line-too-long - flat_params, params_containers_treespec = pytree.tree_flatten_as_tuple(params_containers) # type: ignore[arg-type] - - def objective_fn(__flat_params: TupleOfTensors, *input, **kwargs) -> torch.Tensor: - flat_grad_tracking_params = __flat_params - grad_tracking_params_containers: Tuple[ - Dict[str, Optional[torch.Tensor]], ... - ] = pytree.tree_unflatten( # type: ignore[assignment] - params_containers_treespec, flat_grad_tracking_params - ) - - with container_context(params_containers, grad_tracking_params_containers): - return objective(self, *input, **kwargs) - - objective_grad_fn = functorch.grad(objective_fn, argnums=0) - flat_grads = objective_grad_fn(flat_params, *input, **kwargs) + cls: type[ImplicitMetaGradientModule], +) -> type[ImplicitMetaGradientModule]: + """Derives the optimality function of the objective function.""" + if ( + getattr(cls, 'objective', ImplicitMetaGradientModule.objective) + is ImplicitMetaGradientModule.objective + ): + raise TypeError('The objective function is not defined.') + + def optimality(self: ImplicitMetaGradientModule, *input, **kwargs) -> TupleOfTensors: + params_names, flat_params = tuple(zip(*self.named_parameters())) + meta_params_names, flat_meta_params = tuple(zip(*self.named_meta_parameters())) + + objective_grad_fn = functorch.grad(_stateless_objective_fn, argnums=0) + flat_grads = objective_grad_fn( + flat_params, + flat_meta_params, + params_names, + meta_params_names, + self, + *input, + **kwargs, + ) return flat_grads - return optimality + cls.optimality = optimality # type: ignore[assignment] + return cls def enable_implicit_gradients( - cls: Type['ImplicitMetaGradientModule'], -) -> Type['ImplicitMetaGradientModule']: - """Enables implicit gradients for the :func:`solve` method.""" + cls: type[ImplicitMetaGradientModule], +) -> type[ImplicitMetaGradientModule]: + """Enable implicit gradients for the :func:`solve` method.""" cls_solve = cls.solve if getattr(cls_solve, '__implicit_gradients_enabled__', False): raise TypeError('Implicit gradients are already enabled for the `solve` method.') if cls.linear_solve is not None: - solve_kwargs = dict(solve=cls.linear_solve) + solve_kwargs = {'solve': cls.linear_solve} else: solve_kwargs = {} - @functools.wraps(cls_solve) - def wrapped( # pylint: disable=too-many-locals - self: 'ImplicitMetaGradientModule', *input, **kwargs - ) -> Any: + @custom_root(_stateless_optimality_fn, argnums=1, has_aux=True, **solve_kwargs) + def stateless_solver_fn( + # pylint: disable=unused-argument + __flat_params: TupleOfTensors, + __flat_meta_params: TupleOfTensors, + __params_names: Iterable[str], + __meta_params_names: Iterable[str], + # pylint: enable=unused-argument + self: ImplicitMetaGradientModule, + *input, + **kwargs, + ) -> tuple[TupleOfTensors, Any]: """Solve the optimization problem.""" - params_containers = extract_module_containers(self, with_buffers=False)[0] - meta_params_containers = [self._meta_parameters] # pylint: disable=protected-access - for meta_module in self.meta_children(): - meta_params_containers.extend( - extract_module_containers(meta_module, with_buffers=False)[0] - ) - meta_params_containers = tuple(meta_params_containers) + output = cls_solve(self, *input, **kwargs) + flat_optimal_params = tuple(p.detach_() for p in self.parameters()) + return flat_optimal_params, output - flat_params: TupleOfTensors - flat_meta_params: TupleOfTensors - flat_params, params_containers_treespec = pytree.tree_flatten_as_tuple( - params_containers # type: ignore[arg-type] - ) - flat_meta_params, meta_params_containers_treespec = pytree.tree_flatten_as_tuple( - meta_params_containers # type: ignore[arg-type] - ) - - def optimality_fn( - __flat_params: TupleOfTensors, - __flat_meta_params: TupleOfTensors, - *input, - **kwargs, - ) -> TupleOfTensors: - flat_grad_tracking_params = __flat_params - grad_tracking_params_containers: Tuple[ - Dict[str, Optional[torch.Tensor]], ... - ] = pytree.tree_unflatten( # type: ignore[assignment] - params_containers_treespec, flat_grad_tracking_params - ) - flat_grad_tracking_meta_params = __flat_meta_params - grad_tracking_meta_params_containers: Tuple[ - Dict[str, Optional[torch.Tensor]], ... - ] = pytree.tree_unflatten( # type: ignore[assignment] - meta_params_containers_treespec, flat_grad_tracking_meta_params - ) - - with container_context( - itertools.chain( - params_containers, - meta_params_containers, - ), - itertools.chain( - grad_tracking_params_containers, - grad_tracking_meta_params_containers, - ), - ): - return self.optimality(*input, **kwargs) - - @custom_root(optimality_fn, argnums=1, has_aux=True, **solve_kwargs) - def solver_fn( - __flat_params: TupleOfTensors, # pylint: disable=unused-argument - __flat_meta_params: TupleOfTensors, # pylint: disable=unused-argument + @functools.wraps(cls_solve) + def wrapped(self: ImplicitMetaGradientModule, *input, **kwargs) -> Any: + """Solve the optimization problem.""" + params_names, flat_params = tuple(zip(*self.named_parameters())) + meta_params_names, flat_meta_params = tuple(zip(*self.named_meta_parameters())) + + flat_optimal_params, output = stateless_solver_fn( + flat_params, + flat_meta_params, + params_names, + meta_params_names, + self, *input, **kwargs, - ) -> Tuple[TupleOfTensors, Any]: - output = cls_solve(self, *input, **kwargs) - flat_optimal_params: TupleOfTensors = tuple(pytree.tree_leaves(params_containers)) # type: ignore[arg-type] - return flat_optimal_params, output - - # pylint: disable-next=unused-variable - flat_optimal_params, output = solver_fn(flat_params, flat_meta_params, *input, **kwargs) + ) + swap_state(self, zip(params_names, flat_optimal_params)) return output wrapped.__implicit_gradients_enabled__ = True # type: ignore[attr-defined] @@ -180,10 +161,10 @@ class ImplicitMetaGradientModule(MetaGradientModule): _custom_optimality: bool _custom_objective: bool - linear_solve: Optional[LinearSolver] + linear_solve: LinearSolver | None - def __init_subclass__(cls, linear_solve: Optional[LinearSolver] = None) -> None: - """Validates and initializes the subclass.""" + def __init_subclass__(cls, linear_solve: LinearSolver | None = None) -> None: + """Validate and initialize the subclass.""" super().__init_subclass__() cls.linear_solve = linear_solve @@ -211,15 +192,15 @@ def __init_subclass__(cls, linear_solve: Optional[LinearSolver] = None) -> None: if not callable(objective): raise TypeError('method objective() must be callable.') - cls.optimality = make_optimality_from_objective(objective) # type: ignore[assignment] + make_optimality_from_objective(cls) enable_implicit_gradients(cls) + @abc.abstractmethod def solve(self, *input, **kwargs) -> Any: - """Solves the inner optimization problem. + """Solve the inner optimization problem. .. warning:: - For gradient-based optimization methods, the parameter inputs should be explicitly specified in the :func:`torch.autograd.backward` function as argument ``inputs``. Otherwise, if not provided, the gradient is accumulated into all the leaf Tensors @@ -243,8 +224,8 @@ def solve(self, batch, labels): """ raise NotImplementedError # update parameters - def optimality(self, *input, **kwargs) -> TensorTree: - r"""Computes the optimality residual. + def optimality(self, *input, **kwargs) -> TupleOfTensors: + r"""Compute the optimality residual. This method stands for the optimality residual to the optimal parameters after solving the inner optimization problem (:meth:`solve`), i.e.: @@ -280,13 +261,14 @@ def optimality(self, *input, **kwargs) -> TensorTree: :math:`\boldsymbol{\theta}` is the joint vector of the meta-parameters. Returns: - A tree of tensors, the optimality residual to the optimal parameters after solving the - inner optimization problem. + A tuple of tensors, the optimality residual to the optimal parameters after solving the + inner optimization problem. The returned tensors should correspond to the outputs of + `tuple(self.parameters())`. """ # pylint: disable=line-too-long raise NotImplementedError def objective(self, *input, **kwargs) -> torch.Tensor: - """Computes the objective function value. + """Compute the objective function value. This method is used to calculate the :meth:`optimality` if it is not implemented. Otherwise, this method is optional. diff --git a/torchopt/diff/zero_order/__init__.py b/torchopt/diff/zero_order/__init__.py index a76dcb9a..5b85d03d 100644 --- a/torchopt/diff/zero_order/__init__.py +++ b/torchopt/diff/zero_order/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,10 +17,12 @@ import sys as _sys from types import ModuleType as _ModuleType +from torchopt.diff.zero_order import nn from torchopt.diff.zero_order.decorator import zero_order +from torchopt.diff.zero_order.nn import ZeroOrderGradientModule -__all__ = ['zero_order'] +__all__ = ['zero_order', 'ZeroOrderGradientModule'] class _CallableModule(_ModuleType): # pylint: disable=too-few-public-methods diff --git a/torchopt/diff/zero_order/decorator.py b/torchopt/diff/zero_order/decorator.py index 361da4ff..43522028 100644 --- a/torchopt/diff/zero_order/decorator.py +++ b/torchopt/diff/zero_order/decorator.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,8 +14,10 @@ # ============================================================================== """Zero-Order Gradient Estimation.""" +from __future__ import annotations + import functools -from typing import Any, Callable, List, Tuple, Union +from typing import Any, Callable, Sequence from typing_extensions import Literal # Python 3.8+ from typing_extensions import TypeAlias # Python 3.10+ @@ -23,14 +25,7 @@ from torch.autograd import Function from torchopt import pytree -from torchopt.typing import ( - ListOfTensors, - Numeric, - Samplable, - SampleFunc, - Sequence, - TupleOfOptionalTensors, -) +from torchopt.typing import ListOfTensors, Numeric, Samplable, SampleFunc, TupleOfOptionalTensors class WrappedSamplable(Samplable): # pylint: disable=too-few-public-methods @@ -40,25 +35,23 @@ def __init__(self, sample_fn: SampleFunc) -> None: """Wrap a sample function to make it a :class:`Samplable` object.""" self.sample_fn = sample_fn - def sample( - self, sample_shape: torch.Size = torch.Size() - ) -> Union[torch.Tensor, Sequence[Numeric]]: + def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor | Sequence[Numeric]: # pylint: disable-next=line-too-long - """Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.""" + """Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.""" return self.sample_fn(sample_shape) def _zero_order_naive( # pylint: disable=too-many-statements fn: Callable[..., torch.Tensor], distribution: Samplable, - argnums: Tuple[int, ...], + argnums: tuple[int, ...], num_samples: int, sigma: Numeric, ) -> Callable[..., torch.Tensor]: @functools.wraps(fn) def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements diff_params = [args[argnum] for argnum in argnums] - flat_diff_params: List[Any] + flat_diff_params: list[Any] flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params) # type: ignore[arg-type] class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-method @@ -66,7 +59,7 @@ class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-m def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: flat_diff_params = args[:-1] origin_args = list(args[-1][0]) - flat_args: List[Any] + flat_args: list[Any] flat_args, args_treespec = pytree.tree_flatten(origin_args, none_is_leaf=True) # type: ignore[arg-type] ctx.args_treespec = args_treespec @@ -114,7 +107,7 @@ def backward( # pylint: disable=too-many-locals flat_args.append(non_tensors[non_tensors_counter]) non_tensors_counter += 1 - args: List[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment] + args: list[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment] def add_perturbation(tensor, noises): return tensor.add(noises, alpha=sigma) @@ -126,7 +119,7 @@ def add_perturbation(tensor, noises): flat_noisy_params = [ add_perturbation(t, n) for t, n in zip(flat_diff_params, noises) ] - noisy_params: List[Any] = pytree.tree_unflatten( # type: ignore[assignment] + noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment] diff_params_treespec, flat_noisy_params ) @@ -152,14 +145,14 @@ def add_perturbation(tensor, noises): def _zero_order_forward( # pylint: disable=too-many-statements fn: Callable[..., torch.Tensor], distribution: Samplable, - argnums: Tuple[int, ...], + argnums: tuple[int, ...], num_samples: int, sigma: Numeric, ) -> Callable[..., torch.Tensor]: @functools.wraps(fn) def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements diff_params = [args[argnum] for argnum in argnums] - flat_diff_params: List[Any] + flat_diff_params: list[Any] flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params) # type: ignore[arg-type] class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-method @@ -167,7 +160,7 @@ class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-m def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: flat_diff_params = args[:-1] origin_args = list(args[-1][0]) - flat_args: List[Any] + flat_args: list[Any] flat_args, args_treespec = pytree.tree_flatten(origin_args, none_is_leaf=True) # type: ignore[arg-type] ctx.args_treespec = args_treespec @@ -216,7 +209,7 @@ def backward( # pylint: disable=too-many-locals flat_args.append(non_tensors[non_tensors_counter]) non_tensors_counter += 1 - args: List[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment] + args: list[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment] def add_perturbation(tensor, noises): return tensor.add(noises, alpha=sigma) @@ -228,7 +221,7 @@ def add_perturbation(tensor, noises): flat_noisy_params = [ add_perturbation(t, n) for t, n in zip(flat_diff_params, noises) ] - noisy_params: List[Any] = pytree.tree_unflatten( # type: ignore[assignment] + noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment] diff_params_treespec, flat_noisy_params ) @@ -255,14 +248,14 @@ def add_perturbation(tensor, noises): def _zero_order_antithetic( # pylint: disable=too-many-statements fn: Callable[..., torch.Tensor], distribution: Samplable, - argnums: Tuple[int, ...], + argnums: tuple[int, ...], num_samples: int, sigma: Numeric, ) -> Callable[..., torch.Tensor]: @functools.wraps(fn) def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements diff_params = [args[argnum] for argnum in argnums] - flat_diff_params: List[Any] + flat_diff_params: list[Any] flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params) # type: ignore[arg-type] class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-method @@ -270,7 +263,7 @@ class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-m def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: flat_diff_params = args[:-1] origin_args = list(args[-1][0]) - flat_args: List[Any] + flat_args: list[Any] flat_args, args_treespec = pytree.tree_flatten(origin_args, none_is_leaf=True) # type: ignore[arg-type] ctx.args_treespec = args_treespec @@ -316,7 +309,7 @@ def backward(ctx: Any, *grad_outputs: Any): # pylint: disable=too-many-locals flat_args.append(non_tensors[non_tensors_counter]) non_tensors_counter += 1 - args: List[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment] + args: list[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment] param_grads: ListOfTensors = [0.0 for _ in range(len(flat_diff_params))] # type: ignore[misc] @@ -325,7 +318,7 @@ def get_output(add_perturbation_fn, noises) -> torch.Tensor: add_perturbation_fn(t, n, alpha=sigma) for t, n in zip(flat_diff_params, noises) ] - noisy_params: List[Any] = pytree.tree_unflatten( # type: ignore[assignment] + noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment] diff_params_treespec, flat_noisy_params ) @@ -356,28 +349,28 @@ def get_output(add_perturbation_fn, noises) -> torch.Tensor: def zero_order( - distribution: Union[SampleFunc, Samplable], + distribution: SampleFunc | Samplable, method: Method = 'naive', - argnums: Union[int, Tuple[int, ...]] = (0,), + argnums: int | tuple[int, ...] = (0,), num_samples: int = 1, sigma: Numeric = 1.0, ) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: - """Decorator for applying zero-order differentiation. + """Return a decorator for applying zero-order differentiation. Args: - distribution: (function or Samplable) - A samplable object that has method ``samplable.sample(sample_shape)`` or a function that - takes the shape as input and returns a shaped batch of samples. This is used to sample - perturbations from the given distribution. The distribution should be sphere symmetric. - method: (str) - The algorithm to use. The currently supported algorithms are :const:`'naive'`, - :const:`'forward'`, and :const:`'antithetic'`. Defaults to :const:`'naive'`. - argnums: (int or tuple of int, default: :const:`0`) - Specifies arguments to compute gradients with respect to. - num_samples: (int, default :const:`1`) - The number of sample to get the averaged estimated gradient. - sigma: (Numeric) - The standard deviation of the perturbation. Defaults to :const:`1.0`. + distribution (callable or Samplable): A samplable object that has method + ``samplable.sample(sample_shape)`` or a function that takes the shape as input and + returns a shaped batch of samples. This is used to sample perturbations from the given + distribution. The distribution should be sphere symmetric. + method (str, optional): The algorithm to use. The currently supported algorithms are + :const:`'naive'`, :const:`'forward'`, and :const:`'antithetic'`. + (default: :const:`'naive'`) + argnums (int or tuple of int, optional): Specifies arguments to compute gradients with + respect to. (default: :const:`0`) + num_samples (int, optional): The number of sample to get the averaged estimated gradient. + (default: :const:`1`) + sigma (float or Tensor, optional): The standard deviation of the perturbation. + (default: :const:`1.0`) Returns: A function decorator that enables zero-order gradient estimation. diff --git a/torchopt/diff/zero_order/nn/__init__.py b/torchopt/diff/zero_order/nn/__init__.py new file mode 100644 index 00000000..1bf64efe --- /dev/null +++ b/torchopt/diff/zero_order/nn/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The base class for zero-order gradient models.""" + +import torchopt.nn.module # preload to resolve circular references +from torchopt.diff.zero_order.nn.module import ZeroOrderGradientModule + + +__all__ = ['ZeroOrderGradientModule'] + +del torchopt diff --git a/torchopt/diff/zero_order/nn/module.py b/torchopt/diff/zero_order/nn/module.py new file mode 100644 index 00000000..65014fb9 --- /dev/null +++ b/torchopt/diff/zero_order/nn/module.py @@ -0,0 +1,99 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The base class for zero-order gradient models.""" + +# pylint: disable=redefined-builtin + +from __future__ import annotations + +import abc +import functools +from typing import Sequence + +import torch +import torch.nn as nn + +from torchopt.diff.zero_order.decorator import Method, Samplable, zero_order +from torchopt.nn.stateless import reparametrize +from torchopt.typing import Numeric, TupleOfTensors + + +__all__ = ['ZeroOrderGradientModule'] + + +def enable_zero_order_gradients( + cls: type[ZeroOrderGradientModule], + method: Method = 'naive', + num_samples: int = 1, + sigma: Numeric = 1.0, +) -> type[ZeroOrderGradientModule]: + """Enable zero-order gradient estimation for the :func:`forward` method.""" + cls_forward = cls.forward + if getattr(cls_forward, '__zero_order_gradients_enabled__', False): + raise TypeError( + 'Zero-order gradient estimation is already enabled for the `forward` method.' + ) + + @functools.wraps(cls_forward) + def wrapped(self: ZeroOrderGradientModule, *input, **kwargs) -> torch.Tensor: + """Do the forward pass calculation.""" + params_names, flat_params = tuple(zip(*self.named_parameters())) + + @zero_order(self.sample, argnums=0, method=method, num_samples=num_samples, sigma=sigma) + def forward_fn( + __flat_params: TupleOfTensors, + *input, + **kwargs, + ) -> torch.Tensor: + with reparametrize(self, zip(params_names, __flat_params)): + return cls_forward(self, *input, **kwargs) + + return forward_fn(flat_params, *input, **kwargs) + + wrapped.__zero_order_gradients_enabled__ = True # type: ignore[attr-defined] + cls.forward = wrapped # type: ignore[assignment] + return cls + + +class ZeroOrderGradientModule(nn.Module, Samplable): + """The base class for zero-order gradient models.""" + + def __init_subclass__( # pylint: disable=arguments-differ + cls, + method: Method = 'naive', + num_samples: int = 1, + sigma: Numeric = 1.0, + ) -> None: + """Validate and initialize the subclass.""" + super().__init_subclass__() + enable_zero_order_gradients( + cls, + method=method, + num_samples=num_samples, + sigma=sigma, + ) + + @abc.abstractmethod + def forward(self, *args, **kwargs) -> torch.Tensor: + """Do the forward pass of the model.""" + raise NotImplementedError + + @abc.abstractmethod + def sample( + self, sample_shape: torch.Size = torch.Size() # pylint: disable=unused-argument + ) -> torch.Tensor | Sequence[Numeric]: + # pylint: disable-next=line-too-long + """Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.""" + raise NotImplementedError diff --git a/torchopt/distributed/__init__.py b/torchopt/distributed/__init__.py index d966691c..4272e37a 100644 --- a/torchopt/distributed/__init__.py +++ b/torchopt/distributed/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,6 +25,6 @@ __all__ = ['is_available', *api.__all__, *world.__all__] -def is_available(): +def is_available() -> bool: """Check if the distributed module is available.""" return dist.is_available() and rpc.is_available() and autograd.is_available() diff --git a/torchopt/distributed/api.py b/torchopt/distributed/api.py index 0c06fa91..b46ad67e 100644 --- a/torchopt/distributed/api.py +++ b/torchopt/distributed/api.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +14,8 @@ # ============================================================================== """Distributed APIs.""" +from __future__ import annotations + import functools import sys from typing import ( @@ -33,7 +35,7 @@ import torch import torch.distributed.rpc as rpc -import torchopt.pytree as pytree +from torchopt import pytree from torchopt.distributed.world import get_worker_id, get_world_rank, get_world_size from torchopt.typing import Future @@ -73,8 +75,8 @@ class TensorDimensionPartitioner: while the non-tensor values will be broadcasted to partitions. Args: - dim: The dimension to partition. - exclusive: Whether to partition the batch exclusively. + dim (int): The dimension to partition. + exclusive (bool, optional): Whether to partition the batch exclusively. (default: :data:`False`) If :data:`True`, the batch will be partitioned into ``batch_size`` partitions, where ``batch_size`` is the size of the batch along the given dimension. Each batch sample will be assigned to a separate RPC call. @@ -82,11 +84,12 @@ class TensorDimensionPartitioner: partitions, where ``num_workers`` is the number of workers in the world. When ``batch_size > num_workers``, there can be multiple batch samples forward in a single RPC call. - keepdim: Whether to keep the partitioned dimension. Defaults to :data:`True`, i.e., keep the - batch dimension. If :data:`False`, use select instead of slicing. This functionality - should be used with ``exclusive=True``. - workers: The workers to partition the batch to. If :data:`None`, the batch will be - partitioned to all workers in the world. + keepdim (bool, optional): Whether to keep the partitioned dimension. (default: :data:`True`) + If :data:`True`, keep the batch dimension. If :data:`False`, use select instead of + slicing. This functionality should be used with ``exclusive=True``. + workers (sequence of int or str, or None, optional): The workers to partition the batch to. + If :data:`None`, the batch will be partitioned to all workers in the world. + (default: :data:`None`) """ def __init__( @@ -95,7 +98,7 @@ def __init__( *, exclusive: bool = False, keepdim: bool = False, - workers: Optional[Sequence[Union[int, str]]] = None, + workers: Sequence[int | str] | None = None, ) -> None: """Initialize the partitioner instance.""" if not keepdim and not exclusive: @@ -111,7 +114,7 @@ def __call__( self, *args: Any, **kwargs: Any, - ) -> List[Tuple[int, Optional[Args], Optional[KwArgs]]]: + ) -> list[tuple[int, Args | None, KwArgs | None]]: """Partition the batch of inputs along the given dimension.""" if self.workers is None: workers = list(range(get_world_size())) @@ -120,7 +123,7 @@ def __call__( num_workers = len(workers) args_tree = (args, kwargs) - flat_args: List[Any] + flat_args: list[Any] flat_args, treespec = pytree.tree_flatten(args_tree) # type: ignore[arg-type] batch_size = None @@ -137,8 +140,8 @@ def __call__( if batch_size is None: return [(get_world_rank(), args, kwargs.copy())] - dim_slices: List[Union[int, slice]] - batch_slices: List[Tuple[Union[int, slice, Ellipsis.__class__], ...]] # type: ignore[name-defined] + dim_slices: list[int | slice] + batch_slices: list[tuple[int | slice | Ellipsis.__class__, ...]] # type: ignore[name-defined] if self.exclusive: num_replicas = batch_size if self.keepdim: @@ -172,7 +175,7 @@ def __call__( for dim_slice in dim_slices ] - flat_args_replicas: List[List[Any]] = [[] for _ in range(num_replicas)] + flat_args_replicas: list[list[Any]] = [[] for _ in range(num_replicas)] for arg in flat_args: if isinstance(arg, torch.Tensor): for i, batch_slice in enumerate(batch_slices): @@ -181,7 +184,7 @@ def __call__( for i in range(num_replicas): flat_args_replicas[i].append(arg) - args_replicas: List[Tuple[Args, KwArgs]] = [ + args_replicas: list[tuple[Args, KwArgs]] = [ pytree.tree_unflatten(treespec, args_replica) # type: ignore[misc] for args_replica in flat_args_replicas ] @@ -193,16 +196,16 @@ def __call__( def __reduce__( self, - ) -> Tuple[ - Callable[..., 'TensorDimensionPartitioner'], - Tuple[int], - Dict[str, Union[bool, Optional[Sequence[Union[int, str]]]]], + ) -> tuple[ + Callable[..., TensorDimensionPartitioner], + tuple[int], + dict[str, bool | Sequence[int | str] | None], ]: """Return a tuple that allows the partitioner to be pickled.""" return ( TensorDimensionPartitioner, (self.dim,), - dict(exclusive=self.exclusive, keepdim=self.keepdim, workers=self.workers), + {'exclusive': self.exclusive, 'keepdim': self.keepdim, 'workers': self.workers}, ) @@ -211,7 +214,7 @@ def dim_partitioner( *, exclusive: bool = False, keepdim: bool = True, - workers: Optional[Sequence[Union[int, str]]] = None, + workers: Sequence[int | str] | None = None, ) -> PartitionFunction: """Partition a batch of inputs along a given dimension. @@ -219,8 +222,8 @@ def dim_partitioner( while the non-tensor values will be broadcasted to partitions. Args: - dim: The dimension to partition. - exclusive: Whether to partition the batch exclusively. + dim (int, optional): The dimension to partition. (default: :const:`0`) + exclusive (bool, optional): Whether to partition the batch exclusively. (default: :data:`False`) If :data:`True`, the batch will be partitioned into ``batch_size`` partitions, where ``batch_size`` is the size of the batch along the given dimension. Each batch sample will be assigned to a separate RPC call. @@ -228,11 +231,12 @@ def dim_partitioner( partitions, where ``num_workers`` is the number of workers in the world. When ``batch_size > num_workers``, there can be multiple batch samples forward in a single RPC call. - keepdim: Whether to keep the partitioned dimension. Defaults to :data:`True`, i.e., keep the - batch dimension. If :data:`False`, use select instead of slicing. This functionality - should be used with ``exclusive=True``. - workers: The workers to partition the batch to. If :data:`None`, the batch will be - partitioned to all workers in the world. + keepdim (bool, optional): Whether to keep the partitioned dimension. (default: :data:`False`) + If :data:`True`, keep the batch dimension. If :data:`False`, use select instead of + slicing. This functionality should be used with ``exclusive=True``. + workers (sequence of int or str, or None, optional): The workers to partition the batch to. + If :data:`None`, the batch will be partitioned to all workers in the world. + (default: :data:`None`) Returns: A partition function. @@ -273,26 +277,26 @@ def sum_reducer(results: Iterable[torch.Tensor]) -> torch.Tensor: def remote_async_call( func: Callable[..., T], *, - args: Optional[Args] = None, - kwargs: Optional[KwArgs] = None, - partitioner: Optional[Partitioner] = None, - reducer: Optional[Callable[[Iterable[T]], U]] = None, - timeout: Optional[float] = UNSET_RPC_TIMEOUT, -) -> Union[Future[List[T]], Future[U]]: + args: Args | None = None, + kwargs: KwArgs | None = None, + partitioner: Partitioner | None = None, + reducer: Callable[[Iterable[T]], U] | None = None, + timeout: float | None = UNSET_RPC_TIMEOUT, +) -> Future[list[T]] | Future[U]: """Asynchronously do an RPC on remote workers and return the a :class:`torch.Future` instance at the current worker. Args: - func (Callable[..., T]): The function to call. - args (Optional[Args], optional): The arguments to pass to the function. Defaults to - :data:`None`. - kwargs (Optional[KwArgs], optional): The keyword arguments to pass to the function. Defaults - to :data:`None`. - partitioner (Partitioner, optional): A partitioner that partitions the arguments to multiple - workers. Defaults to :func:`batch_partitioner`. - reducer (Callable[[Iterable[T]], U], optional): A reducer that reduces the results from - multiple workers. Defaults to :data:`None`. - timeout (float, optional): The timeout for the RPC call. Defaults to - :data:`rpc.api.UNSET_RPC_TIMEOUT`. + func (callable): The function to call. + args (tuple of object or None, optional): The arguments to pass to the function. + (default: :data:`None`) + kwargs (dict[str, object] or None, optional): The keyword arguments to pass to the function. + (default: :data:`None`) + partitioner (int, str, or callable, optional): A partitioner that partitions the arguments + to multiple workers. (default: :func:`batch_partitioner`) + reducer (callable or None, optional): A reducer that reduces the results from multiple + workers. If :data:`None`, do not reduce the results. (default: :data:`None`) + timeout (float, optional): The timeout for the RPC call. + (default: :data:`rpc.api.UNSET_RPC_TIMEOUT`) Returns: A :class:`torch.Future` instance for the result. The result is at the current worker. @@ -330,26 +334,26 @@ def remote_async_call( def remote_sync_call( func: Callable[..., T], *, - args: Optional[Args] = None, - kwargs: Optional[KwArgs] = None, - partitioner: Optional[Partitioner] = None, - reducer: Optional[Callable[[Iterable[T]], U]] = None, - timeout: Optional[float] = UNSET_RPC_TIMEOUT, -) -> Union[List[T], U]: - """Synchronously do an RPC on remote workers and return the result to the current worker. + args: Args | None = None, + kwargs: KwArgs | None = None, + partitioner: Partitioner | None = None, + reducer: Callable[[Iterable[T]], U] | None = None, + timeout: float | None = UNSET_RPC_TIMEOUT, +) -> list[T] | U: + """Do an RPC synchronously on remote workers and return the result to the current worker. Args: - func (Callable[..., T]): The function to call. - args (Optional[Args], optional): The arguments to pass to the function. Defaults to - :data:`None`. - kwargs (Optional[KwArgs], optional): The keyword arguments to pass to the function. Defaults - to :data:`None`. - partitioner (Partitioner, optional): A partitioner that partitions the arguments to multiple - workers. Defaults to :func:`batch_partitioner`. - reducer (Callable[[Iterable[T]], U], optional): A reducer that reduces the results from - multiple workers. Defaults to :data:`None`. - timeout (float, optional): The timeout for the RPC call. Defaults to - :data:`rpc.api.UNSET_RPC_TIMEOUT`. + func (callable): The function to call. + args (tuple of object or None, optional): The arguments to pass to the function. + (default: :data:`None`) + kwargs (dict[str, object] or None, optional): The keyword arguments to pass to the function. + (default: :data:`None`) + partitioner (int, str, or callable, optional): A partitioner that partitions the arguments + to multiple workers. (default: :func:`batch_partitioner`) + reducer (callable or None, optional): A reducer that reduces the results from multiple + workers. If :data:`None`, do not reduce the results. (default: :data:`None`) + timeout (float, optional): The timeout for the RPC call. + (default: :data:`rpc.api.UNSET_RPC_TIMEOUT`) Returns: The result of the RPC call. The result is at the current worker. @@ -365,24 +369,23 @@ def remote_sync_call( def parallelize_async( - partitioner: Optional[Partitioner] = None, - reducer: Optional[Callable[[Iterable[T]], U]] = None, - timeout: Optional[float] = UNSET_RPC_TIMEOUT, -) -> Callable[[Callable[..., T]], Callable[..., Union[Future[List[T]], Future[U]]]]: - """Decorator for parallelizing a function. + partitioner: Partitioner | None = None, + reducer: Callable[[Iterable[T]], U] | None = None, + timeout: float | None = UNSET_RPC_TIMEOUT, +) -> Callable[[Callable[..., T]], Callable[..., Future[list[T]] | Future[U]]]: + """Return a decorator for parallelizing a function. This decorator can be used to parallelize a function call across multiple workers. The function will be called asynchronously on remote workers. The decorated function will return a :class:`torch.Future` instance of the result. Args: - partitioner (Partitioner, optional): A partitioner that partitions the arguments to multiple - workers. Defaults to :func:`batch_partitioner`. - reducer (Callable[[Iterable[T]], U], optional): A reducer that reduces the results from - multiple workers. Defaults to :func:`mean_reducer` if the ``partitioner`` is not - specified, i.e., :func:`batch_partitioner`. Otherwise, it defaults to :data:`None`. - timeout (float, optional): The timeout for the RPC call. Defaults to - :data:`rpc.api.UNSET_RPC_TIMEOUT`. + partitioner (int, str, or callable, optional): A partitioner that partitions the arguments + to multiple workers. (default: :func:`batch_partitioner`) + reducer (callable or None, optional): A reducer that reduces the results from multiple + workers. If :data:`None`, do not reduce the results. (default: :data:`None`) + timeout (float, optional): The timeout for the RPC call. + (default: :data:`rpc.api.UNSET_RPC_TIMEOUT`) Returns: The decorator function. @@ -392,9 +395,9 @@ def parallelize_async( if reducer is None: reducer = mean_reducer # type: ignore[assignment] - def wrapper(func: Callable[..., T]) -> Callable[..., Union[Future[List[T]], Future[U]]]: + def wrapper(func: Callable[..., T]) -> Callable[..., Future[list[T]] | Future[U]]: @functools.wraps(func) - def wrapped(*args: Any, **kwargs: Any) -> Union[Future[List[T]], Future[U]]: + def wrapped(*args: Any, **kwargs: Any) -> Future[list[T]] | Future[U]: return remote_async_call( func, args=args, @@ -423,22 +426,21 @@ def wrapped(*args: Any, **kwargs: Any) -> Union[Future[List[T]], Future[U]]: def parallelize( - partitioner: Optional[Partitioner] = None, - reducer: Optional[Callable[[Iterable[T]], U]] = None, - timeout: Optional[float] = UNSET_RPC_TIMEOUT, -) -> Callable[[Callable[..., T]], Callable[..., Union[List[T], U]]]: - """Decorator for parallelizing a function. + partitioner: Partitioner | None = None, + reducer: Callable[[Iterable[T]], U] | None = None, + timeout: float | None = UNSET_RPC_TIMEOUT, +) -> Callable[[Callable[..., T]], Callable[..., list[T] | U]]: + """Return a decorator for parallelizing a function. This decorator can be used to parallelize a function call across multiple workers. Args: - partitioner (Partitioner, optional): A partitioner that partitions the arguments to multiple - workers. Defaults to :func:`batch_partitioner`. - reducer (Callable[[Iterable[T]], U], optional): A reducer that reduces the results from - multiple workers. Defaults to :func:`mean_reducer` if the ``partitioner`` is not - specified, i.e., :func:`batch_partitioner`. Otherwise, it defaults to :data:`None`. - timeout (float, optional): The timeout for the RPC call. Defaults to - :data:`rpc.api.UNSET_RPC_TIMEOUT`. + partitioner (int, str, or callable, optional): A partitioner that partitions the arguments + to multiple workers. (default: :func:`batch_partitioner`) + reducer (callable or None, optional): A reducer that reduces the results from multiple + workers. If :data:`None`, do not reduce the results. (default: :data:`None`) + timeout (float, optional): The timeout for the RPC call. + (default: :data:`rpc.api.UNSET_RPC_TIMEOUT`) Returns: The decorator function. @@ -448,9 +450,9 @@ def parallelize( if reducer is None: reducer = mean_reducer # type: ignore[assignment] - def wrapper(func: Callable[..., T]) -> Callable[..., Union[List[T], U]]: + def wrapper(func: Callable[..., T]) -> Callable[..., list[T] | U]: @functools.wraps(func) - def wrapped(*args: Any, **kwargs: Any) -> Union[List[T], U]: + def wrapped(*args: Any, **kwargs: Any) -> list[T] | U: return remote_sync_call( func, args=args, diff --git a/torchopt/distributed/autograd.py b/torchopt/distributed/autograd.py index 41b6b461..17fa9463 100644 --- a/torchopt/distributed/autograd.py +++ b/torchopt/distributed/autograd.py @@ -14,14 +14,15 @@ # ============================================================================== """Distributed Autograd.""" +from __future__ import annotations + from threading import Lock -from typing import Optional, overload import torch import torch.distributed.autograd as autograd from torch.distributed.autograd import context -from torchopt.typing import TensorOrTensors, TupleOfOptionalTensors, TupleOfTensors +from torchopt.typing import TensorOrTensors, TupleOfOptionalTensors __all__ = ['is_available', 'context'] @@ -43,22 +44,23 @@ def backward( autograd_ctx_id: int, tensors: TensorOrTensors, retain_graph: bool = False, - inputs: Optional[TensorOrTensors] = None, + inputs: TensorOrTensors | None = None, ) -> None: """Perform distributed backward pass for local parameters. - Computes the sum of gradients of given tensors with respect to graph leaves. + Compute the sum of gradients of given tensors with respect to graph leaves. Args: - autograd_ctx_id: The autograd context id. - tensors (Sequence[Tensor] or Tensor): Tensors of which the derivative will be computed. + autograd_ctx_id (int): The autograd context id. + tensors (Tensor or sequence of Tensor): Tensors of which the derivative will be computed. retain_graph (bool, optional): If :data:`False`, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to :data:`True` is not needed and often can be worked around in a much more efficient way. - inputs (Sequence[Tensor] or Tensor, optional): Inputs w.r.t. which the gradient be will - accumulated into ``.grad``. All other Tensors will be ignored. If not provided, the - gradient is accumulated into all the leaf Tensors that were used to compute the - attr::tensors. + (default: :data:`False`) + inputs (Tensor, sequence of Tensor, or None, optional): Inputs w.r.t. which the gradient + be will accumulated into ``.grad``. All other Tensors will be ignored. If not + provided, the gradient is accumulated into all the leaf Tensors that were used to + compute the ``tensors``. (default: :data:`None`) """ if inputs is not None: if isinstance(inputs, torch.Tensor): @@ -85,25 +87,6 @@ def backward( else: p.grad = g - @overload - def grad( - autograd_ctx_id: int, - outputs: TensorOrTensors, - inputs: TensorOrTensors, - retain_graph: bool = False, - ) -> TupleOfTensors: - ... - - @overload - def grad( - autograd_ctx_id: int, - outputs: TensorOrTensors, - inputs: TensorOrTensors, - retain_graph: bool = False, - allow_unused: bool = False, - ) -> TupleOfOptionalTensors: - ... - def grad( autograd_ctx_id: int, outputs: TensorOrTensors, @@ -111,19 +94,20 @@ def grad( retain_graph: bool = False, allow_unused: bool = False, ) -> TupleOfOptionalTensors: - """Computes and returns the sum of gradients of outputs with respect to the inputs. + """Compute and return the sum of gradients of outputs with respect to the inputs. Args: - autograd_ctx_id: The autograd context id. - outputs (sequence of Tensor): outputs of the differentiated function. - inputs (sequence of Tensor): Inputs w.r.t. which the gradient will be returned (and not - accumulated into ``.grad``). + autograd_ctx_id (int): The autograd context id. + outputs (Tensor or sequence of Tensor): Outputs of the differentiated function. + inputs (Tensor or sequence of Tensor): Inputs w.r.t. which the gradient will be returned + (and not accumulated into ``.grad``). retain_graph (bool, optional): If :data:`False`, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to :data:`True` is not needed and often can be worked around in a much more efficient way. + (default: :data:`False`) allow_unused (bool, optional): If :data:`False`, specifying inputs that were not used when computing outputs (and therefore their grad is always zero) is an error. - Defaults to :data:`False`. + (default: :data:`False`) """ outputs = [outputs] if isinstance(outputs, torch.Tensor) else list(outputs) inputs = (inputs,) if isinstance(inputs, torch.Tensor) else tuple(inputs) diff --git a/torchopt/distributed/world.py b/torchopt/distributed/world.py index 4a24f3ef..804d4b9d 100644 --- a/torchopt/distributed/world.py +++ b/torchopt/distributed/world.py @@ -14,10 +14,12 @@ # ============================================================================== """Utilities for gathering information about the world.""" +from __future__ import annotations + import atexit import functools import os -from typing import Any, Callable, Iterable, NamedTuple, Optional, TypeVar, Union +from typing import Any, Callable, Iterable, NamedTuple, TypeVar import torch.distributed.rpc as rpc from torch.distributed.elastic.multiprocessing.errors import record @@ -46,7 +48,7 @@ def default_worker_name_format( local_rank: int, # pylint: disable=unused-argument local_world_size: int, # pylint: disable=unused-argument ) -> str: - """Default worker name format.""" + """Get the default worker name format.""" return f'worker{world_rank:0{len(str(world_size))}d}' @@ -64,12 +66,12 @@ class WorldInfo(NamedTuple): @property def rank(self) -> int: - """The global world rank of the current worker.""" + """Get the global world rank of the current worker.""" return self.world_rank @property def worker_name(self) -> str: - """The name of the current worker.""" + """Get the name of the current worker.""" return _WORKER_NAME_FORMAT( world_rank=self.world_rank, world_size=self.world_size, @@ -127,34 +129,35 @@ def get_local_world_size() -> int: # pylint: disable-next=redefined-builtin,invalid-name -def get_worker_id(id: Optional[Union[str, int]] = None) -> int: +def get_worker_id(id: str | int | None = None) -> int: """Get the worker id from the given id.""" if isinstance(id, int): return id return rpc.get_worker_info(worker_name=id).id -def barrier(worker_names: Optional[Iterable[str]] = None) -> None: - r"""Synchronizes local and remote RPC processes. +def barrier(worker_names: Iterable[str] | None = None) -> None: + r"""Synchronize local and remote RPC processes. This will block until all local and remote RPC processes specified under worker_names reach this method to wait for all outstanding work to complete. Args: - worker_names: The set of workers to synchronize. If :data:`None`, all workers. + worker_names (iterable of str or None, optional): The set of workers to synchronize. + If :data:`None`, all workers. (default: :data:`None`) """ worker_names = {} if worker_names is None else set(worker_names) rpc.api._barrier(worker_names) # pylint: disable=protected-access def auto_init_rpc( - worker_init_fn: Optional[Callable[[], None]] = None, + worker_init_fn: Callable[[], None] | None = None, worker_name_format: Callable[..., str] = default_worker_name_format, *, - backend: Optional['rpc.BackendType'] = None, - rpc_backend_options: Optional['rpc.RpcBackendOptions'] = None, + backend: rpc.BackendType | None = None, + rpc_backend_options: rpc.RpcBackendOptions | None = None, ) -> Callable[[F], F]: - """Decorator to automatically initialize RPC on the decorated function.""" + """Return a decorator to automatically initialize RPC on the decorated function.""" global _WORKER_NAME_FORMAT # pylint: disable=global-statement _WORKER_NAME_FORMAT = worker_name_format @@ -204,25 +207,25 @@ def wrapped(*args, **kwargs): def on_rank(*ranks: int) -> Callable[[F], F]: - """Decorator to mark a function to be executed only on given ranks.""" + """Return a decorator to mark a function to be executed only on given ranks.""" return __on_ranks(ranks=ranks, inverse=False) def not_on_rank(*ranks) -> Callable[[F], F]: - """Decorator to mark a function to be executed only on non given ranks.""" + """Return a decorator to mark a function to be executed only on non given ranks.""" return __on_ranks(ranks=ranks, inverse=True) def rank_all(func: F) -> F: - """Decorator to mark a function to be executed on all ranks.""" + """Return a decorator to mark a function to be executed on all ranks.""" return func def rank_zero_only(func: F) -> F: - """Decorator to mark a function to be executed only on rank zero.""" + """Return a decorator to mark a function to be executed only on rank zero.""" return on_rank(0)(func) def rank_non_zero_only(func: F) -> F: - """Decorator to mark a function to be executed only on non rank zero.""" + """Return a decorator to mark a function to be executed only on non rank zero.""" return not_on_rank(0)(func) diff --git a/torchopt/hook.py b/torchopt/hook.py index 612f2177..f188415c 100644 --- a/torchopt/hook.py +++ b/torchopt/hook.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,29 +14,32 @@ # ============================================================================== """Hook utilities.""" -from typing import Callable, Optional +from __future__ import annotations + +from typing import Callable import torch from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation +from torchopt.typing import OptState, Params, Updates __all__ = ['zero_nan_hook', 'nan_to_num_hook', 'register_hook'] def zero_nan_hook(g: torch.Tensor) -> torch.Tensor: - """A zero ``nan`` hook to replace ``nan`` with zero.""" + """Replace ``nan`` with zero.""" return g.nan_to_num(nan=0.0) def nan_to_num_hook( - nan: float = 0.0, posinf: Optional[float] = None, neginf: Optional[float] = None + nan: float = 0.0, posinf: float | None = None, neginf: float | None = None ) -> Callable[[torch.Tensor], torch.Tensor]: - """Returns a ``nan`` to num hook to replace ``nan`` / ``+inf`` / ``-inf`` with the given numbers.""" + """Return a ``nan`` to num hook to replace ``nan`` / ``+inf`` / ``-inf`` with the given numbers.""" def hook(g: torch.Tensor) -> torch.Tensor: - """A hook to replace ``nan`` / ``+inf`` / ``-inf`` with the given numbers.""" + """Replace ``nan`` / ``+inf`` / ``-inf`` with the given numbers.""" return g.nan_to_num(nan=nan, posinf=posinf, neginf=neginf) return hook @@ -51,10 +54,16 @@ def register_hook(hook) -> GradientTransformation: An ``(init_fn, update_fn)`` tuple. """ - def init_fn(params): # pylint: disable=unused-argument + def init_fn(params: Params) -> OptState: # pylint: disable=unused-argument return EmptyState() - def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, # pylint: disable=unused-argument + ) -> tuple[Updates, OptState]: def f(g): return g.register_hook(hook) diff --git a/torchopt/linalg/cg.py b/torchopt/linalg/cg.py index 94daee53..5456f076 100644 --- a/torchopt/linalg/cg.py +++ b/torchopt/linalg/cg.py @@ -33,8 +33,10 @@ # pylint: disable=invalid-name +from __future__ import annotations + from functools import partial -from typing import Callable, Optional, Union +from typing import Callable import torch @@ -100,14 +102,14 @@ def body_fn(value): def _isolve( _isolve_solve: Callable, - A: Union[TensorTree, Callable[[TensorTree], TensorTree]], + A: TensorTree | Callable[[TensorTree], TensorTree], b: TensorTree, - x0: Optional[TensorTree] = None, + x0: TensorTree | None = None, *, rtol: float = 1e-5, atol: float = 0.0, - maxiter: Optional[int] = None, - M: Optional[Union[TensorTree, Callable[[TensorTree], TensorTree]]] = None, + maxiter: int | None = None, + M: TensorTree | Callable[[TensorTree], TensorTree] | None = None, ) -> TensorTree: if x0 is None: x0 = pytree.tree_map(torch.zeros_like, b) @@ -133,14 +135,14 @@ def _isolve( def cg( - A: Union[TensorTree, Callable[[TensorTree], TensorTree]], + A: TensorTree | Callable[[TensorTree], TensorTree], b: TensorTree, - x0: Optional[TensorTree] = None, + x0: TensorTree | None = None, *, rtol: float = 1e-5, atol: float = 0.0, - maxiter: Optional[int] = None, - M: Optional[Union[TensorTree, Callable[[TensorTree], TensorTree]]] = None, + maxiter: int | None = None, + M: TensorTree | Callable[[TensorTree], TensorTree] | None = None, ) -> TensorTree: """Use Conjugate Gradient iteration to solve ``Ax = b``. @@ -153,30 +155,30 @@ def cg( solves converge. Args: - A: (tensor or tree of tensors or function) - 2D array or function that calculates the linear map (matrix-vector product) ``Ax`` when - called like ``A(x)``. ``A`` must represent a hermitian, positive definite matrix, and - must return array(s) with the same structure and shape as its argument. - b: (tensor or tree of tensors) - Right hand side of the linear system representing a single vector. Can be stored as an - array or Python container of array(s) with any shape. - x0: (tensor or tree of tensors, optional) - Starting guess for the solution. Must have the same structure as ``b``. - rtol: (float, optional, default: :const:`1e-5`) - Tolerances for convergence, ``norm(residual) <= max(rtol*norm(b), atol)``. We do not - implement SciPy's "legacy" behavior, so TorchOpt's tolerance will differ from SciPy - unless you explicitly pass ``atol`` to SciPy's ``cg``. - atol: (float, optional, default: :const:`0.0`) - Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``. We do not - implement SciPy's "legacy" behavior, so TorchOpt's tolerance will differ from SciPy - unless you explicitly pass ``atol`` to SciPy's ``cg``. - maxiter: (integer, optional) - Maximum number of iterations. Iteration will stop after maxiter steps even if the - specified tolerance has not been achieved. - M: (tensor or tree of tensors or function) - Pre-conditioner for ``A``. The pre-conditioner should approximate the inverse of ``A``. - Effective preconditioning dramatically improves the rate of convergence, which implies - that fewer iterations are needed to reach a given error tolerance. + A (Tensor or tree of Tensor): 2D array or function that calculates the linear map + (matrix-vector product) ``Ax`` when called like ``A(x)``. ``A`` must represent a + hermitian, positive definite matrix, and must return tensor(s) with the same structure + and shape as its argument. + b (Tensor or tree of Tensor): Right hand side of the linear system representing a single + vector. Can be stored as a tensor or Python container of tensor(s) with any shape. + x0 (Tensor, tree of Tensor, or None, optional): Starting guess for the solution. Must have + the same structure as ``b``. If :data:`None`, use zero initialization. + (default: :data:`None`) + rtol (float, optional): Tolerances for convergence, ``norm(residual) <= max(rtol*norm(b), atol)``. + We do not implement SciPy's "legacy" behavior, so TorchOpt's tolerance will differ from + SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``. (default: :const:`1e-5`) + atol (float, optional): Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``. + We do not implement SciPy's "legacy" behavior, so TorchOpt's tolerance will differ from + SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``. (default: :const:`0.0`) + maxiter (int or None, optional): Maximum number of iterations. Iteration will stop after + maxiter steps even if the specified tolerance has not been achieved. If :data:`None`, + ``10 * size`` will be used, where ``size`` is the size of the flattened input tensor(s). + (default: :data:`None`) + M (Tensor, tree of Tensor, function, or None, optional): Pre-conditioner for ``A``. The + pre-conditioner should approximate the inverse of ``A``. Effective preconditioning + dramatically improves the rate of convergence, which implies that fewer iterations are + needed to reach a given error tolerance. If :data:`None`, no pre-conditioner will be + used. (default: :data:`None`) Returns: the Conjugate Gradient (CG) linear solver diff --git a/torchopt/linalg/ns.py b/torchopt/linalg/ns.py index 4da8ef9f..c1975203 100644 --- a/torchopt/linalg/ns.py +++ b/torchopt/linalg/ns.py @@ -16,13 +16,15 @@ # pylint: disable=invalid-name +from __future__ import annotations + import functools -from typing import Callable, Optional, Union +from typing import Callable import torch from torchopt import pytree -from torchopt.linalg.utils import cat_shapes, normalize_matvec +from torchopt.linalg.utils import normalize_matvec from torchopt.typing import TensorTree @@ -33,9 +35,9 @@ def _ns_solve( A: torch.Tensor, b: torch.Tensor, maxiter: int, - alpha: Optional[float] = None, + alpha: float | None = None, ) -> torch.Tensor: - """Uses Neumann Series Matrix Inversion Approximation to solve ``Ax = b``.""" + """Use Neumann Series Matrix Inversion Approximation to solve ``Ax = b``.""" if A.ndim != 2 or A.shape[0] != A.shape[1]: raise ValueError(f'`A` must be a square matrix, but has shape: {A.shape}') @@ -57,27 +59,26 @@ def _ns_solve( def ns( - A: Union[TensorTree, Callable[[TensorTree], TensorTree]], + A: TensorTree | Callable[[TensorTree], TensorTree], b: TensorTree, - maxiter: Optional[int] = None, + maxiter: int | None = None, *, - alpha: Optional[float] = None, + alpha: float | None = None, ) -> TensorTree: - """Uses Neumann Series Matrix Inversion Approximation to solve ``Ax = b``. + """Use Neumann Series Matrix Inversion Approximation to solve ``Ax = b``. Args: - A: (tensor or tree of tensors or function) - 2D array or function that calculates the linear map (matrix-vector product) ``Ax`` when - called like ``A(x)``. ``A`` must represent a hermitian, positive definite matrix, and - must return array(s) with the same structure and shape as its argument. - b: (tensor or tree of tensors) - Right hand side of the linear system representing a single vector. Can be stored as an - array or Python container of array(s) with any shape. - maxiter: (integer, optional) - Maximum number of iterations. Iteration will stop after maxiter steps even if the - specified tolerance has not been achieved. - alpha: (float, optional) - Decay coefficient. + A (Tensor or tree of Tensor): 2D array or function that calculates the linear map + (matrix-vector product) ``Ax`` when called like ``A(x)``. ``A`` must represent a + hermitian, positive definite matrix, and must return tensor(s) with the same structure + and shape as its argument. + b (Tensor or tree of Tensor): Right hand side of the linear system representing a single + vector. Can be stored as a tensor or Python container of tensor(s) with any shape. + maxiter (int or None, optional): Maximum number of iterations. Iteration will stop after + maxiter steps even if the specified tolerance has not been achieved. If :data:`None`, + :const:`10` will be used. (default: :const:`10`) + alpha: (float or None, optional): Decay coefficient. If :data:`None`, :const:`1.0` will be + used. (default: :const:`1.0`) Returns: The Neumann Series (NS) matrix inversion approximation. @@ -111,8 +112,8 @@ def ns( return inv_A_hat_b -def _ns_inv(A: torch.Tensor, maxiter: int, alpha: Optional[float] = None): - """Uses Neumann Series iteration to solve ``A^{-1}``.""" +def _ns_inv(A: torch.Tensor, maxiter: int, alpha: float | None = None): + """Use Neumann Series iteration to solve ``A^{-1}``.""" if A.ndim != 2 or A.shape[0] != A.shape[1]: raise ValueError(f'`A` must be a square matrix, but has shape: {A.shape}') @@ -134,28 +135,27 @@ def _ns_inv(A: torch.Tensor, maxiter: int, alpha: Optional[float] = None): def ns_inv( A: TensorTree, - maxiter: Optional[int] = None, + maxiter: int | None = None, *, - alpha: Optional[float] = None, + alpha: float | None = None, ) -> TensorTree: - """Uses Neumann Series iteration to solve ``A^{-1}``. + """Use Neumann Series iteration to solve ``A^{-1}``. Args: - A: (tensor or tree of tensors or function) - 2D array or function that calculates the linear map (matrix-vector product) ``Ax`` when - called like ``A(x)``. ``A`` must represent a hermitian, positive definite matrix, and - must return array(s) with the same structure and shape as its argument. - maxiter: (integer, optional) - Maximum number of iterations. Iteration will stop after maxiter steps even if the - specified tolerance has not been achieved. - alpha: (float, optional) - Decay coefficient. + A (Tensor or tree of Tensor): 2D array or function that calculates the linear map + (matrix-vector product) ``Ax`` when called like ``A(x)``. ``A`` must represent a + hermitian, positive definite matrix, and must return tensor(s) with the same structure + and shape as its argument. + maxiter (int or None, optional): Maximum number of iterations. Iteration will stop after + maxiter steps even if the specified tolerance has not been achieved. If :data:`None`, + :const:`10` will be used. (default: :const:`10`) + alpha: (float or None, optional): Decay coefficient. If :data:`None`, :const:`1.0` will be + used. (default: :const:`1.0`) Returns: The Neumann Series (NS) matrix inversion approximation. """ if maxiter is None: - size = sum(cat_shapes(A)) - maxiter = 10 * size # copied from SciPy + maxiter = 10 return pytree.tree_map(functools.partial(_ns_inv, maxiter=maxiter, alpha=alpha), A) diff --git a/torchopt/linalg/utils.py b/torchopt/linalg/utils.py index f2440b9a..f301a624 100644 --- a/torchopt/linalg/utils.py +++ b/torchopt/linalg/utils.py @@ -14,8 +14,10 @@ # ============================================================================== """Utilities for linear algebra.""" +from __future__ import annotations + import itertools -from typing import Callable, Tuple, Union +from typing import Callable import torch @@ -23,16 +25,16 @@ from torchopt.typing import TensorTree -def cat_shapes(tree: TensorTree) -> Tuple[int, ...]: - """Concatenates the shapes of the leaves of a tree of tensors.""" +def cat_shapes(tree: TensorTree) -> tuple[int, ...]: + """Concatenate the shapes of the leaves of a tree of tensors.""" leaves = pytree.tree_leaves(tree) return tuple(itertools.chain.from_iterable(tuple(leaf.shape) for leaf in leaves)) def normalize_matvec( - matvec: Union[TensorTree, Callable[[TensorTree], TensorTree]] + matvec: TensorTree | Callable[[TensorTree], TensorTree] ) -> Callable[[TensorTree], TensorTree]: - """Normalizes an argument for computing matrix-vector product.""" + """Normalize an argument for computing matrix-vector product.""" if callable(matvec): return matvec diff --git a/torchopt/linear_solve/cg.py b/torchopt/linear_solve/cg.py index 2ffc8217..844c9407 100644 --- a/torchopt/linear_solve/cg.py +++ b/torchopt/linear_solve/cg.py @@ -33,8 +33,10 @@ # pylint: disable=invalid-name +from __future__ import annotations + import functools -from typing import Callable, Optional +from typing import Callable from torchopt import linalg from torchopt.linear_solve.utils import make_ridge_matvec @@ -47,19 +49,21 @@ def _solve_cg( matvec: Callable[[TensorTree], TensorTree], # (x) -> A @ x b: TensorTree, - ridge: Optional[float] = None, - init: Optional[TensorTree] = None, + ridge: float | None = None, + init: TensorTree | None = None, **kwargs, ) -> TensorTree: - """Solves ``A x = b`` using conjugate gradient. + """Solve ``A x = b`` using conjugate gradient. This assumes that ``A`` is a hermitian, positive definite matrix. Args: - matvec: A function that returns the product between ``A`` and a vector. - b: A tree of tensors for the right hand side of the equation. - ridge: Optional ridge regularization. - init: Optional initialization to be used by conjugate gradient. + matvec (callable): A function that returns the product between ``A`` and a vector. + b (Tensor or tree of Tensor): A tree of tensors for the right hand side of the equation. + ridge (float or None, optional): Optional ridge regularization. If provided, solves the + equation for ``A x + ridge x = b``. (default: :data:`None`) + init (Tensor, tree of Tensor, or None, optional): Optional initialization to be used by + conjugate gradient. If :data:`None`, uses zero initialization. (default: :data:`None`) **kwargs: Additional keyword arguments for the conjugate gradient solver. Returns: @@ -75,13 +79,15 @@ def _solve_cg( def solve_cg(**kwargs): - """A wrapper that returns a solver function to solve ``A x = b`` using conjugate gradient. + """Return a solver function to solve ``A x = b`` using conjugate gradient. This assumes that ``A`` is a hermitian, positive definite matrix. Args: - ridge: Optional ridge regularization. Solves the equation for ``(A + ridge * I) @ x = b``. - init: Optional initialization to be used by conjugate gradient. + ridge (float or None, optional): Optional ridge regularization. If provided, solves the + equation for ``A x + ridge x = b``. (default: :data:`None`) + init (Tensor, tree of Tensor, or None, optional): Optional initialization to be used by + conjugate gradient. If :data:`None`, uses zero initialization. (default: :data:`None`) **kwargs: Additional keyword arguments for the conjugate gradient solver :func:`torchopt.linalg.cg`. diff --git a/torchopt/linear_solve/inv.py b/torchopt/linear_solve/inv.py index bf36f40e..399a0ef9 100644 --- a/torchopt/linear_solve/inv.py +++ b/torchopt/linear_solve/inv.py @@ -33,8 +33,10 @@ # pylint: disable=invalid-name +from __future__ import annotations + import functools -from typing import Callable, Optional +from typing import Callable import torch @@ -49,21 +51,23 @@ def _solve_inv( matvec: Callable[[TensorTree], TensorTree], # (x) -> A @ x b: TensorTree, - ridge: Optional[float] = None, + ridge: float | None = None, ns: bool = False, **kwargs, ) -> TensorTree: - """Solves ``A x = b`` using matrix inversion. + """Solve ``A x = b`` using matrix inversion. If ``ns = False``, this assumes the matrix ``A`` is a constant matrix and will materialize it in memory. Args: - matvec: A function that returns the product between ``A`` and a vector. - b: A tensor for the right hand side of the equation. - ridge: Optional ridge regularization. Solves the equation for ``(A + ridge * I) @ x = b``. - ns: Whether to use Neumann Series matrix inversion approximation. If :data:`False`, - materialize the matrix ``A`` in memory and use :func:`torch.linalg.solve` instead. + matvec (callable): A function that returns the product between ``A`` and a vector. + b (Tensor or tree of Tensor): A tree of tensors for the right hand side of the equation. + ridge (float or None, optional): Optional ridge regularization. If provided, solves the + equation for ``A x + ridge x = b``. (default: :data:`None`) + ns (bool, optional): Whether to use Neumann Series matrix inversion approximation. + If :data:`False`, materialize the matrix ``A`` in memory and use :func:`torch.linalg.solve` + instead. (default: :data:`False`) **kwargs: Additional keyword arguments for the Neumann Series matrix inversion approximation solver :func:`torchopt.linalg.ns`. @@ -88,15 +92,17 @@ def _solve_inv( def solve_inv(**kwargs): - """A wrapper that returns a solver function to solve ``A x = b`` using matrix inversion. + """Return a solver function to solve ``A x = b`` using matrix inversion. If ``ns = False``, this assumes the matrix ``A`` is a constant matrix and will materialize it in memory. Args: - ridge: Optional ridge regularization. Solves the equation for ``(A + ridge * I) @ x = b``. - ns: Whether to use Neumann Series matrix inversion approximation. If :data:`False`, - materialize the matrix ``A`` in memory and use :func:`torch.linalg.solve` instead. + ridge (float or None, optional): Optional ridge regularization. If provided, solves the + equation for ``A x + ridge x = b``. (default: :data:`None`) + ns (bool, optional): Whether to use Neumann Series matrix inversion approximation. + If :data:`False`, materialize the matrix ``A`` in memory and use :func:`torch.linalg.solve` + instead. (default: :data:`False`) **kwargs: Additional keyword arguments for the Neumann Series matrix inversion approximation solver :func:`torchopt.linalg.ns`. diff --git a/torchopt/linear_solve/normal_cg.py b/torchopt/linear_solve/normal_cg.py index 3646d7f4..8d38f77a 100644 --- a/torchopt/linear_solve/normal_cg.py +++ b/torchopt/linear_solve/normal_cg.py @@ -33,8 +33,10 @@ # pylint: disable=invalid-name +from __future__ import annotations + import functools -from typing import Callable, Optional +from typing import Callable from torchopt import linalg from torchopt.linear_solve.utils import make_normal_matvec, make_ridge_matvec, make_rmatvec @@ -47,20 +49,22 @@ def _solve_normal_cg( matvec: Callable[[TensorTree], TensorTree], # (x) -> A @ x b: TensorTree, - ridge: Optional[float] = None, - init: Optional[TensorTree] = None, + ridge: float | None = None, + init: TensorTree | None = None, **kwargs, ) -> TensorTree: - """Solves the normal equation ``A^T A x = A^T b`` using conjugate gradient. + """Solve the normal equation ``A^T A x = A^T b`` using conjugate gradient. This can be used to solve ``A x = b`` using conjugate gradient when ``A`` is not hermitian, positive definite. Args: - matvec: A function that returns the product between ``A`` and a vector. - b: A tree of tensors for the right hand side of the equation. - ridge: Optional ridge regularization. Solves the equation for ``(A.T @ A + ridge * I) @ x = A.T @ b``. - init: Optional initialization to be used by normal conjugate gradient. + matvec (callable): A function that returns the product between ``A`` and a vector. + b (Tensor or tree of Tensor): A tree of tensors for the right hand side of the equation. + ridge (float or None, optional): Optional ridge regularization. If provided, solves the + equation for ``A^T A x + ridge x = A^T b``. (default: :data:`None`) + init (Tensor, tree of Tensor, or None, optional): Optional initialization to be used by + conjugate gradient. If :data:`None`, uses zero initialization. (default: :data:`None`) **kwargs: Additional keyword arguments for the conjugate gradient solver :func:`torchopt.linalg.cg`. @@ -87,14 +91,16 @@ def _solve_normal_cg( def solve_normal_cg(**kwargs): - """A wrapper that returns a solver function to solve ``A^T A x = A^T b`` using conjugate gradient. + """Return a solver function to solve ``A^T A x = A^T b`` using conjugate gradient. This can be used to solve ``A x = b`` using conjugate gradient when ``A`` is not hermitian, positive definite. Args: - ridge: Optional ridge regularization. Solves the equation for ``(A.T @ A + ridge * I) @ x = A.T @ b``. - init: Optional initialization to be used by normal conjugate gradient. + ridge (float or None, optional): Optional ridge regularization. If provided, solves the + equation for ``A^T A x + ridge x = A^T b``. (default: :data:`None`) + init (Tensor, tree of Tensor, or None, optional): Optional initialization to be used by + conjugate gradient. If :data:`None`, uses zero initialization. (default: :data:`None`) **kwargs: Additional keyword arguments for the conjugate gradient solver :func:`torchopt.linalg.cg`. diff --git a/torchopt/linear_solve/utils.py b/torchopt/linear_solve/utils.py index a7e93e65..f4f34e2a 100644 --- a/torchopt/linear_solve/utils.py +++ b/torchopt/linear_solve/utils.py @@ -31,7 +31,9 @@ # ============================================================================== """Utilities for linear algebra solvers.""" -from typing import Callable, Tuple +from __future__ import annotations + +from typing import Callable import functorch @@ -42,7 +44,7 @@ def make_rmatvec( matvec: Callable[[TensorTree], TensorTree], example_x: TensorTree ) -> Callable[[TensorTree], TensorTree]: - """Returns a function that computes ``rmatvec(y) = A.T @ y`` from ``matvec(x) = A @ x``.""" + """Return a function that computes ``rmatvec(y) = A.T @ y`` from ``matvec(x) = A @ x``.""" _, vjp, *_ = functorch.vjp(matvec, example_x) return lambda y: vjp(y)[0] @@ -51,10 +53,10 @@ def make_rmatvec( def make_normal_matvec( matvec: Callable[[TensorTree], TensorTree] ) -> Callable[[TensorTree], TensorTree]: - """Returns a function that computes ``normal_matvec(y) = A.T @ A @ y`` from ``matvec(x) = A @ x``.""" + """Return a function that computes ``normal_matvec(y) = A.T @ A @ y`` from ``matvec(x) = A @ x``.""" def normal_matvec(y: TensorTree) -> TensorTree: - """Computes ``A.T @ A @ y`` from ``matvec(x) = A @ x``.""" + """Compute ``A.T @ A @ y`` from ``matvec(x) = A @ x``.""" matvec_y, vjp, *_ = functorch.vjp(matvec, y) return vjp(matvec_y)[0] @@ -64,10 +66,10 @@ def normal_matvec(y: TensorTree) -> TensorTree: def make_ridge_matvec( matvec: Callable[[TensorTree], TensorTree], ridge: float = 0.0 ) -> Callable[[TensorTree], TensorTree]: - """Returns a function that computes ``ridge_matvec(y) = A.T @ A @ y + ridge * y`` from ``matvec(x) = A @ x``.""" + """Return a function that computes ``ridge_matvec(y) = A.T @ A @ y + ridge * y`` from ``matvec(x) = A @ x``.""" def ridge_matvec(y: TensorTree) -> TensorTree: - """Computes ``A.T @ A @ v + ridge * v`` from ``matvec(x) = A @ x``.""" + """Compute ``A.T @ A @ v + ridge * v`` from ``matvec(x) = A @ x``.""" return pytree.tree_add_scalar_mul(matvec(y), y, alpha=ridge) return ridge_matvec @@ -75,13 +77,13 @@ def ridge_matvec(y: TensorTree) -> TensorTree: def materialize_matvec( matvec: Callable[[TensorTree], TensorTree], x: TensorTree -) -> Tuple[ +) -> tuple[ TensorTree, Callable[[TensorTree], TensorTree], Callable[[TensorTree], TensorTree], Callable[[TensorTree], TensorTree], ]: - """Materializes the matrix ``A`` used in ``matvec(x) = A @ x``.""" + """Materialize the matrix ``A`` used in ``matvec(x) = A @ x``.""" x_flat, treespec = pytree.tree_flatten(x) shapes = tuple(t.shape for t in x_flat) diff --git a/torchopt/nn/__init__.py b/torchopt/nn/__init__.py index 57a8e802..8271ad7d 100644 --- a/torchopt/nn/__init__.py +++ b/torchopt/nn/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,7 +15,16 @@ """Base class for neural network modules that hold meta-parameters and meta-modules.""" from torchopt.diff.implicit.nn.module import ImplicitMetaGradientModule # circular reference +from torchopt.diff.zero_order.nn.module import ZeroOrderGradientModule # circular reference from torchopt.nn.module import MetaGradientModule +from torchopt.nn.stateless import reparameterize, reparametrize, swap_state -__all__ = ['MetaGradientModule', 'ImplicitMetaGradientModule'] +__all__ = [ + 'MetaGradientModule', + 'ImplicitMetaGradientModule', + 'ZeroOrderGradientModule', + 'reparametrize', + 'reparameterize', + 'swap_state', +] diff --git a/torchopt/nn/module.py b/torchopt/nn/module.py index 4a1364f1..f8804864 100644 --- a/torchopt/nn/module.py +++ b/torchopt/nn/module.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,33 +14,36 @@ # ============================================================================== """Base class for neural network modules that hold meta-parameters and meta-modules.""" +from __future__ import annotations + from collections import OrderedDict -from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Set, Tuple, Union +from typing import Any, Iterator, NamedTuple import torch import torch.nn as nn from torchopt import pytree +from torchopt.typing import TensorContainer class MetaInputsContainer(NamedTuple): """Container for parameters and modules in the constructor input arguments.""" - meta_parameters: Set[torch.Tensor] - meta_modules: Set[nn.Module] + meta_parameters: set[torch.Tensor] + meta_modules: set[nn.Module] class MetaGradientModule(nn.Module): # pylint: disable=abstract-method """Base class for neural network modules that hold meta-parameters and meta-modules.""" _meta_inputs: MetaInputsContainer - _meta_parameters: Dict[str, Optional[torch.Tensor]] - _meta_modules: Dict[str, Optional[nn.Module]] + _meta_parameters: TensorContainer + _meta_modules: dict[str, nn.Module | None] - def __new__(cls, *args, **kwargs) -> 'MetaGradientModule': - """Creates a new module instance.""" + def __new__(cls, *args, **kwargs) -> MetaGradientModule: + """Create a new module instance.""" instance = super().__new__(cls) - flat_args: List[Any] + flat_args: list[Any] flat_args = pytree.tree_leaves((args, kwargs)) # type: ignore[arg-type] meta_parameters = {x for x in flat_args if isinstance(x, torch.Tensor) and x.requires_grad} meta_modules = {x for x in flat_args if isinstance(x, nn.Module) and x.training} @@ -49,12 +52,16 @@ def __new__(cls, *args, **kwargs) -> 'MetaGradientModule': meta_modules.update(meta_module.modules()) instance._meta_inputs = MetaInputsContainer(meta_parameters, meta_modules) - instance._meta_parameters: Dict[str, Optional[torch.Tensor]] = OrderedDict() # type: ignore[misc] - instance._meta_modules: Dict[str, Optional[nn.Module]] = OrderedDict() # type: ignore[misc] + instance._meta_parameters: TensorContainer = OrderedDict() # type: ignore[misc] + instance._meta_modules: dict[str, nn.Module | None] = OrderedDict() # type: ignore[misc] return instance - def __getattr__(self, name: str) -> Union[torch.Tensor, nn.Module]: - """Gets an attribute of the module.""" + def __init__(self, *args, **kwargs) -> None: # pylint: disable=unused-argument + """Initialize a new module instance.""" + super().__init__() + + def __getattr__(self, name: str) -> torch.Tensor | nn.Module: + """Get an attribute of the module.""" if '_parameters' in self.__dict__: _parameters = self.__dict__['_parameters'] if name in _parameters: @@ -78,8 +85,8 @@ def __getattr__(self, name: str) -> Union[torch.Tensor, nn.Module]: raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") # pylint: disable-next=too-many-branches,too-many-statements - def __setattr__(self, name: str, value: Union[torch.Tensor, nn.Module]) -> None: - """Sets an attribute of the module.""" + def __setattr__(self, name: str, value: torch.Tensor | nn.Module) -> None: + """Set an attribute of the module.""" def remove_from(*dicts_or_sets): for dict_or_set in dicts_or_sets: @@ -166,7 +173,7 @@ def remove_from(*dicts_or_sets): object.__setattr__(self, name, value) def __delattr__(self, name: str) -> None: - """Deletes an attribute of the module.""" + """Delete an attribute of the module.""" if name in self._parameters: del self._parameters[name] elif name in self._buffers: @@ -181,27 +188,26 @@ def __delattr__(self, name: str) -> None: else: object.__delattr__(self, name) - def register_parameter(self, name: str, param: Optional[torch.Tensor]) -> None: - r"""Adds a parameter to the module. + def register_parameter(self, name: str, param: torch.Tensor | None) -> None: + r"""Add a parameter to the module. The parameter can be accessed as an attribute using given name. Args: - name (string): name of the parameter. The parameter can be accessed - from this module using the given name - param (torch.Tensor or None): parameter to be added to the module. If - ``None``, then operations that run on parameters, such as :attr:`cuda`, - are ignored. If ``None``, the parameter is **not** included in the - module's :attr:`state_dict`. + name (str): The name of the parameter. The parameter can be accessed from this module + using the given name. + param (Tensor or None): The parameter to be added to the module. If :data:`None`, then + operations that run on parameters, such as ``cuda``, are ignored. If :data:`None`, + the parameter is **not** included in the module's ``state_dict``. """ if '_parameters' not in self.__dict__: raise AttributeError('cannot assign parameter before Module.__init__() call') if not isinstance(name, str): raise TypeError(f'parameter name should be a string. Got {torch.typename(name)}') if '.' in name: - raise KeyError("parameter name can't contain \".\"") + raise KeyError("parameter name can't contain '.'") if name == '': - raise KeyError("parameter name can't be empty string \"\"") + raise KeyError("parameter name can't be empty string ''") if hasattr(self, name) and name not in self._parameters: raise KeyError(f"attribute '{name}' already exists") @@ -226,18 +232,17 @@ def register_parameter(self, name: str, param: Optional[torch.Tensor]) -> None: self._parameters[name] = param # type: ignore - def register_meta_parameter(self, name: str, param: Optional[torch.Tensor]) -> None: - r"""Adds a meta-parameter to the module. + def register_meta_parameter(self, name: str, param: torch.Tensor | None) -> None: + r"""Add a meta-parameter to the module. The meta-parameter can be accessed as an attribute using given name. Args: - name (string): name of the parameter. The parameter can be accessed - from this module using the given name - param (torch.Tensor or None): parameter to be added to the module. If - ``None``, then operations that run on parameters, such as :attr:`cuda`, - are ignored. If ``None``, the parameter is **not** included in the - module's :attr:`state_dict`. + name (str): The name of the meta-parameter. The meta-parameter can be accessed from this + module using the given name. + param (Tensor or None): The meta-parameter to be added to the module. If :data:`None`, + then operations that run on meta-parameters, such as ``cuda``, are ignored. If + :data:`None`, the meta-parameter is **not** included in the module's ``state_dict``. """ if '_meta_parameters' not in self.__dict__: raise AttributeError( @@ -246,9 +251,9 @@ def register_meta_parameter(self, name: str, param: Optional[torch.Tensor]) -> N if not isinstance(name, str): raise TypeError(f'meta-parameter name should be a string. Got {torch.typename(name)}') if '.' in name: - raise KeyError("meta-parameter name can't contain \".\"") + raise KeyError("meta-parameter name can't contain '.'") if name == '': - raise KeyError("meta-parameter name can't be empty string \"\"") + raise KeyError("meta-parameter name can't be empty string ''") if hasattr(self, name) and name not in self._meta_parameters: raise KeyError(f"attribute '{name}' already exists") @@ -268,15 +273,15 @@ def register_meta_parameter(self, name: str, param: Optional[torch.Tensor]) -> N self._meta_parameters[name] = param - def add_module(self, name: str, module: Optional[nn.Module]) -> None: - r"""Adds a child module to the current module. + def add_module(self, name: str, module: nn.Module | None) -> None: + r"""Add a child module to the current module. The module can be accessed as an attribute using the given name. Args: - name (string): name of the child module. The child module can be - accessed from this module using the given name - module (Module): child module to be added to the module. + name (str): The name of the child module. The child module can be accessed from this + module using the given name + module (nn.Module or None): The child module to be added to the module. """ if not isinstance(module, nn.Module) and module is not None: raise TypeError(f'{torch.typename(module)} is not a Module subclass') @@ -285,9 +290,9 @@ def add_module(self, name: str, module: Optional[nn.Module]) -> None: if hasattr(self, name) and name not in self._modules: raise KeyError(f"attribute '{name}' already exists") if '.' in name: - raise KeyError(f"module name can't contain \".\", got: {name}") + raise KeyError(f"module name can't contain '.', got: '{name}'") if name == '': - raise KeyError("module name can't be empty string \"\"") + raise KeyError("module name can't be empty string ''") if module in self._meta_inputs.meta_modules: raise ValueError( f"cannot add module that is a meta-module to module '{name}'. " @@ -296,19 +301,19 @@ def add_module(self, name: str, module: Optional[nn.Module]) -> None: self._modules[name] = module - def register_module(self, name: str, module: Optional[nn.Module]) -> None: + def register_module(self, name: str, module: nn.Module | None) -> None: r"""Alias for :func:`add_module`.""" self.add_module(name, module) - def add_meta_module(self, name: str, meta_module: Optional[nn.Module]) -> None: - r"""Adds a child meta-module to the current module. + def add_meta_module(self, name: str, meta_module: nn.Module | None) -> None: + r"""Add a child meta-module to the current module. The meta-module can be accessed as an attribute using the given name. Args: - name (string): name of the child meta-module. The child meta-module can be - accessed from this module using the given name - meta_module (Module): child meta-module to be added to the module. + name (str): The name of the child meta-module. The child meta-module can be accessed + from this module using the given name + meta_module (nn.Module or None): The child meta-module to be added to the module. """ if not isinstance(meta_module, nn.Module) and meta_module is not None: raise TypeError(f'{torch.typename(meta_module)} is not a Module subclass') @@ -317,25 +322,25 @@ def add_meta_module(self, name: str, meta_module: Optional[nn.Module]) -> None: if hasattr(self, name) and name not in self._meta_modules: raise KeyError(f"attribute '{name}' already exists") if '.' in name: - raise KeyError(f"meta-module name can't contain \".\", got: {name}") + raise KeyError(f"meta-module name can't contain '.', got: '{name}'") if name == '': - raise KeyError("meta-module name can't be empty string \"\"") + raise KeyError("meta-module name can't be empty string ''") self._meta_modules[name] = meta_module - def register_meta_module(self, name: str, meta_module: Optional[nn.Module]) -> None: + def register_meta_module(self, name: str, meta_module: nn.Module | None) -> None: r"""Alias for :func:`add_meta_module`.""" self.add_meta_module(name, meta_module) def meta_parameters(self, recurse: bool = True) -> Iterator[torch.Tensor]: - r"""Returns an iterator over module meta-parameters. + r"""Return an iterator over module meta-parameters. This is typically passed to an optimizer. Args: - recurse (bool): if True, then yields parameters of this module and - all submodules. Otherwise, yields only meta-parameters that - are direct members of this module. + recurse (bool, optional): If :data:`True`, then yields parameters of this module and + all submodules. Otherwise, yields only meta-parameters that are direct members of + this module. (default: :data:`True`) Yields: Parameter: module meta-parameter @@ -353,14 +358,15 @@ def meta_parameters(self, recurse: bool = True) -> Iterator[torch.Tensor]: def named_meta_parameters( self, prefix: str = '', recurse: bool = True - ) -> Iterator[Tuple[str, torch.Tensor]]: - r"""Returns an iterator over module meta-parameters, yielding both the name of the meta-parameter as well as the meta-parameter itself. + ) -> Iterator[tuple[str, torch.Tensor]]: + r"""Return an iterator over module meta-parameters, yielding both the name of the meta-parameter as well as the meta-parameter itself. Args: - prefix (str): prefix to prepend to all meta-parameter names. - recurse (bool): if True, then yields meta-parameters of this module - and all submodules. Otherwise, yields only meta-parameters that - are direct members of this module. + prefix (str, optional): The prefix to prepend to all meta-parameter names. + (default: :const:`''`) + recurse (bool, optional): if :data:`True`, then yields meta-parameters of this module + and all submodules. Otherwise, yields only meta-parameters that are direct members + of this module. (default: :data:`True`) Yields: (string, Parameter): Tuple containing the name and parameter @@ -385,7 +391,7 @@ def named_meta_parameters( yield from meta_module.named_parameters(submodule_prefix, recurse) def meta_children(self) -> Iterator[nn.Module]: - r"""Returns an iterator over immediate children meta-modules. + r"""Return an iterator over immediate children meta-modules. Yields: Module: a child meta-module @@ -393,8 +399,8 @@ def meta_children(self) -> Iterator[nn.Module]: for _, module in self.named_meta_children(): yield module - def named_meta_children(self) -> Iterator[Tuple[str, nn.Module]]: - r"""Returns an iterator over immediate children meta-modules, yielding both the name of the meta-module as well as the meta-module itself. + def named_meta_children(self) -> Iterator[tuple[str, nn.Module]]: + r"""Return an iterator over immediate children meta-modules, yielding both the name of the meta-module as well as the meta-module itself. Yields: (string, Module): Tuple containing a name and child meta-module @@ -413,7 +419,7 @@ def named_meta_children(self) -> Iterator[Tuple[str, nn.Module]]: yield name, meta_module def meta_modules(self) -> Iterator[nn.Module]: - r"""Returns an iterator over all meta-modules in the network. + r"""Return an iterator over all meta-modules in the network. Yields: Module: a meta-module in the network @@ -425,15 +431,18 @@ def meta_modules(self) -> Iterator[nn.Module]: yield meta_module def named_meta_modules( - self, memo: Optional[Set[nn.Module]] = None, prefix: str = '', remove_duplicate: bool = True - ) -> Iterator[Tuple[str, nn.Module]]: - r"""Returns an iterator over all meta-modules in the network, yielding both the name of the meta-module as well as the meta-module itself. + self, memo: set[nn.Module] | None = None, prefix: str = '', remove_duplicate: bool = True + ) -> Iterator[tuple[str, nn.Module]]: + r"""Return an iterator over all meta-modules in the network, yielding both the name of the meta-module as well as the meta-module itself. Args: - memo: a memo to store the set of meta-modules already added to the result - prefix: a prefix that will be added to the name of the meta-module - remove_duplicate: whether to remove the duplicated meta-module instances in the result - or not + memo (set of nn.Module or None, optional): A memory to store the set of meta-modules + already added to the result. If not provided, a new set will be created. + (default: :const:`None`) + prefix (str, optional): A prefix that will be added to the name of the meta-module. + (default: :const:`''`) + remove_duplicate (bool, optional): whether to remove the duplicated meta-module + instances in the result or not. (default: :const:`True`) Yields: (string, Module): Tuple of name and meta-module diff --git a/torchopt/nn/stateless.py b/torchopt/nn/stateless.py new file mode 100644 index 00000000..9391352f --- /dev/null +++ b/torchopt/nn/stateless.py @@ -0,0 +1,104 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions for stateless module calls.""" + +from __future__ import annotations + +import contextlib +from typing import Generator, Iterable + +import torch +import torch.nn as nn + + +__all__ = ['swap_state', 'reparametrize', 'reparameterize'] + + +MISSING: torch.Tensor = object() # type: ignore[assignment] + + +def swap_state( + module: nn.Module, + named_tensors: dict[str, torch.Tensor] | Iterable[tuple[str, torch.Tensor]], + allow_missing: bool = False, +) -> dict[str, torch.Tensor]: + """Swap the module parameters and/or buffers.""" + if not isinstance(named_tensors, dict): + named_tensors = dict(named_tensors) + + submodules = {'': module} + + def get_submodule(path: str) -> nn.Module: + """Get submodules recursively.""" + try: + return submodules[path] + except KeyError: + prefix, dot, attr = path.rpartition('.') + if dot: + submodule = submodules[path] = getattr(get_submodule(prefix), attr) + else: + submodule = submodules[path] = getattr(module, attr) + return submodule + + def recursive_setattr(path: str, value: torch.Tensor) -> torch.Tensor: + """Set attribute recursively.""" + prefix, _, attr = path.rpartition('.') + mod = get_submodule(prefix) + + if allow_missing: + orig = getattr(mod, attr, MISSING) + else: + orig = getattr(mod, attr) + + # pylint: disable=protected-access + if value is MISSING: + delattr(mod, attr) + elif hasattr(mod, '_parameters') and attr in mod._parameters: + mod._parameters[attr] = value # type: ignore[assignment] + elif hasattr(mod, '_buffers') and attr in mod._buffers: + mod._buffers[attr] = value + elif hasattr(mod, '_meta_parameters') and attr in mod._meta_parameters: # type: ignore[operator] + mod._meta_parameters[attr] = value # type: ignore[operator,index] + else: + setattr(mod, attr, value) + # pylint: enable=protected-access + + return orig + + orig_named_tensors = { + name: recursive_setattr(name, tensor) for name, tensor in named_tensors.items() + } + return orig_named_tensors + + +@contextlib.contextmanager +def reparametrize( + module: nn.Module, + named_tensors: dict[str, torch.Tensor] | Iterable[tuple[str, torch.Tensor]], + allow_missing: bool = False, +) -> Generator[nn.Module, None, None]: + """Reparameterize the module parameters and/or buffers.""" + if not isinstance(named_tensors, dict): + named_tensors = dict(named_tensors) + + orig_named_tensors = {} + try: + orig_named_tensors = swap_state(module, named_tensors, allow_missing=allow_missing) + yield module + finally: + swap_state(module, orig_named_tensors, allow_missing=allow_missing) + + +reparameterize = reparametrize diff --git a/torchopt/optim/adam.py b/torchopt/optim/adam.py index 8fcdff90..640eea1d 100644 --- a/torchopt/optim/adam.py +++ b/torchopt/optim/adam.py @@ -14,7 +14,9 @@ # ============================================================================== """Adam optimizer.""" -from typing import Iterable, Tuple +from __future__ import annotations + +from typing import Iterable import torch @@ -39,7 +41,7 @@ def __init__( self, params: Iterable[torch.Tensor], lr: ScalarOrSchedule, - betas: Tuple[float, float] = (0.9, 0.999), + betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.0, *, @@ -47,28 +49,30 @@ def __init__( maximize: bool = False, use_accelerated_op: bool = False, ) -> None: - r"""The :meth:`init` function. + r"""Initialize the Adam optimizer. Args: - params: (iterable of torch.Tensor) - An iterable of :class:`torch.Tensor`\s. Specifies what tensors should be optimized. - lr: (default: :const:`1e-3`) - This is a fixed global scaling factor. - betas: (default: :const:`(0.9, 0.999)`) - Coefficients used for computing running averages of gradient and its square. - eps: (default: :const:`1e-8`) - A small constant applied to denominator outside of the square root (as in the Adam - paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - eps_root: (default: :data:`0.0`) - A small constant applied to denominator inside the square root (as in RMSProp), to - avoid dividing by zero when rescaling. This is needed for example when computing - (meta-)gradients through Adam. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - use_accelerated_op: (default: :data:`False`) - If :data:`True` use our implemented fused operator. + params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what + tensors should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to denominator outside of the square + root (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + eps_root (float, optional): A small constant applied to denominator inside the square + root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for + example when computing (meta-)gradients through Adam. (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) + use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator. + (default: :data:`False`) """ super().__init__( params, diff --git a/torchopt/optim/adamw.py b/torchopt/optim/adamw.py index 24362d59..7db5e750 100644 --- a/torchopt/optim/adamw.py +++ b/torchopt/optim/adamw.py @@ -14,13 +14,15 @@ # ============================================================================== """AdamW optimizer.""" -from typing import Any, Callable, Iterable, Optional, Tuple, Union +from __future__ import annotations + +from typing import Callable, Iterable import torch from torchopt import alias from torchopt.optim.base import Optimizer -from torchopt.typing import Params, ScalarOrSchedule +from torchopt.typing import OptState, Params, ScalarOrSchedule __all__ = ['AdamW'] @@ -39,46 +41,48 @@ def __init__( self, params: Iterable[torch.Tensor], lr: ScalarOrSchedule = 1e-3, - betas: Tuple[float, float] = (0.9, 0.999), + betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 1e-2, *, eps_root: float = 0.0, - mask: Optional[Union[Any, Callable[[Params], Any]]] = None, + mask: OptState | Callable[[Params], OptState] | None = None, maximize: bool = False, use_accelerated_op: bool = False, ) -> None: - r"""The :meth:`init` function. + r"""Initialize the AdamW optimizer. Args: - params: (iterable of torch.Tensor) - An iterable of :class:`torch.Tensor`\s. Specifies what tensors should be optimized. - lr: (default: :const:`1e-3`) - This is a fixed global scaling factor. - betas: (default: :const:`(0.9, 0.999)`) - Coefficients used for computing running averages of gradient and its square. - eps: (default: :const:`1e-8`) - A small constant applied to denominator outside of the square root (as in the Adam - paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`1e-2`) - Strength of the weight decay regularization. Note that this weight decay is - multiplied with the learning rate. This is consistent with other frameworks such as - PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only - multiplied with the "schedule multiplier", but not the base learning rate. - eps_root: (default: :data:`0.0`) - A small constant applied to denominator inside the square root (as in RMSProp), to - avoid dividing by zero when rescaling. This is needed for example when computing - (meta-)gradients through Adam. - mask: (default: :data:`None`) - A tree with same structure as (or a prefix of) the params PyTree, or a Callable that + params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what + tensors should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to denominator outside of the square + root (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Strength of the weight decay regularization. Note that + this weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where + the weight decay is only multiplied with the "schedule multiplier", but not the base + learning rate. (default: :const:`1e-2`) + eps_root (float, optional): A small constant applied to denominator inside the square + root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for + example when computing (meta-)gradients through Adam. (default: :const:`0.0`) + mask (tree of Tensor, callable, or None, optional): + A tree with same structure as (or a prefix of) the params pytree, or a function that returns such a pytree given the params/updates. The leaves should be booleans, :data:`True` for leaves/subtrees you want to apply the weight decay to, and :data:`False` for those you want to skip. Note that the Adam gradient - transformations are applied to all parameters. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - use_accelerated_op: (default: :data:`False`) - If :data:`True` use our implemented fused operator. + transformations are applied to all parameters. (default: :data:`None`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) + use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator. + (default: :data:`False`) """ super().__init__( params, diff --git a/torchopt/optim/base.py b/torchopt/optim/base.py index dc933f30..aac3a782 100644 --- a/torchopt/optim/base.py +++ b/torchopt/optim/base.py @@ -14,7 +14,9 @@ # ============================================================================== """The base class for optimizers.""" -from typing import Callable, Iterable, List, Optional, Sequence, Tuple +from __future__ import annotations + +from typing import Callable, Iterable, Sequence import torch @@ -31,14 +33,14 @@ class Optimizer: """A base class for classic optimizers that similar to :class:`torch.optim.Optimizer`.""" def __init__(self, params: Iterable[torch.Tensor], impl: GradientTransformation) -> None: - r"""The :meth:`init` function. + r"""Initialize the optimizer. Args: params (iterable of torch.Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what tensors should be optimized. impl (GradientTransformation): A low level optimizer function, it could be a optimizer - function provided by ``alias.py`` or a customized ``chain`` provided by - ``combine.py``. + function provided in :mod:`torchopt.alias` or a customized :func:`torchopt.chain`\ed + transformation. Note that using ``Optimizer(sgd())`` or ``Optimizer(chain(sgd()))`` is equivalent to :class:`torchopt.SGD`. """ @@ -46,21 +48,22 @@ def __init__(self, params: Iterable[torch.Tensor], impl: GradientTransformation) raise TypeError(f'{impl} (type: {type(impl).__name__}) is not a GradientTransformation') self.impl: GradientTransformation = impl - self.param_groups: List[TupleOfTensors] = [] - self.param_treespecs: List[pytree.PyTreeSpec] = [] - self.state_groups: List[OptState] = [] + self.param_groups: list[TupleOfTensors] = [] + self.param_treespecs: list[pytree.PyTreeSpec] = [] + self.state_groups: list[OptState] = [] if not isinstance(params, (list, tuple)): params = tuple(params) self.add_param_group(params) def zero_grad(self, set_to_none: bool = False) -> None: - r"""Sets the gradients of all optimized :class:`torch.Tensor`\s to zero. + r"""Set the gradients of all optimized :class:`torch.Tensor`\s to zero. The behavior is similar to :meth:`torch.optim.Optimizer.zero_grad`. Args: - set_to_none (bool): Instead of setting to zero, set the ``grads`` to :data:`None`. + set_to_none (bool, optional): Instead of setting to zero, set the ``grads`` to + :data:`None`. (default: :data:`False`) """ if set_to_none: @@ -80,26 +83,27 @@ def f(p): pytree.tree_map_(f, self.param_groups) # type: ignore[arg-type] - def state_dict(self) -> Tuple[OptState, ...]: - """Returns the state of the optimizer.""" + def state_dict(self) -> tuple[OptState, ...]: + """Return the state of the optimizer.""" return tuple(self.state_groups) def load_state_dict(self, state_dict: Sequence[OptState]) -> None: - """Loads the optimizer state. + """Load the optimizer state. Args: - state_dict: Optimizer state. Should be an object returned from a call to - :meth:`state_dict`. + state_dict (sequence of tree of Tensor): Optimizer state. Should be an object returned + from a call to :meth:`state_dict`. """ self.state_groups[:] = list(state_dict) - def step(self, closure: Optional[Callable[[], torch.Tensor]] = None) -> Optional[torch.Tensor]: - """Performs a single optimization step. + def step(self, closure: Callable[[], torch.Tensor] | None = None) -> torch.Tensor | None: + """Perform a single optimization step. The behavior is similar to :meth:`torch.optim.Optimizer.step`. Args: - closure (callable, optional): A closure that reevaluates the model and returns the loss. + closure (callable or None, optional): A closure that reevaluates the model and returns + the loss. Optional for most optimizers. (default: :data:`None`) """ loss = None if closure is not None: @@ -120,7 +124,7 @@ def f(p): return loss def add_param_group(self, params: Params) -> None: - """Add a param group to the optimizer's :attr:`param_groups`.""" + """Add a param group to the optimizer's ``param_groups``.""" flat_params: TupleOfTensors flat_params, params_treespec = pytree.tree_flatten_as_tuple(params) self.param_groups.append(flat_params) diff --git a/torchopt/optim/func/base.py b/torchopt/optim/func/base.py index b3125d19..9dce3412 100644 --- a/torchopt/optim/func/base.py +++ b/torchopt/optim/func/base.py @@ -14,7 +14,7 @@ # ============================================================================== """Functional optimizer wrappers.""" -from typing import Optional +from __future__ import annotations import torch @@ -41,26 +41,27 @@ class FuncOptimizer: # pylint: disable=too-few-public-methods """ def __init__(self, impl: GradientTransformation, *, inplace: bool = False) -> None: - """The :meth:`init` function. + r"""Initialize the functional optimizer wrapper. Args: impl (GradientTransformation): A low level optimizer function, it could be a optimizer - function provided by `alias.py` or a customized `chain` provided by `combine.py`. - inplace (optional): (default: :data:`False`) - The default value of ``inplace`` for each optimization update. + function provided in :mod:`torchopt.alias` or a customized :func:`torchopt.chain`\ed + transformation. + inplace (bool, optional): The default value of ``inplace`` for each optimization update. + (default: :data:`False`) """ if not isinstance(impl, GradientTransformation): raise TypeError(f'{impl} (type: {type(impl).__name__}) is not a GradientTransformation') self.impl: GradientTransformation = impl - self.optim_state: Optional[OptState] = UninitializedState() + self.optim_state: OptState | None = UninitializedState() self.inplace: bool = bool(inplace) def step( self, loss: torch.Tensor, params: Params, - inplace: Optional[bool] = None, + inplace: bool | None = None, ) -> Params: r"""Compute the gradients of loss to the network parameters and update network parameters. @@ -69,13 +70,12 @@ def step( gradients and update the network parameters without modifying tensors in-place. Args: - loss: (torch.Tensor) - loss that is used to compute the gradients to network parameters. - params: (tree of torch.Tensor) - An tree of :class:`torch.Tensor`\s. Specifies what tensors should be optimized. - inplace (optional): (default: :data:`None`) - Whether to update the parameters in-place. If :data:`None`, use the default value - specified in the constructor. + loss (Tensor): The loss that is used to compute the gradients to network parameters. + params (tree of Tensor): An tree of :class:`torch.Tensor`\s. Specifies what tensors + should be optimized. + inplace (bool or None, optional): Whether to update the parameters in-place. If + :data:`None`, use the default value specified in the constructor. + (default: :data:`None`) """ if isinstance(self.optim_state, UninitializedState): self.optim_state = self.impl.init(params) diff --git a/torchopt/optim/meta/adam.py b/torchopt/optim/meta/adam.py index 9340b513..bd9804b9 100644 --- a/torchopt/optim/meta/adam.py +++ b/torchopt/optim/meta/adam.py @@ -14,7 +14,7 @@ # ============================================================================== """Differentiable Adam optimizer.""" -from typing import Tuple +from __future__ import annotations import torch.nn as nn @@ -39,7 +39,7 @@ def __init__( self, module: nn.Module, lr: ScalarOrSchedule = 1e-3, - betas: Tuple[float, float] = (0.9, 0.999), + betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.0, *, @@ -48,31 +48,29 @@ def __init__( maximize: bool = False, use_accelerated_op: bool = False, ) -> None: - """The :meth:`init` function. + """Initialize the meta-Adam optimizer. Args: - module: (nn.Module) - A network whose parameters should be optimized. - lr: (default: :const:`1e-3`) - This is a fixed global scaling factor. - betas: (default: :const:`(0.9, 0.999)`) - Coefficients used for computing running averages of gradient and its square. - eps: (default: :const:`1e-8`) - A small constant applied to denominator outside of the square root (as in the Adam - paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - eps_root: (default: :data:`0.0`) - A small constant applied to denominator inside the square root (as in RMSProp), to - avoid dividing by zero when rescaling. This is needed for example when computing - (meta-)gradients through Adam. - moment_requires_grad: (default: :data:`True`) - If :data:`True` the momentums will be created with flag ``requires_grad=True``, this - flag is often used in Meta-Learning algorithms. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - use_accelerated_op: (default: :data:`False`) - If :data:`True` use our implemented fused operator. + module (nn.Module): A network whose parameters should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to denominator outside of the square + root (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + eps_root (float, optional): A small constant applied to denominator inside the square + root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for + example when computing (meta-)gradients through Adam. (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) + use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator. + (default: :data:`False`) """ super().__init__( module, diff --git a/torchopt/optim/meta/adamw.py b/torchopt/optim/meta/adamw.py index 70f3a80a..c8a8ef9c 100644 --- a/torchopt/optim/meta/adamw.py +++ b/torchopt/optim/meta/adamw.py @@ -14,13 +14,15 @@ # ============================================================================== """Differentiable AdamW optimizer.""" -from typing import Any, Callable, Optional, Tuple, Union +from __future__ import annotations + +from typing import Callable import torch.nn as nn from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer -from torchopt.typing import Params, ScalarOrSchedule +from torchopt.typing import OptState, Params, ScalarOrSchedule __all__ = ['MetaAdamW'] @@ -39,50 +41,48 @@ def __init__( self, module: nn.Module, lr: ScalarOrSchedule = 1e-3, - betas: Tuple[float, float] = (0.9, 0.999), + betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 1e-2, *, eps_root: float = 0.0, - mask: Optional[Union[Any, Callable[[Params], Any]]] = None, + mask: OptState | Callable[[Params], OptState] | None = None, moment_requires_grad: bool = False, maximize: bool = False, use_accelerated_op: bool = False, ) -> None: - """The :meth:`init` function. + """Initialize the meta-AdamW optimizer. Args: - module: (nn.Module) - A network whose parameters should be optimized. - lr: (default: :const:`1e-3`) - This is a fixed global scaling factor. - betas: (default: :const:`(0.9, 0.999)`) - Coefficients used for computing running averages of gradient and its square. - eps: (default: :const:`1e-8`) - A small constant applied to denominator outside of the square root (as in the Adam - paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`1e-2`) - Strength of the weight decay regularization. Note that this weight decay is - multiplied with the learning rate. This is consistent with other frameworks such as - PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only - multiplied with the "schedule multiplier", but not the base learning rate. - eps_root: (default: :data:`0.0`) - A small constant applied to denominator inside the square root (as in RMSProp), to - avoid dividing by zero when rescaling. This is needed for example when computing - (meta-)gradients through Adam. - mask: (default: :data:`None`) - A tree with same structure as (or a prefix of) the params PyTree, or a Callable that + module (nn.Module): A network whose parameters should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to denominator outside of the square + root (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Strength of the weight decay regularization. Note that + this weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where + the weight decay is only multiplied with the "schedule multiplier", but not the base + learning rate. (default: :const:`1e-2`) + eps_root (float, optional): A small constant applied to denominator inside the square + root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for + example when computing (meta-)gradients through Adam. (default: :const:`0.0`) + mask (tree of Tensor, callable, or None, optional): + A tree with same structure as (or a prefix of) the params pytree, or a function that returns such a pytree given the params/updates. The leaves should be booleans, :data:`True` for leaves/subtrees you want to apply the weight decay to, and :data:`False` for those you want to skip. Note that the Adam gradient - transformations are applied to all parameters. - moment_requires_grad: (default: :data:`False`) - If :data:`True` the momentums will be created with flag ``requires_grad=True``, this - flag is often used in Meta-Learning algorithms. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - use_accelerated_op: (default: :data:`False`) - If :data:`True` use our implemented fused operator. + transformations are applied to all parameters. (default: :data:`None`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) + use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator. + (default: :data:`False`) """ super().__init__( module, diff --git a/torchopt/optim/meta/base.py b/torchopt/optim/meta/base.py index 5993ecc1..c5c9ad73 100644 --- a/torchopt/optim/meta/base.py +++ b/torchopt/optim/meta/base.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,14 +14,16 @@ # ============================================================================== """The base class for differentiable meta-optimizers.""" -from typing import Dict, List, Optional, Sequence, Tuple +from __future__ import annotations + +from typing import Sequence import torch import torch.nn as nn from torchopt import pytree from torchopt.base import UninitializedState -from torchopt.typing import GradientTransformation, OptState, TupleOfTensors +from torchopt.typing import GradientTransformation, ModuleTensorContainers, OptState, TupleOfTensors from torchopt.update import apply_updates from torchopt.utils import extract_module_containers @@ -33,14 +35,13 @@ class MetaOptimizer: """The base class for high-level differentiable optimizers.""" def __init__(self, module: nn.Module, impl: GradientTransformation) -> None: - """The :meth:`init` function. + r"""Initialize the meta-optimizer. Args: - module: (nn.Module) - A network whose parameters should be optimized. - impl: (GradientTransformation) - A low level optimizer function, it could be a optimizer function provided by - ``alias.py`` or a customized ``chain`` provided by ``combine.py``. + module (nn.Module): A network whose parameters should be optimized. + impl (GradientTransformation): A low level optimizer function, it could be a optimizer + function provided in :mod:`torchopt.alias` or a customized :func:`torchopt.chain`\ed + transformation. Note that using ``MetaOptimizer(sgd(moment_requires_grad=True))`` or ``MetaOptimizer(chain(sgd(moment_requires_grad=True)))`` is equivalent to :class:`torchopt.MetaSGD`. @@ -49,8 +50,8 @@ def __init__(self, module: nn.Module, impl: GradientTransformation) -> None: raise TypeError(f'{impl} (type: {type(impl).__name__}) is not a GradientTransformation') self.impl: GradientTransformation = impl - self.param_containers_groups: List[Tuple[Dict[str, Optional[torch.Tensor]], ...]] = [] - self.state_groups: List[OptState] = [] + self.param_containers_groups: list[ModuleTensorContainers] = [] + self.state_groups: list[OptState] = [] self.add_param_group(module) @@ -62,8 +63,8 @@ def step(self, loss: torch.Tensor) -> None: # pylint: disable=too-many-locals gradients and update the network parameters without modifying tensors in-place. Args: - loss: (torch.Tensor) - The loss that is used to compute the gradients to the network parameters. + loss (torch.Tensor): The loss that is used to compute the gradients to the network + parameters. """ # Step parameter only for i, (param_container, state) in enumerate( @@ -87,21 +88,19 @@ def step(self, loss: torch.Tensor) -> None: # pylint: disable=too-many-locals ) self.state_groups[i] = new_state flat_new_params = apply_updates(flat_params, updates, inplace=False) - new_params: Tuple[ - Dict[str, Optional[torch.Tensor]], ... - ] = pytree.tree_unflatten( # type: ignore[assignment] + new_params: ModuleTensorContainers = pytree.tree_unflatten( # type: ignore[assignment] container_treespec, flat_new_params ) for container, new_param in zip(param_container, new_params): container.update(new_param) def add_param_group(self, module: nn.Module) -> None: - """Add a param group to the optimizer's :attr:`state_groups`.""" + """Add a param group to the optimizer's ``state_groups``.""" params_container = extract_module_containers(module, with_buffers=False)[0] self.param_containers_groups.append(params_container) self.state_groups.append(UninitializedState()) - def state_dict(self) -> Tuple[OptState, ...]: + def state_dict(self) -> tuple[OptState, ...]: """Extract the references of the optimizer states. Note that the states are references, so any in-place operations will change the states diff --git a/torchopt/optim/meta/rmsprop.py b/torchopt/optim/meta/rmsprop.py index 47c3e983..3aff20e1 100644 --- a/torchopt/optim/meta/rmsprop.py +++ b/torchopt/optim/meta/rmsprop.py @@ -47,33 +47,29 @@ def __init__( nesterov: bool = False, maximize: bool = False, ) -> None: - """The :meth:`init` function. + """Initialize the meta-RMSProp optimizer. Args: - module: (nn.Module) - A network whose parameters should be optimized. - lr: (default: :const:`1e-2`) - This is a fixed global scaling factor. - alpha: (default: :const:`0.99`) - Smoothing constant, the decay used to track the magnitude of previous gradients. - eps: (default: :const:`1e-8`) - A small numerical constant to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - momentum: (default: :const:`0.0`) - The decay rate used by the momentum term. The momentum is not used when it is set to - :const:`0.0`. - centered: (default: :data:`False`) - If :data:`True`, use the variance of the past gradients to rescale the latest - gradients. - initial_scale: (default: :data:`0.0`) - Initialization of accumulators tracking the magnitude of previous updates. PyTorch - uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When reproducing results from a - paper, verify the value used by the authors. - nesterov: (default: :data:`False`) - Whether to use Nesterov momentum. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. + module (nn.Module): A network whose parameters should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-2`) + alpha (float, optional): Smoothing constant, the decay used to track the magnitude of + previous gradients. (default: :const:`0.99`) + eps (float, optional): A small numerical constant to avoid dividing by zero when + rescaling. (default: :const:`1e-8`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + momentum (float, optional): The decay rate used by the momentum term. The momentum is + not used when it is set to :const:`0.0`. (default: :const:`0.0`) + centered (bool, optional): If :data:`True`, use the variance of the past gradients to + rescale the latest gradients. (default: :data:`False`) + initial_scale (float, optional): Initialization of accumulators tracking the magnitude + of previous updates. PyTorch uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When + reproducing results from a paper, verify the value used by the authors. + (default: :data:`0.0`) + nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) """ super().__init__( module, diff --git a/torchopt/optim/meta/sgd.py b/torchopt/optim/meta/sgd.py index f46158a6..476ed9d6 100644 --- a/torchopt/optim/meta/sgd.py +++ b/torchopt/optim/meta/sgd.py @@ -44,26 +44,23 @@ def __init__( moment_requires_grad: bool = True, maximize: bool = False, ) -> None: - """The :meth:`init` function. + """Initialize the meta-SGD optimizer. Args: - module: (nn.Module) - A network whose parameters should be optimized. - lr: This is a fixed global scaling factor. - momentum: (default: :const:`0.0`) - The decay rate used by the momentum term. The momentum is not used when it is set to - :const:`0.0`. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - dampening: (default: :const:`0.0`) - Dampening for momentum. - nesterov: (default: :const:`False`) - Whether to use Nesterov momentum. - moment_requires_grad: (default: :data:`True`) - If :data:`True` the momentums will be created with flag ``requires_grad=True``, this - flag is often used in Meta-Learning algorithms. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. + module (nn.Module): A network whose parameters should be optimized. + lr (float or callable): This is a fixed global scaling factor or a learning rate + scheduler. + momentum (float, optional): The decay rate used by the momentum term. The momentum is + not used when it is set to :const:`0.0`. (default: :const:`0.0`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + dampening (float, optional): Dampening for momentum. (default: :const:`0.0`) + nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) """ super().__init__( module, diff --git a/torchopt/optim/rmsprop.py b/torchopt/optim/rmsprop.py index dc649722..5c4e536f 100644 --- a/torchopt/optim/rmsprop.py +++ b/torchopt/optim/rmsprop.py @@ -49,33 +49,30 @@ def __init__( nesterov: bool = False, maximize: bool = False, ) -> None: - r"""The `init` function. + r"""Initialize the RMSProp optimizer. Args: - params: (iterable of torch.Tensor) - An iterable of :class:`torch.Tensor`\s. Specifies what Tensors should be optimized. - lr: (default: :const:`1e-2`) - This is a fixed global scaling factor. - alpha: (default: :const:`0.99`) - Smoothing constant, the decay used to track the magnitude of previous gradients. - eps: (default: :const:`1e-8`) - A small numerical constant to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - momentum: (default: :const:`0.0`) - The decay rate used by the momentum term. The momentum is not used when it is set to - :const:`0.0`. - centered: (default: :data:`False`) - If :data:`True`, use the variance of the past gradients to rescale the latest - gradients. - initial_scale: (default: :data:`0.0`) - Initialization of accumulators tracking the magnitude of previous updates. PyTorch - uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When reproducing results from a - paper, verify the value used by the authors. - nesterov: (default: :data:`False`) - Whether to use Nesterov momentum. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. + params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what + tensors should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-2`) + alpha (float, optional): Smoothing constant, the decay used to track the magnitude of + previous gradients. (default: :const:`0.99`) + eps (float, optional): A small numerical constant to avoid dividing by zero when + rescaling. (default: :const:`1e-8`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + momentum (float, optional): The decay rate used by the momentum term. The momentum is + not used when it is set to :const:`0.0`. (default: :const:`0.0`) + centered (bool, optional): If :data:`True`, use the variance of the past gradients to + rescale the latest gradients. (default: :data:`False`) + initial_scale (float, optional): Initialization of accumulators tracking the magnitude + of previous updates. PyTorch uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When + reproducing results from a paper, verify the value used by the authors. + (default: :data:`0.0`) + nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) """ super().__init__( params, diff --git a/torchopt/optim/sgd.py b/torchopt/optim/sgd.py index d83786ae..3da9595a 100644 --- a/torchopt/optim/sgd.py +++ b/torchopt/optim/sgd.py @@ -45,23 +45,24 @@ def __init__( nesterov: bool = False, maximize: bool = False, ) -> None: - r"""The :meth:`init` function. + r"""Initialize the SGD optimizer. Args: - params: (iterable of torch.Tensor) - An iterable of :class:`torch.Tensor`\s. Specifies what tensors should be optimized. - lr: This is a fixed global scaling factor. - momentum: (default: :const:`0.0`) - The decay rate used by the momentum term. The momentum is not used when it is set to - :const:`0.0`. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - dampening: (default: :const:`0.0`) - Dampening for momentum. - nesterov: (default: :data:`False`) - Whether to use Nesterov momentum. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. + params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what + tensors should be optimized. + lr (float or callable): This is a fixed global scaling factor or a learning rate + scheduler. + momentum (float, optional): The decay rate used by the momentum term. The momentum is + not used when it is set to :const:`0.0`. (default: :const:`0.0`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + dampening (float, optional): Dampening for momentum. (default: :const:`0.0`) + nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) """ super().__init__( params, diff --git a/torchopt/pytree.py b/torchopt/pytree.py index 0308b825..d3b2d181 100644 --- a/torchopt/pytree.py +++ b/torchopt/pytree.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,9 +14,11 @@ # ============================================================================== """The PyTree utilities.""" +from __future__ import annotations + import functools import operator -from typing import Callable, List, Optional, Tuple +from typing import Callable import optree import optree.typing as typing # pylint: disable=unused-import @@ -47,19 +49,20 @@ def tree_flatten_as_tuple( tree: PyTree[T], - is_leaf: Optional[Callable[[T], bool]] = None, + is_leaf: Callable[[T], bool] | None = None, *, none_is_leaf: bool = False, namespace: str = '', -) -> Tuple[Tuple[T, ...], PyTreeSpec]: +) -> tuple[tuple[T, ...], PyTreeSpec]: """Flatten a pytree to a tuple of leaves and a PyTreeSpec. Args: - tree: The pytree to flatten. - is_leaf: A function that returns :data:`True` if a given node is a leaf. - none_is_leaf: If :data:`True`, None is considered a leaf rather than a internal node with no - children. - namespace: The namespace of custom tree node types. + tree (pytree): The pytree to flatten. + is_leaf (callable or None, optional): An optionally specified function that returns + :data:`True` if a given node is a leaf. (default: :data:`None`) + none_is_leaf (bool, optional): If :data:`True`, :data:`None` is considered a leaf rather + than a internal node with no children. (default: :data:`False`) + namespace (str, optional): The namespace of custom tree node types. (default: :const:`''`) Returns: A tuple of (leaves, treespec). @@ -84,12 +87,12 @@ def acc_matmul(*args: T) -> T: def tree_pos(tree: PyTree[T]) -> PyTree[T]: - """Applies `operator.pos` over leaves.""" + """Apply :func:`operator.pos` over leaves.""" return tree_map(operator.pos, tree) def tree_neg(tree: PyTree[T]) -> PyTree[T]: - """Applies `operator.neg` over leaves.""" + """Apply :func:`operator.neg` over leaves.""" return tree_map(operator.neg, tree) @@ -99,9 +102,9 @@ def tree_add(*trees: PyTree[T]) -> PyTree[T]: def tree_add_scalar_mul( - tree_x: TensorTree, tree_y: TensorTree, alpha: Optional[Scalar] = None + tree_x: TensorTree, tree_y: TensorTree, alpha: Scalar | None = None ) -> TensorTree: - """Computes tree_x + alpha * tree_y.""" + """Compute ``tree_x + alpha * tree_y``.""" if alpha is None: return tree_map(lambda x, y: x.add(y), tree_x, tree_y) return tree_map(lambda x, y: x.add(y, alpha=alpha), tree_x, tree_y) @@ -113,9 +116,9 @@ def tree_sub(minuend_tree: PyTree[T], subtrahend_tree: PyTree[T]) -> PyTree[T]: def tree_sub_scalar_mul( - tree_x: TensorTree, tree_y: TensorTree, alpha: Optional[Scalar] = None + tree_x: TensorTree, tree_y: TensorTree, alpha: Scalar | None = None ) -> TensorTree: - """Computes tree_x - alpha * tree_y.""" + """Compute ``tree_x - alpha * tree_y``.""" if alpha is None: return tree_map(lambda x, y: x.sub(y), tree_x, tree_y) return tree_map(lambda x, y: x.sub(y, alpha=alpha), tree_x, tree_y) @@ -142,7 +145,7 @@ def tree_truediv(dividend_tree: PyTree[T], divisor_tree: PyTree[T]) -> PyTree[T] def _vdot_real_kernel(x: torch.Tensor, y: torch.Tensor) -> float: - """Computes dot(x.conj(), y).real.""" + """Compute ``dot(x.conj(), y).real``.""" x = x.contiguous().view(-1) y = y.contiguous().view(-1) vdot = torch.dot(x.real, y.real).item() @@ -152,7 +155,7 @@ def _vdot_real_kernel(x: torch.Tensor, y: torch.Tensor) -> float: def tree_vdot_real(tree_x: TensorTree, tree_y: TensorTree) -> float: - """Computes dot(tree_x.conj(), tree_y).real.sum().""" + """Compute ``dot(tree_x.conj(), tree_y).real.sum()``.""" leaves_x, treespec = tree_flatten(tree_x) leaves_y = treespec.flatten_up_to(tree_y) return sum(map(_vdot_real_kernel, leaves_x, leaves_y)) # type: ignore[arg-type] @@ -167,7 +170,7 @@ def tree_wait(future_tree: PyTree[Future[T]]) -> PyTree[T]: return tree_unflatten(treespec, results) -if rpc.is_available(): +if rpc.is_available(): # pragma: no cover def tree_as_rref(tree: PyTree[T]) -> PyTree[RRef[T]]: r"""Convert a tree of local objects to a tree of :class:`RRef`\s.""" @@ -190,4 +193,4 @@ def tree_local_value(rref_tree: PyTree[RRef[T]]) -> PyTree[T]: __all__.extend(['tree_as_rref', 'tree_to_here']) -del Callable, List, Optional, Tuple, optree, rpc, Scalar, T, RRef +del Callable, optree, rpc, Scalar, T, RRef diff --git a/torchopt/schedule/polynomial.py b/torchopt/schedule/polynomial.py index 8d2c2056..d54dbf17 100644 --- a/torchopt/schedule/polynomial.py +++ b/torchopt/schedule/polynomial.py @@ -49,21 +49,20 @@ def polynomial_schedule( transition_steps: int, transition_begin: int = 0, ) -> Schedule: - """Constructs a schedule with polynomial transition from init to end value. + """Construct a schedule with polynomial transition from init to end value. Args: - init_value: Initial value for the scalar to be annealed. - end_value: End value of the scalar to be annealed. - power: The power of the polynomial used to transition from ``init`` to ``end``. - transition_steps: - Number of steps over which annealing takes place, the scalar starts changing at - ``transition_begin`` steps and completes the transition by - ``transition_begin + transition_steps`` steps. - If ``transition_steps <= 0``, then the entire annealing process is disabled and the - value is held fixed at ``init_value``. - transition_begin: - Must be *positive*. After how many steps to start annealing (before this many steps the - scalar value is held fixed at ``init_value``). + init_value (float or Tensor): Initial value for the scalar to be annealed. + end_value (float or Tensor): End value of the scalar to be annealed. + power (float or Tensor): The power of the polynomial used to transition from ``init`` to + ``end``. + transition_steps (int): Number of steps over which annealing takes place, the scalar starts + changing at ``transition_begin`` steps and completes the transition by + ``transition_begin + transition_steps`` steps. If ``transition_steps <= 0``, then the + entire annealing process is disabled and the value is held fixed at ``init_value``. + transition_begin (int, optional): Must be *positive*. After how many steps to start + annealing (before this many steps the scalar value is held fixed at ``init_value``). + (default: :const:`0`) Returns: schedule: diff --git a/torchopt/transform/__init__.py b/torchopt/transform/__init__.py index 07c1a8e9..7006090f 100644 --- a/torchopt/transform/__init__.py +++ b/torchopt/transform/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ # ============================================================================== """Preset transformations.""" -from torchopt.transform.add_decayed_weights import add_decayed_weights +from torchopt.transform.add_decayed_weights import add_decayed_weights, masked from torchopt.transform.nan_to_num import nan_to_num from torchopt.transform.scale import scale from torchopt.transform.scale_by_adam import scale_by_accelerated_adam, scale_by_adam @@ -46,6 +46,7 @@ 'scale', 'scale_by_schedule', 'add_decayed_weights', + 'masked', 'scale_by_adam', 'scale_by_accelerated_adam', 'scale_by_rms', diff --git a/torchopt/transform/add_decayed_weights.py b/torchopt/transform/add_decayed_weights.py index 700e9c7b..14745766 100644 --- a/torchopt/transform/add_decayed_weights.py +++ b/torchopt/transform/add_decayed_weights.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,19 +32,21 @@ # ============================================================================== """Preset transformations for adding weight decay to updates.""" -from typing import Any, Callable, NamedTuple, Optional, Union +from __future__ import annotations + +from typing import Any, Callable, NamedTuple from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation, identity -from torchopt.transform.utils import tree_map_flat -from torchopt.typing import Params +from torchopt.transform.utils import tree_map_flat, tree_map_flat_ +from torchopt.typing import OptState, Params, Updates __all__ = ['masked', 'add_decayed_weights'] class MaskedState(NamedTuple): - """Maintains inner transform state for masked transformations.""" + """Maintain inner transform state for masked transformations.""" inner_state: Any @@ -59,7 +61,7 @@ class MaskedNode(NamedTuple): def masked( inner: GradientTransformation, - mask: Union[Any, Callable[[Params], Any]], + mask: OptState | Callable[[Params], OptState] | None = None, ) -> GradientTransformation: """Mask updates so only some are transformed, the rest are passed through. @@ -75,11 +77,12 @@ def masked( of :data:`True`. Args: - inner: Inner transformation to mask. - mask: A tree with same structure as (or a prefix of) the params tree, or a Callable that - returns such a tree given the params/updates. The leaves should be booleans, :data:`True` - for leaves/subtrees you want to apply the transformation to, and :data:`False` for those - you want to skip. The mask must be static for the gradient transformation to be jit-compilable. + inner (GradientTransformation): Inner transformation to mask. + mask (tree of Tensor, callable, or None, optional): A tree with same structure as (or a + prefix of) the params tree, or a function that returns such a tree given the + params/updates. The leaves should be booleans, :data:`True` for leaves/subtrees you want + to apply the transformation to, and :data:`False` for those you want to skip. + (default: :data:`None`) Returns: A :class:`GradientTransformation` wrapping ``inner``. @@ -89,18 +92,17 @@ def masked( def _masked_flat( inner: GradientTransformation, - mask: Union[Any, Callable[[Params], Any]], + mask: OptState | Callable[[Params], OptState] | None = None, ) -> GradientTransformation: return _masked(inner, mask, already_flattened=True) def _masked( inner: GradientTransformation, - mask: Union[Any, Callable[[Params], Any]], + mask: OptState | Callable[[Params], OptState] | None = None, *, already_flattened: bool = False, ) -> GradientTransformation: - if already_flattened: tree_map = tree_map_flat else: @@ -109,12 +111,18 @@ def _masked( def tree_mask(params, mask_tree): return tree_map(lambda p, m: p if m else MaskedNode(), params, mask_tree) - def init_fn(params): + def init_fn(params: Params) -> OptState: mask_tree = mask(params) if callable(mask) else mask masked_params = tree_mask(params, mask_tree) return MaskedState(inner_state=inner.init(masked_params)) - def update_fn(updates, state, params=None, inplace=True): # pylint: disable=unused-argument + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, + inplace: bool = True, + ) -> tuple[Updates, OptState]: mask_tree = mask(updates) if callable(mask) else mask masked_updates = tree_mask(updates, mask_tree) masked_params = None if params is None else tree_mask(params, mask_tree) @@ -124,7 +132,7 @@ def update_fn(updates, state, params=None, inplace=True): # pylint: disable=unu ) new_updates = tree_map( - lambda new_u, old_u, m: new_u if m else old_u, new_masked_updates, updates, mask_tree + lambda old_u, new_u, m: new_u if m else old_u, updates, new_masked_updates, mask_tree ) return new_updates, MaskedState(inner_state=new_inner_state) @@ -140,16 +148,17 @@ def update_fn(updates, state, params=None, inplace=True): # pylint: disable=unu def add_decayed_weights( weight_decay: float = 0.0, - mask: Optional[Union[Any, Callable[[Params], Any]]] = None, + mask: OptState | Callable[[Params], OptState] | None = None, ) -> GradientTransformation: """Add parameter scaled by `weight_decay`. Args: - weight_decay: a scalar weight decay rate. - mask: a tree with same structure as (or a prefix of) the params tree, or a Callable that - returns such a pytree given the params/updates. The leaves should be booleans, - :data:`True` for leaves/subtrees you want to apply the transformation to, and - :data:`False` for those you want to skip. + weight_decay (float, optional): A scalar weight decay rate. (default: :const:`0.0`) + mask (tree of Tensor, callable, or None, optional): A tree with same structure as (or a + prefix of) the params tree, or a function that returns such a tree given the + params/updates. The leaves should be booleans, :data:`True` for leaves/subtrees you want + to apply the transformation to, and :data:`False` for those you want to skip. + (default: :data:`None`) Returns: An (init_fn, update_fn) tuple. @@ -163,7 +172,7 @@ def add_decayed_weights( def _add_decayed_weights_flat( weight_decay: float = 0.0, - mask: Optional[Union[Any, Callable[[Params], Any]]] = None, + mask: OptState | Callable[[Params], OptState] | None = None, ) -> GradientTransformation: return _add_decayed_weights( weight_decay=weight_decay, @@ -174,11 +183,12 @@ def _add_decayed_weights_flat( def _add_decayed_weights( weight_decay: float = 0.0, - mask: Optional[Union[Any, Callable[[Params], Any]]] = None, + mask: OptState | Callable[[Params], OptState] | None = None, *, already_flattened: bool = False, ) -> GradientTransformation: - if not 0.0 <= weight_decay: # pylint: disable=unneeded-not + # pylint: disable-next=unneeded-not + if not 0.0 <= weight_decay: # pragma: no cover raise ValueError(f'Invalid weight_decay value: {weight_decay}') if weight_decay == 0.0 and mask is None: @@ -186,13 +196,21 @@ def _add_decayed_weights( if already_flattened: tree_map = tree_map_flat + tree_map_ = tree_map_flat_ else: tree_map = pytree.tree_map # type: ignore[assignment] + tree_map_ = pytree.tree_map_ # type: ignore[assignment] - def init_fn(params): # pylint: disable=unused-argument + def init_fn(params: Params) -> OptState: # pylint: disable=unused-argument return AddDecayedWeightsState() - def update_fn(updates, state, params=None, inplace=True): # pylint: disable=unused-argument + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, + inplace: bool = True, + ) -> tuple[Updates, OptState]: assert params is not None, ( 'Parameters are required for weight decay. ' 'Call `update(updates, state, params=params)` instead.' @@ -205,12 +223,15 @@ def f(g, p): return g.add_(p, alpha=weight_decay) return g.add_(p.data, alpha=weight_decay) + updates = tree_map_(f, updates, params) + else: def f(g, p): return g.add(p, alpha=weight_decay) - updates = tree_map(f, updates, params) + updates = tree_map(f, updates, params) + return updates, state # If mask is not `None`, apply mask to the gradient transformation. diff --git a/torchopt/transform/nan_to_num.py b/torchopt/transform/nan_to_num.py index 11890c1b..804f8219 100644 --- a/torchopt/transform/nan_to_num.py +++ b/torchopt/transform/nan_to_num.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,25 +14,34 @@ # ============================================================================== """Preset transformations that replaces updates with non-finite values to the given numbers.""" -from typing import Optional +from __future__ import annotations from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation +from torchopt.typing import OptState, Params, Updates def nan_to_num( - nan: float = 0.0, posinf: Optional[float] = None, neginf: Optional[float] = None + nan: float = 0.0, + posinf: float | None = None, + neginf: float | None = None, ) -> GradientTransformation: - """Replaces updates with values ``nan`` / ``+inf`` / ``-inf`` to the given numbers. + """Replace updates with values ``nan`` / ``+inf`` / ``-inf`` to the given numbers. Returns: An ``(init_fn, update_fn)`` tuple. """ - def init_fn(params): # pylint: disable=unused-argument + def init_fn(params: Params) -> OptState: # pylint: disable=unused-argument return EmptyState() - def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: if inplace: def f(g): diff --git a/torchopt/transform/scale.py b/torchopt/transform/scale.py index 828b4b2f..639c903e 100644 --- a/torchopt/transform/scale.py +++ b/torchopt/transform/scale.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,9 +31,12 @@ # ============================================================================== """Preset transformation for scaling updates by learning rate.""" +from __future__ import annotations + from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation -from torchopt.transform.utils import tree_map_flat +from torchopt.transform.utils import tree_map_flat, tree_map_flat_ +from torchopt.typing import OptState, Params, Updates __all__ = ['scale'] @@ -46,7 +49,7 @@ def scale(step_size: float) -> GradientTransformation: """Scale updates by some fixed scalar ``step_size``. Args: - step_size: A scalar corresponding to a fixed scaling factor for updates. + step_size (float): A scalar corresponding to a fixed scaling factor for updates. Returns: An ``(init_fn, update_fn)`` tuple. @@ -58,27 +61,42 @@ def _scale_flat(step_size: float) -> GradientTransformation: return _scale(step_size=step_size, already_flattened=True) -def _scale(step_size: float, *, already_flattened: bool = False) -> GradientTransformation: +def _scale( + step_size: float, + *, + already_flattened: bool = False, +) -> GradientTransformation: if already_flattened: tree_map = tree_map_flat + tree_map_ = tree_map_flat_ else: tree_map = pytree.tree_map # type: ignore[assignment] + tree_map_ = pytree.tree_map_ # type: ignore[assignment] - def init_fn(params): # pylint: disable=unused-argument + def init_fn(params: Params) -> OptState: # pylint: disable=unused-argument return ScaleState() - def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: if inplace: def f(g): return g.mul_(step_size) + updates = tree_map_(f, updates) + else: def f(g): return g.mul(step_size) - updates = tree_map(f, updates) + updates = tree_map(f, updates) + return updates, state return GradientTransformation(init_fn, update_fn) diff --git a/torchopt/transform/scale_by_adam.py b/torchopt/transform/scale_by_adam.py index f0065712..36f30be9 100644 --- a/torchopt/transform/scale_by_adam.py +++ b/torchopt/transform/scale_by_adam.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,6 +33,8 @@ # pylint: disable=invalid-name +from __future__ import annotations + from typing import NamedTuple import torch @@ -41,13 +43,13 @@ from torchopt.accelerated_op import AdamOp from torchopt.base import GradientTransformation from torchopt.transform.utils import inc_count, tree_map_flat, update_moment -from torchopt.typing import SequenceOfTensors, Updates +from torchopt.typing import OptState, Params, Updates __all__ = ['scale_by_adam', 'scale_by_accelerated_adam'] -TRIPLE_PYTREE_SPEC = pytree.tree_structure((0, 1, 2)) # type: ignore[arg-type] +TRIPLE_PYTREE_SPEC = pytree.tree_structure((0, 1, 2), none_is_leaf=True) # type: ignore[arg-type] class ScaleByAdamState(NamedTuple): @@ -55,14 +57,20 @@ class ScaleByAdamState(NamedTuple): mu: Updates nu: Updates - count: SequenceOfTensors # type: ignore + count: OptState -def _bias_correction(moment, decay, count, *, already_flattened=False): +def _bias_correction( + moment: Updates, + decay: float, + count: OptState, + *, + already_flattened: bool = False, +) -> Updates: """Perform bias correction. This becomes a no-op as count goes to infinity.""" def f(t, c): # pylint: disable=invalid-name - return t.div(1 - decay**c) + return t.div(1 - pow(decay, c)) if already_flattened: return tree_map_flat(f, moment, count) @@ -82,17 +90,17 @@ def scale_by_adam( [Kingma et al, 2014](https://arxiv.org/abs/1412.6980) Args: - b1: (default: :const:`0.9`) - Decay rate for the exponentially weighted average of grads. - b2: (default: :const:`0.999`) - Decay rate for the exponentially weighted average of squared grads. - eps: (default: :const:`1e-8`) - Term added to the denominator to improve numerical stability. - eps_root: (default: :const:`0.0`) - Term added to the denominator inside the square-root to improve + b1 (float, optional): Decay rate for the exponentially weighted average of grads. + (default: :const:`0.9`) + b2 (float, optional): Decay rate for the exponentially weighted average of squared grads. + (default: :const:`0.999`) + eps (float, optional): Term added to the denominator to improve numerical stability. + (default: :const:`1e-8`) + eps_root (float, optional): Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. - moment_requires_grad: (default: :data:`False`) - If :data:`True`, states will be created with flag `requires_grad = True`. + (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag + ``requires_grad = True``. (default: :data:`False`) Returns: An (init_fn, update_fn) tuple. @@ -134,11 +142,11 @@ def _scale_by_adam( already_flattened: bool = False, ) -> GradientTransformation: # pylint: disable=unneeded-not - if not 0.0 <= eps: + if not 0.0 <= eps: # pragma: no cover raise ValueError(f'Invalid epsilon value: {eps}') - if not 0.0 <= b1 < 1.0: + if not 0.0 <= b1 < 1.0: # pragma: no cover raise ValueError(f'Invalid beta parameter at index 0: {b1}') - if not 0.0 <= b2 < 1.0: + if not 0.0 <= b2 < 1.0: # pragma: no cover raise ValueError(f'Invalid beta parameter at index 1: {b2}') # pylint: enable=unneeded-not @@ -147,7 +155,7 @@ def _scale_by_adam( else: tree_map = pytree.tree_map # type: ignore[assignment] - def init_fn(params): + def init_fn(params: Params) -> OptState: zero = tree_map( # count init lambda t: torch.zeros(1, dtype=torch.int64, device=t.device).squeeze_(), params ) @@ -159,7 +167,13 @@ def init_fn(params): ) return ScaleByAdamState(mu=mu, nu=nu, count=zero) - def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: mu = update_moment.impl( # type: ignore[attr-defined] updates, state.mu, b1, order=1, inplace=inplace, already_flattened=already_flattened ) @@ -206,17 +220,17 @@ def scale_by_accelerated_adam( [Kingma et al, 2014](https://arxiv.org/abs/1412.6980) Args: - b1: (default: :const:`0.9`) - Decay rate for the exponentially weighted average of grads. - b2: (default: :const:`0.999`) - Decay rate for the exponentially weighted average of squared grads. - eps: (default: :const:`1e-8`) - Term added to the denominator to improve numerical stability. - eps_root: (default: :const:`0.0`) - Term added to the denominator inside the square-root to improve + b1 (float, optional): Decay rate for the exponentially weighted average of grads. + (default: :const:`0.9`) + b2 (float, optional): Decay rate for the exponentially weighted average of squared grads. + (default: :const:`0.999`) + eps (float, optional): Term added to the denominator to improve numerical stability. + (default: :const:`1e-8`) + eps_root (float, optional): Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. - moment_requires_grad: (default: :data:`False`) - If :data:`True`, states will be created with flag `requires_grad = True`. + (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag + ``requires_grad = True``. (default: :data:`False`) Returns: An (init_fn, update_fn) tuple. @@ -258,19 +272,24 @@ def _scale_by_accelerated_adam( already_flattened: bool = False, ) -> GradientTransformation: # pylint: disable=unneeded-not - if not 0.0 <= eps: + if not 0.0 <= eps: # pragma: no cover raise ValueError(f'Invalid epsilon value: {eps}') - if not 0.0 <= b1 < 1.0: + if not 0.0 <= b1 < 1.0: # pragma: no cover raise ValueError(f'Invalid beta parameter at index 0: {b1}') - if not 0.0 <= b2 < 1.0: + if not 0.0 <= b2 < 1.0: # pragma: no cover raise ValueError(f'Invalid beta parameter at index 1: {b2}') # pylint: enable=unneeded-not if already_flattened: tree_map = tree_map_flat - # pylint: disable-next=unused-argument - def update_fn(updates, state, *, params=None, inplace=True): + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: count_inc = inc_count.impl(updates, state.count, already_flattened=True) # type: ignore[attr-defined] op = AdamOp(b1=b1, b2=b2, eps=eps, eps_root=eps_root, inplace=inplace) @@ -282,11 +301,16 @@ def update_fn(updates, state, *, params=None, inplace=True): else: tree_map = pytree.tree_map # type: ignore[assignment] - # pylint: disable-next=unused-argument - def update_fn(updates, state, *, params=None, inplace=True): + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: count_inc = inc_count.impl(updates, state.count, already_flattened=False) # type: ignore[attr-defined] - treespec = pytree.tree_structure(updates) + treespec = pytree.tree_structure(updates, none_is_leaf=True) op = AdamOp(b1=b1, b2=b2, eps=eps, eps_root=eps_root, inplace=inplace) out = pytree.tree_map(op, state.mu, state.nu, updates, count_inc) @@ -297,7 +321,7 @@ def update_fn(updates, state, *, params=None, inplace=True): new_mu, new_nu, new_updates = pytree.tree_transpose(treespec, TRIPLE_PYTREE_SPEC, out) # type: ignore[misc] return new_updates, ScaleByAdamState(mu=new_mu, nu=new_nu, count=count_inc) - def init_fn(params): + def init_fn(params: Params) -> OptState: zero = tree_map( # count init lambda t: torch.zeros(1, dtype=torch.int64, device=t.device).squeeze_(), params ) diff --git a/torchopt/transform/scale_by_rms.py b/torchopt/transform/scale_by_rms.py index 3451fafe..7a0c8c20 100644 --- a/torchopt/transform/scale_by_rms.py +++ b/torchopt/transform/scale_by_rms.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,14 +31,16 @@ # ============================================================================== """Preset transformations for scaling updates by exponential root mean-squared (RMS).""" +from __future__ import annotations + from typing import NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation -from torchopt.transform.utils import tree_map_flat, update_moment -from torchopt.typing import Updates +from torchopt.transform.utils import tree_map_flat, tree_map_flat_, update_moment +from torchopt.typing import OptState, Params, Updates __all__ = ['scale_by_rms'] @@ -51,7 +53,9 @@ class ScaleByRmsState(NamedTuple): def scale_by_rms( - alpha: float = 0.9, eps: float = 1e-8, initial_scale: float = 0.0 + alpha: float = 0.9, + eps: float = 1e-8, + initial_scale: float = 0.0, ) -> GradientTransformation: """Rescale updates by the root of the exp. moving avg of the square. @@ -59,12 +63,11 @@ def scale_by_rms( [Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) Args: - alpha: (default: :const:`0.9`) - Decay rate for the exponentially weighted average of squared grads. - eps: (default: :const:`1e-8`) - Term added to the denominator to improve numerical stability. - initial_scale: (default: :const:`0.0`) - Initial value for second moment + alpha (float, optional): Decay rate for the exponentially weighted average of squared grads. + (default: :const:`0.9`) + eps (float, optional): Term added to the denominator to improve numerical stability. + (default: :const:`1e-8`) + initial_scale (float, optional): Initial value for second moment. (default: :const:`0.0`) Returns: An (init_fn, update_fn) tuple. @@ -78,7 +81,9 @@ def scale_by_rms( def _scale_by_rms_flat( - alpha: float = 0.9, eps: float = 1e-8, initial_scale: float = 0.0 + alpha: float = 0.9, + eps: float = 1e-8, + initial_scale: float = 0.0, ) -> GradientTransformation: return _scale_by_rms( alpha=alpha, @@ -96,22 +101,30 @@ def _scale_by_rms( already_flattened: bool = False, ) -> GradientTransformation: # pylint: disable=unneeded-not - if not 0.0 <= alpha: + if not 0.0 <= alpha: # pragma: no cover raise ValueError(f'Invalid alpha value: {alpha}') - if not 0.0 <= eps: + if not 0.0 <= eps: # pragma: no cover raise ValueError(f'Invalid epsilon value: {eps}') # pylint: enable=unneeded-not if already_flattened: tree_map = tree_map_flat + tree_map_ = tree_map_flat_ else: tree_map = pytree.tree_map # type: ignore[assignment] + tree_map_ = pytree.tree_map_ # type: ignore[assignment] - def init_fn(params): + def init_fn(params: Params) -> OptState: nu = tree_map(lambda n: torch.full_like(n, initial_scale), params) # second moment return ScaleByRmsState(nu=nu) - def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: nu = update_moment.impl( # type: ignore[attr-defined] updates, state.nu, alpha, order=2, inplace=inplace, already_flattened=already_flattened ) @@ -121,12 +134,15 @@ def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable= def f(g, n): # pylint: disable=invalid-name return g.div_(n.sqrt().add_(eps)) + updates = tree_map_(f, updates, nu) + else: def f(g, n): # pylint: disable=invalid-name return g.div(n.sqrt().add(eps)) - updates = tree_map(f, updates, nu) + updates = tree_map(f, updates, nu) + return updates, ScaleByRmsState(nu=nu) return GradientTransformation(init_fn, update_fn) diff --git a/torchopt/transform/scale_by_schedule.py b/torchopt/transform/scale_by_schedule.py index 49b6abb7..d6e3b0fa 100644 --- a/torchopt/transform/scale_by_schedule.py +++ b/torchopt/transform/scale_by_schedule.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,21 +31,23 @@ # ============================================================================== """Preset transformation for scaling updates by learning rate schedules.""" +from __future__ import annotations + from typing import NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation -from torchopt.transform.utils import inc_count, tree_map_flat -from torchopt.typing import Schedule, SequenceOfTensors +from torchopt.transform.utils import inc_count, tree_map_flat, tree_map_flat_ +from torchopt.typing import OptState, Params, Schedule, SequenceOfTensors, Updates __all__ = ['scale_by_schedule'] class ScaleByScheduleState(NamedTuple): - """Maintains count for scale scheduling.""" + """Maintain count for scale scheduling.""" count: SequenceOfTensors # type: ignore @@ -54,9 +56,8 @@ def scale_by_schedule(step_size_fn: Schedule) -> GradientTransformation: """Scale updates using a custom schedule for the ``step_size``. Args: - step_size_fn: - A function that takes an update count as input and proposes the ``step_size`` to - multiply the updates by. + step_size_fn (callable): A function that takes an update count as input and proposes the + ``step_size`` to multiply the updates by. Returns: An ``(init_fn, update_fn)`` tuple. @@ -69,33 +70,46 @@ def _scale_by_schedule_flat(step_size_fn: Schedule) -> GradientTransformation: def _scale_by_schedule( - step_size_fn: Schedule, *, already_flattened: bool = False + step_size_fn: Schedule, + *, + already_flattened: bool = False, ) -> GradientTransformation: if already_flattened: tree_map = tree_map_flat + tree_map_ = tree_map_flat_ else: tree_map = pytree.tree_map # type: ignore[assignment] + tree_map_ = pytree.tree_map_ # type: ignore[assignment] - def init_fn(params): + def init_fn(params: Params) -> OptState: zero = tree_map( # count init lambda t: torch.zeros(1, dtype=torch.int64, device=t.device).squeeze_(), params ) return ScaleByScheduleState(count=zero) - def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: if inplace: def f(g, c): # pylint: disable=invalid-name step_size = step_size_fn(c) return g.mul_(step_size) + updates = tree_map_(f, updates, state.count) + else: def f(g, c): # pylint: disable=invalid-name step_size = step_size_fn(c) return g.mul(step_size) - updates = tree_map(f, updates, state.count) + updates = tree_map(f, updates, state.count) + return ( updates, ScaleByScheduleState( diff --git a/torchopt/transform/scale_by_stddev.py b/torchopt/transform/scale_by_stddev.py index 37138566..228ed707 100644 --- a/torchopt/transform/scale_by_stddev.py +++ b/torchopt/transform/scale_by_stddev.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,14 +33,16 @@ # pylint: disable=invalid-name +from __future__ import annotations + from typing import NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation -from torchopt.transform.utils import tree_map_flat, update_moment -from torchopt.typing import Updates +from torchopt.transform.utils import tree_map_flat, tree_map_flat_, update_moment +from torchopt.typing import OptState, Params, Updates __all__ = ['scale_by_stddev'] @@ -54,7 +56,9 @@ class ScaleByRStdDevState(NamedTuple): def scale_by_stddev( - alpha: float = 0.9, eps: float = 1e-8, initial_scale: float = 0.0 + alpha: float = 0.9, + eps: float = 1e-8, + initial_scale: float = 0.0, ) -> GradientTransformation: """Rescale updates by the root of the centered exponential moving average of squares. @@ -62,12 +66,11 @@ def scale_by_stddev( [Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) Args: - alpha: (default: :const:`0.9`) - Decay rate for the exponentially weighted average of squared grads. - eps: (default: :const:`1e-8`) - Term added to the denominator to improve numerical stability. - initial_scale: (default: :const:`0.0`) - Initial value for second moment + alpha (float, optional): Decay rate for the exponentially weighted average of squared grads. + (default: :const:`0.9`) + eps (float, optional): Term added to the denominator to improve numerical stability. + (default: :const:`1e-8`) + initial_scale (float, optional): Initial value for second moment. (default: :const:`0.0`) Returns: An (init_fn, update_fn) tuple. @@ -81,7 +84,9 @@ def scale_by_stddev( def _scale_by_stddev_flat( - alpha: float = 0.9, eps: float = 1e-8, initial_scale: float = 0.0 + alpha: float = 0.9, + eps: float = 1e-8, + initial_scale: float = 0.0, ) -> GradientTransformation: return _scale_by_stddev( alpha=alpha, @@ -99,23 +104,31 @@ def _scale_by_stddev( already_flattened: bool = False, ) -> GradientTransformation: # pylint: disable=unneeded-not - if not 0.0 <= alpha: + if not 0.0 <= alpha: # pragma: no cover raise ValueError(f'Invalid alpha value: {alpha}') - if not 0.0 <= eps: + if not 0.0 <= eps: # pragma: no cover raise ValueError(f'Invalid epsilon value: {eps}') # pylint: enable=unneeded-not if already_flattened: tree_map = tree_map_flat + tree_map_ = tree_map_flat_ else: tree_map = pytree.tree_map # type: ignore[assignment] + tree_map_ = pytree.tree_map_ # type: ignore[assignment] - def init_fn(params): + def init_fn(params: Params) -> OptState: mu = tree_map(torch.zeros_like, params) # first moment nu = tree_map(lambda n: torch.full_like(n, initial_scale), params) # second moment return ScaleByRStdDevState(mu=mu, nu=nu) - def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: mu = update_moment.impl( # type: ignore[attr-defined] updates, state.mu, alpha, order=1, inplace=inplace, already_flattened=already_flattened ) @@ -128,12 +141,15 @@ def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable= def f(g, m, n): return g.div_(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) + updates = tree_map_(f, updates, mu, nu) + else: def f(g, m, n): return g.div(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) - updates = tree_map(f, updates, mu, nu) + updates = tree_map(f, updates, mu, nu) + return updates, ScaleByRStdDevState(mu=mu, nu=nu) return GradientTransformation(init_fn, update_fn) diff --git a/torchopt/transform/trace.py b/torchopt/transform/trace.py index 1d741d04..03d2441d 100644 --- a/torchopt/transform/trace.py +++ b/torchopt/transform/trace.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,21 +33,23 @@ # pylint: disable=invalid-name +from __future__ import annotations + from typing import NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation, identity -from torchopt.transform.utils import tree_map_flat -from torchopt.typing import Params +from torchopt.transform.utils import tree_map_flat, tree_map_flat_ +from torchopt.typing import OptState, Params, Updates __all__ = ['trace'] class TraceState(NamedTuple): - """Holds an aggregation of past updates.""" + """Hold an aggregation of past updates.""" trace: Params @@ -65,14 +67,12 @@ def trace( Both are frequently found in the optimization literature. Args: - momentum: (default: :const:`0.9`) - The decay rate for the trace of past updates. - dampening: (default: :const:`0.0`) - Dampening for momentum. - nesterov: (default: :data:`False`) - Whether to use Nesterov momentum. - moment_requires_grad: (default: :data:`False`) - If :data:`True`, states will be created with flag `requires_grad = True`. + momentum (float, optional): The decay rate for the trace of past updates. + (default: :const:`0.9`) + dampening (float, optional): Dampening for momentum. (default: :const:`0.0`) + nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`) + moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag + ``requires_grad = True``. (default: :data:`False`) Returns: An (init_fn, update_fn) tuple. @@ -110,9 +110,9 @@ def _trace( already_flattened: bool = False, ) -> GradientTransformation: # pylint: disable=unneeded-not - if not 0.0 <= momentum: + if not 0.0 <= momentum: # pragma: no cover raise ValueError(f'Invalid momentum value: {momentum}') - if nesterov and (momentum <= 0.0 or dampening != 0.0): + if nesterov and (momentum <= 0.0 or dampening != 0.0): # pragma: no cover raise ValueError('Nesterov momentum requires a momentum and zero dampening') # pylint: enable=unneeded-not @@ -121,10 +121,12 @@ def _trace( if already_flattened: tree_map = tree_map_flat + tree_map_ = tree_map_flat_ else: tree_map = pytree.tree_map # type: ignore[assignment] + tree_map_ = pytree.tree_map_ # type: ignore[assignment] - def init_fn(params): + def init_fn(params: Params) -> OptState: return TraceState( trace=tree_map( lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params @@ -133,7 +135,13 @@ def init_fn(params): first_call = True - def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: nonlocal first_call if nesterov: @@ -148,7 +156,8 @@ def f2(g, t): return g.add_(t, alpha=momentum) new_trace = tree_map(f1, updates, state.trace) - updates = tree_map(f2, updates, new_trace) + updates = tree_map_(f2, updates, new_trace) + else: def f1(g, t): @@ -161,19 +170,21 @@ def f2(g, t): new_trace = tree_map(f1, updates, state.trace) updates = tree_map(f2, updates, new_trace) + else: if inplace: def f(g, t): if first_call: - return t.add(g) + return t.add_(g) return t.mul_(momentum).add_(g, alpha=1.0 - dampening) def copy_(g, t): return g.copy_(t) new_trace = tree_map(f, updates, state.trace) - updates = tree_map(copy_, updates, new_trace) + updates = tree_map_(copy_, updates, new_trace) + else: def f(g, t): diff --git a/torchopt/transform/utils.py b/torchopt/transform/utils.py index 497df44e..77ba58ca 100644 --- a/torchopt/transform/utils.py +++ b/torchopt/transform/utils.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,8 +31,10 @@ # ============================================================================== """Utilities for the preset transformations.""" +from __future__ import annotations + from collections import deque -from typing import Any, Callable, Iterable, List +from typing import Any, Callable, Sequence import torch @@ -46,7 +48,12 @@ INT64_MAX = torch.iinfo(torch.int64).max -def tree_map_flat(func: Callable, *flat_args: Any, none_is_leaf: bool = False) -> List[Any]: +def tree_map_flat( + func: Callable, + flat_arg: Sequence[Any], + *flat_args: Any, + none_is_leaf: bool = False, +) -> Sequence[Any]: """Apply a function to each element of a flattened list.""" if none_is_leaf: fn = func @@ -55,13 +62,16 @@ def tree_map_flat(func: Callable, *flat_args: Any, none_is_leaf: bool = False) - def fn(x, *xs): return func(x, *xs) if x is not None else None - return list(map(fn, *flat_args)) + return flat_arg.__class__(map(fn, flat_arg, *flat_args)) # type: ignore[call-arg] def tree_map_flat_( - func: Callable, flat_arg: Iterable[Any], *flat_args: Any, none_is_leaf: bool = False -) -> Iterable[Any]: - """Apply a function to each element of a flattened list.""" + func: Callable, + flat_arg: Sequence[Any], + *flat_args: Any, + none_is_leaf: bool = False, +) -> Sequence[Any]: + """Apply a function to each element of a flattened list and return the original list.""" if none_is_leaf: fn = func else: @@ -75,51 +85,93 @@ def fn(x, *xs): def inc_count(updates: Updates, count: TensorTree) -> TensorTree: - """Increments int counter by one. + """Increment int counter by one. Returns: A counter incremented by one, or :data:`INT64_MAX` if the maximum precision is reached. """ - return _inc_count(updates=updates, count=count, already_flattened=False) + return _inc_count( + updates=updates, + count=count, + already_flattened=False, + ) def _inc_count_flat(updates: Updates, count: TensorTree) -> TensorTree: - return _inc_count(updates=updates, count=count, already_flattened=True) + return _inc_count( + updates=updates, + count=count, + already_flattened=True, + ) def _inc_count( - updates: Updates, count: TensorTree, *, already_flattened: bool = False + updates: Updates, + count: TensorTree, + *, + already_flattened: bool = False, ) -> TensorTree: def f(c, g): # pylint: disable=invalid-name return c + (c != INT64_MAX).to(torch.int64) if g is not None else c if already_flattened: - return tree_map_flat(f, count, updates) - return pytree.tree_map(f, count, updates) + return tree_map_flat(f, count, updates, none_is_leaf=True) + return pytree.tree_map(f, count, updates, none_is_leaf=True) inc_count.flat = _inc_count_flat # type: ignore[attr-defined] inc_count.impl = _inc_count # type: ignore[attr-defined] -def update_moment(updates, moments, decay, *, order, inplace=True): +def update_moment( + updates: Updates, + moments: TensorTree, + decay: float, + *, + order: int, + inplace: bool = True, +) -> TensorTree: """Compute the exponential moving average of the ``order``-th moment.""" return _update_moment( - updates, moments, decay, order=order, inplace=inplace, already_flattened=False + updates, + moments, + decay, + order=order, + inplace=inplace, + already_flattened=False, ) -def _update_moment_flat(updates, moments, decay, *order, inplace=True): +def _update_moment_flat( + updates: Updates, + moments: TensorTree, + decay: float, + *, + order: int, + inplace: bool = True, +) -> TensorTree: return _update_moment( - updates, moments, decay, order=order, inplace=inplace, already_flattened=True + updates, + moments, + decay, + order=order, + inplace=inplace, + already_flattened=True, ) -def _update_moment(updates, moments, decay, *, order, inplace=True, already_flattened=False): +def _update_moment( + updates: Updates, + moments: TensorTree, + decay: float, + *, + order: int, + inplace: bool = True, + already_flattened=False, +) -> TensorTree: assert order in (1, 2) if inplace: - if order == 2: def f(g, t): @@ -131,7 +183,6 @@ def f(g, t): return t.mul_(decay).add_(g, alpha=1 - decay) if g is not None else t else: - if order == 2: def f(g, t): @@ -143,7 +194,7 @@ def f(g, t): return t.mul(decay).add_(g, alpha=1 - decay) if g is not None else t if already_flattened: - return tree_map_flat(f, updates, moments) + return tree_map_flat(f, updates, moments, none_is_leaf=True) return pytree.tree_map(f, updates, moments, none_is_leaf=True) diff --git a/torchopt/typing.py b/torchopt/typing.py index a7499a99..2075dc62 100644 --- a/torchopt/typing.py +++ b/torchopt/typing.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,8 @@ # ============================================================================== """Typing utilities.""" -from typing import Callable, List, Optional, Sequence, Tuple, TypeVar, Union +import abc +from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union from typing_extensions import TypeAlias # Python 3.10+ from typing_extensions import Protocol, runtime_checkable # Python 3.8+ @@ -24,7 +25,6 @@ from torch import Tensor from torch.distributions import Distribution from torch.futures import Future -from torch.types import Device from torchopt.base import ( ChainedGradientTransformation, @@ -59,6 +59,8 @@ 'SequenceOfOptionalTensors', 'OptionalTensorOrOptionalTensors', 'OptionalTensorTree', + 'TensorContainer', + 'ModuleTensorContainers', 'Future', 'LinearSolver', 'Device', @@ -70,6 +72,8 @@ T = TypeVar('T') +Device: TypeAlias = Union[torch.device, str, int] + Scalar: TypeAlias = Union[float, int, bool] Numeric: TypeAlias = Union[Tensor, Scalar] @@ -90,17 +94,21 @@ OptionalTensorOrOptionalTensors = Union[OptionalTensor, SequenceOfOptionalTensors] OptionalTensorTree: TypeAlias = PyTreeTypeVar('OptionalTensorTree', OptionalTensor) # type: ignore[valid-type] +TensorContainer = Dict[str, Optional[Tensor]] +ModuleTensorContainers = Tuple[TensorContainer, ...] + # Parameters are arbitrary nests of `torch.Tensor`. Params: TypeAlias = TensorTree Updates: TypeAlias = Params # Gradient updates are of the same type as parameters. OptState: TypeAlias = TensorTree # States are arbitrary nests of `torch.Tensor`. -if rpc.is_available(): +if rpc.is_available(): # pragma: no cover from torch.distributed.rpc import RRef # pylint: disable=ungrouped-imports,unused-import __all__.extend(['RRef']) -else: - RRef = None # type: ignore[misc,assignment] # pylint: disable=invalid-name +else: # pragma: no cover + # pylint: disable-next=invalid-name + RRef = None # type: ignore[misc,assignment] # solver(matvec, b) -> solution LinearSolver: TypeAlias = Callable[[Callable[[TensorTree], TensorTree], TensorTree], TensorTree] @@ -116,12 +124,13 @@ class Samplable(Protocol): # pylint: disable=too-few-public-methods """Abstract protocol class that supports sampling.""" + @abc.abstractmethod def sample( self, sample_shape: Size = Size() # pylint: disable=unused-argument ) -> Union[Tensor, Sequence[Numeric]]: # pylint: disable-next=line-too-long - """Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.""" - raise NotImplementedError + """Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.""" + raise NotImplementedError # pragma: no cover Samplable.register(Distribution) diff --git a/torchopt/update.py b/torchopt/update.py index 85e93673..9485896b 100644 --- a/torchopt/update.py +++ b/torchopt/update.py @@ -39,7 +39,7 @@ def apply_updates(params: Params, updates: Updates, *, inplace: bool = True) -> Params: - """Applies an update to the corresponding parameters. + """Apply an update to the corresponding parameters. This is a utility functions that applies an update to a set of parameters, and then returns the updated parameters to the caller. As an example, the update may be a gradient transformed by a @@ -48,11 +48,11 @@ def apply_updates(params: Params, updates: Updates, *, inplace: bool = True) -> :func:`tree_map` (e.g. if you want to manipulate updates in custom ways before applying them). Args: - params: A tree of parameters. - updates: - A tree of updates, the tree structure and the shape of the leaf nodes must match that - of ``params``. - inplace: If :data:`True`, will update params in a inplace manner. + params (tree of Tensor): A tree of parameters. + updates (tree of Tensor): A tree of updates, the tree structure and the shape of the leaf + nodes must match that of ``params``. + inplace (bool, optional): If :data:`True`, will update params in a inplace manner. + (default: :data:`True`) Returns: Updated parameters, with same structure, shape and type as ``params``. diff --git a/torchopt/utils.py b/torchopt/utils.py index f60bc6d6..12adb214 100644 --- a/torchopt/utils.py +++ b/torchopt/utils.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,21 +14,11 @@ # ============================================================================== """Utilities for TorchOpt.""" +from __future__ import annotations + import copy import itertools -from typing import ( - TYPE_CHECKING, - Dict, - List, - NamedTuple, - Optional, - Sequence, - Set, - Tuple, - Union, - cast, - overload, -) +from typing import TYPE_CHECKING, NamedTuple, Sequence, cast, overload from typing_extensions import Literal # Python 3.8+ from typing_extensions import TypeAlias # Python 3.10+ @@ -36,10 +26,10 @@ import torch.nn as nn from torchopt import pytree -from torchopt.typing import Device, OptState, TensorTree +from torchopt.typing import Device, ModuleTensorContainers, OptState, TensorContainer, TensorTree -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from torchopt.optim.meta.base import MetaOptimizer @@ -56,32 +46,30 @@ class ModuleState(NamedTuple): """Container for module state.""" - params: Tuple[Dict[str, torch.Tensor], ...] - buffers: Tuple[Dict[str, torch.Tensor], ...] - visual_contents: Optional[Dict] = None + params: tuple[dict[str, torch.Tensor], ...] + buffers: tuple[dict[str, torch.Tensor], ...] + visual_contents: dict | None = None detach_buffers: bool = False CopyMode: TypeAlias = Literal['reference', 'copy', 'deepcopy', 'ref', 'clone', 'deepclone'] -def stop_gradient(target: Union[TensorTree, ModuleState, nn.Module, 'MetaOptimizer']) -> None: +def stop_gradient(target: ModuleState | nn.Module | MetaOptimizer | TensorTree) -> None: """Stop the gradient for the input object. - Since a tensor use :attr:`grad_fn` to connect itself with the previous computation graph, the + Since a tensor use ``grad_fn`` to connect itself with the previous computation graph, the backpropagated gradient will flow over the tensor and continue flow to the tensors that is - connected by :attr:`grad_fn`. Some algorithms requires manually detaching tensors from the + connected by ``grad_fn``. Some algorithms requires manually detaching tensors from the computation graph. Note that the :func:`stop_gradient` operation is in-place. Args: - target: The target that to be detached from the computation graph, it could be a - :class:`nn.Module`, :class:`torchopt.MetaOptimizer`, state of the - :class:`torchopt.MetaOptimizer`, or just a plain list of tensors. - inplace: If :data:`True`, the target will be detached in-place. if :data:`Frue`, this - function will return a detached copy of the target. The in-place operation is fast and - memory efficient but may raise backpropagation error. + target (ModuleState, nn.Module, MetaOptimizer, or tree of Tensor): The target that to be + detached from the computation graph, it could be a :class:`nn.Module`, + :class:`torchopt.MetaOptimizer`, state of the :class:`torchopt.MetaOptimizer`, or just + a plain list of tensors. """ # pylint: disable-next=import-outside-toplevel from torchopt.optim.meta.base import MetaOptimizer @@ -108,67 +96,72 @@ def extract_state_dict( target: nn.Module, *, by: CopyMode = 'reference', - device: Device = None, + device: Device | None = None, with_buffers: bool = True, enable_visual: bool = False, visual_prefix: str = '', -) -> ModuleState: +) -> ModuleState: # pragma: no cover ... @overload def extract_state_dict( - target: 'MetaOptimizer', + target: MetaOptimizer, *, by: CopyMode = 'reference', - device: Device = None, + device: Device | None = None, with_buffers: bool = True, enable_visual: bool = False, visual_prefix: str = '', -) -> Tuple[OptState, ...]: +) -> tuple[OptState, ...]: # pragma: no cover ... # pylint: disable-next=too-many-branches,too-many-locals def extract_state_dict( - target: Union[nn.Module, 'MetaOptimizer'], + target: nn.Module | MetaOptimizer, *, by: CopyMode = 'reference', - device: Device = None, + device: Device | None = None, with_buffers: bool = True, detach_buffers: bool = False, enable_visual: bool = False, visual_prefix: str = '', -) -> Union[ModuleState, Tuple[OptState, ...]]: +) -> ModuleState | tuple[OptState, ...]: """Extract target state. - Since a tensor use :attr:`grad_fn` to connect itself with the previous computation graph, the + Since a tensor use ``grad_fn`` to connect itself with the previous computation graph, the backpropagated gradient will flow over the tensor and continue flow to the tensors that is - connected by :attr:`grad_fn`. Some algorithms requires manually detaching tensors from the + connected by ``grad_fn``. Some algorithms requires manually detaching tensors from the computation graph. Note that the extracted state is a reference, which means any in-place operator will affect the target that the state is extracted from. Args: - target: It could be a :class:`nn.Module` or :class:`torchopt.MetaOptimizer`. - by: The extract policy of tensors in the target. + target (nn.Module or MetaOptimizer): It could be a :class:`nn.Module` or + :class:`torchopt.MetaOptimizer`. + by (str, optional): The extract policy of tensors in the target. (default: :const:`'reference'`) - :const:`'reference'`: The extracted tensors will be references to the original tensors. - :const:`'copy'`: The extracted tensors will be clones of the original tensors. This - makes the copied tensors have :attr:`grad_fn` to be a ```` function - points to the original tensors. + makes the copied tensors have ``grad_fn`` to be a ```` function points + to the original tensors. - :const:`'deepcopy'`: The extracted tensors will be deep-copied from the original tensors. The deep-copied tensors will detach from the original computation graph. - device: If specified, move the extracted state to the specified device. - with_buffers: Extract buffer together with parameters, this argument is only used if the - input target is :class:`nn.Module`. - detach_buffers: Whether to detach the reference to the buffers, this argument is only used - if the input target is :class:`nn.Module` and ``by='reference'``. - enable_visual: Add additional annotations, which could be used in computation graph - visualization. Currently, this flag only has effect on :class:`nn.Module` but we will - support :class:`torchopt.MetaOptimizer` later. - visual_prefix: Prefix for the visualization annotations. + device (Device or None, optional): If specified, move the extracted state to the specified + device. (default: :const:`None`) + with_buffers (bool, optional): Extract buffer together with parameters, this argument is + only used if the input target is :class:`nn.Module`. (default: :const:`True`) + detach_buffers (bool, optional): Whether to detach the reference to the buffers, this + argument is only used if the input target is :class:`nn.Module` and ``by='reference'``. + (default: :const:`False`) + enable_visual (bool, optional): Add additional annotations, which could be used in + computation graph visualization. Currently, this flag only has effect on + :class:`nn.Module` but we will support :class:`torchopt.MetaOptimizer` later. + (default: :const:`False`) + visual_prefix (str, optional): Prefix for the visualization annotations. + (default: :const:`''`) Returns: State extracted of the input object. @@ -191,10 +184,10 @@ def clone(t: torch.Tensor) -> torch.Tensor: def clone_detach_(t: torch.Tensor) -> torch.Tensor: if isinstance(t, nn.Parameter): - return nn.Parameter(t.clone().detach_(), requires_grad=t.requires_grad).to( - device=target_device + return nn.Parameter( + t.clone().to(device=target_device).detach_(), requires_grad=t.requires_grad ) - return t.clone().detach_().to(device=target_device).requires_grad_(t.requires_grad) + return t.clone().to(device=target_device).detach_().requires_grad_(t.requires_grad) else: @@ -228,9 +221,9 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor: else: visual_contents = None - params: List[Dict[str, torch.Tensor]] = [] - buffers: List[Dict[str, torch.Tensor]] = [] - memo: Set[nn.Module] = set() + params: list[dict[str, torch.Tensor]] = [] + buffers: list[dict[str, torch.Tensor]] = [] + memo: set[nn.Module] = set() def update_params(container): if len(container) > 0: @@ -287,15 +280,12 @@ def get_variable(t): def extract_module_containers( module: nn.Module, with_buffers: bool = True -) -> Tuple[ - Tuple[Dict[str, Optional[torch.Tensor]], ...], - Tuple[Dict[str, Optional[torch.Tensor]], ...], -]: +) -> tuple[ModuleTensorContainers, ModuleTensorContainers]: """Extract the references to the containers of parameters and buffers from a module.""" if isinstance(module, nn.Module): - params: List[Dict[str, Optional[torch.Tensor]]] = [] - buffers: List[Dict[str, Optional[torch.Tensor]]] = [] - memo: Set[nn.Module] = set() + params: list[TensorContainer] = [] + buffers: list[TensorContainer] = [] + memo: set[nn.Module] = set() def update_container(container, items): if len(items) > 0: @@ -319,8 +309,8 @@ def update_container(container, items): def recover_state_dict( - target: Union[nn.Module, 'MetaOptimizer'], - state: Union[ModuleState, Sequence[OptState]], + target: nn.Module | MetaOptimizer, + state: ModuleState | Sequence[OptState], ) -> None: """Recover state. @@ -330,8 +320,8 @@ def recover_state_dict( modified. Args: - target: Target that need to recover. - state: The recovering state. + target (nn.Module or MetaOptimizer): Target that need to recover. + state (ModuleState or sequence of tree of Tensor): The recovering state. """ # pylint: disable-next=import-outside-toplevel from torchopt.optim.meta.base import MetaOptimizer @@ -347,10 +337,7 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor: return nn.Parameter(t.clone().detach_(), requires_grad=t.requires_grad) return t.clone().detach_().requires_grad_(t.requires_grad) - buffers = cast( - Tuple[Dict[str, torch.Tensor], ...], - pytree.tree_map(clone_detach_, buffers), # type: ignore[arg-type] - ) + buffers = pytree.tree_map(clone_detach_, buffers) # type: ignore[assignment,arg-type] for tgt, src in itertools.chain( zip(params_containers, params), @@ -370,19 +357,19 @@ def module_clone( *, by: CopyMode = 'reference', detach_buffers: bool = False, - device: Device = None, -) -> nn.Module: + device: Device | None = None, +) -> nn.Module: # pragma: no cover ... @overload def module_clone( - target: 'MetaOptimizer', + target: MetaOptimizer, *, by: CopyMode = 'reference', detach_buffers: bool = False, - device: Device = None, -) -> 'MetaOptimizer': + device: Device | None = None, +) -> MetaOptimizer: # pragma: no cover ... @@ -392,34 +379,36 @@ def module_clone( *, by: CopyMode = 'reference', detach_buffers: bool = False, - device: Device = None, -) -> TensorTree: + device: Device | None = None, +) -> TensorTree: # pragma: no cover ... # pylint: disable-next=too-many-locals def module_clone( - target: Union[nn.Module, 'MetaOptimizer', TensorTree], + target: nn.Module | MetaOptimizer | TensorTree, *, by: CopyMode = 'reference', detach_buffers: bool = False, - device: Device = None, -) -> Union[nn.Module, 'MetaOptimizer', TensorTree]: + device: Device | None = None, +) -> nn.Module | MetaOptimizer | TensorTree: """Clone a module. Args: - target: The target to be cloned. - by: The extract policy of tensors in the target. + target (nn.Module, MetaOptimizer, or tree of Tensor): The target to be cloned. + by (str, optional): The extract policy of tensors in the target. (default: :const:`'reference'`) - :const:`'reference'`: The extracted tensors will be references to the original tensors. - :const:`'copy'`: The extracted tensors will be clones of the original tensors. This - makes the copied tensors have :attr:`grad_fn` to be a ```` function - points to the original tensors. + makes the copied tensors have ``grad_fn`` to be a ```` function points + to the original tensors. - :const:`'deepcopy'`: The extracted tensors will be deep-copied from the original tensors. The deep-copied tensors will detach from the original computation graph. - detach_buffers: Whether to detach the reference to the buffers, this argument is only used - if the input target is :class:`nn.Module` and ``by='reference'``. - device: If specified, move the cloned module to the specified device. + detach_buffers (bool, optional): Whether to detach the reference to the buffers, this + argument is only used if the input target is :class:`nn.Module` and ``by='reference'``. + (default: :const:`False`) + device (Device or None, optional): If specified, move the cloned module to the specified + device. (default: :const:`None`) Returns: The cloned module. @@ -463,10 +452,10 @@ def clone(t: torch.Tensor) -> torch.Tensor: def clone_detach_(t: torch.Tensor) -> torch.Tensor: if isinstance(t, nn.Parameter): - return nn.Parameter(t.clone().detach_(), requires_grad=t.requires_grad).to( - device=target_device + return nn.Parameter( + t.clone().to(device=target_device).detach_(), requires_grad=t.requires_grad ) - return t.clone().detach_().to(device=target_device).requires_grad_(t.requires_grad) + return t.clone().to(device=target_device).detach_().requires_grad_(t.requires_grad) else: @@ -491,13 +480,34 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor: return pytree.tree_map(replicate, cast(TensorTree, target)) +@overload +def module_detach_(target: ModuleState) -> ModuleState: # pragma: no cover + ... + + +@overload +def module_detach_(target: nn.Module) -> nn.Module: # pragma: no cover + ... + + +@overload +def module_detach_(target: MetaOptimizer) -> MetaOptimizer: # pragma: no cover + ... + + +@overload +def module_detach_(target: TensorTree) -> TensorTree: # pragma: no cover + ... + + def module_detach_( - target: Union[TensorTree, ModuleState, nn.Module, 'MetaOptimizer'] -) -> Union[TensorTree, ModuleState, nn.Module, 'MetaOptimizer']: + target: ModuleState | nn.Module | MetaOptimizer | TensorTree, +) -> ModuleState | nn.Module | MetaOptimizer | TensorTree: """Detach a module from the computation graph. Args: - target: The target to be detached. + target (ModuleState, nn.Module, MetaOptimizer, or tree of Tensor): The + target to be detached. Returns: The detached module. diff --git a/torchopt/version.py b/torchopt/version.py index 6d66f945..b8136a22 100644 --- a/torchopt/version.py +++ b/torchopt/version.py @@ -14,7 +14,7 @@ # ============================================================================== """TorchOpt: a high-performance optimizer library built upon PyTorch.""" -__version__ = '0.6.0' +__version__ = '0.7.0' __license__ = 'Apache License, Version 2.0' __author__ = 'TorchOpt Contributors' __release__ = False diff --git a/torchopt/visual.py b/torchopt/visual.py index 25a66ada..7afe65a4 100644 --- a/torchopt/visual.py +++ b/torchopt/visual.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,13 +17,13 @@ # ============================================================================== """Computation graph visualization.""" -import warnings +from __future__ import annotations + from collections import namedtuple -from typing import Generator, Iterable, Mapping, Optional, Union, cast +from typing import Generator, Iterable, Mapping, cast import torch from graphviz import Digraph -from pkg_resources import parse_version from torchopt.typing import TensorOrTensors from torchopt.utils import ModuleState @@ -39,7 +39,7 @@ def get_fn_name(fn, show_attrs, max_attr_chars): - """Returns function name.""" + """Return function name.""" name = str(type(fn).__name__) if not show_attrs: return name @@ -73,25 +73,24 @@ def truncate(s): # pylint: disable=invalid-name # pylint: disable-next=too-many-branches,too-many-statements,too-many-locals def make_dot( var: TensorOrTensors, - params: Optional[ - Union[ - Mapping[str, torch.Tensor], - ModuleState, - Generator, - Iterable[Union[Mapping[str, torch.Tensor], ModuleState, Generator]], - ] - ] = None, + params: ( + Mapping[str, torch.Tensor] + | ModuleState + | Generator + | Iterable[Mapping[str, torch.Tensor] | ModuleState | Generator] + | None + ) = None, show_attrs: bool = False, show_saved: bool = False, max_attr_chars: int = 50, ) -> Digraph: - """Produces Graphviz representation of PyTorch autograd graph. + """Produce Graphviz representation of PyTorch autograd graph. If a node represents a backward function, it is gray. Otherwise, the node represents a tensor and is either blue, orange, or green: - **Blue** - Reachable leaf tensors that requires grad (tensors whose :attr:`grad` fields will be + Reachable leaf tensors that requires grad (tensors whose ``grad`` fields will be populated during :meth:`backward`). - **Orange** Saved tensors of custom autograd functions as well as those saved by built-in backward @@ -102,24 +101,17 @@ def make_dot( If any output is a view, we represent its base tensor with a dark green node. Args: - var: Output tensor. - params: ([dict of (name, tensor) or state_dict]) - Parameters to add names to node that requires grad. - show_attrs: Whether to display non-tensor attributes of backward nodes - (Requires PyTorch version >= 1.9) - show_saved: Whether to display saved tensor nodes that are not by custom autograd - functions. Saved tensor nodes for custom functions, if present, are always displayed. - (Requires PyTorch version >= 1.9) - max_attr_chars: If ``show_attrs`` is :data:`True`, sets max number of characters to display - for any given attribute. + var (Tensor or sequence of Tensor): Output tensor. + params: (dict[str, Tensor], ModuleState, iterable of tuple[str, Tensor], or None, optional): + Parameters to add names to node that requires grad. (default: :data:`None`) + show_attrs (bool, optional): Whether to display non-tensor attributes of backward nodes. + (default: :data:`False`) + show_saved (bool, optional): Whether to display saved tensor nodes that are not by custom + autograd functions. Saved tensor nodes for custom functions, if present, are always + displayed. (default: :data:`False`) + max_attr_chars (int, optional): If ``show_attrs`` is :data:`True`, sets max number of + characters to display for any given attribute. (default: :const:`50`) """ - if parse_version(torch.__version__) < parse_version('1.9') and (show_attrs or show_saved): - warnings.warn( - 'make_dot: showing grad_fn attributes and saved variables ' - 'requires PyTorch version >= 1.9. (This does NOT apply to ' - 'saved tensors saved by custom autograd functions.)' - ) - param_map = {} if params is not None: @@ -138,16 +130,16 @@ def make_dot( else: param_map.update({v: k for k, v in cast(Mapping, param).items()}) - node_attr = dict( - style='filled', - shape='box', - align='left', - fontsize='10', - ranksep='0.1', - height='0.2', - fontname='monospace', - ) - dot = Digraph(node_attr=node_attr, graph_attr=dict(size='12,12')) + node_attr = { + 'style': 'filled', + 'shape': 'box', + 'align': 'left', + 'fontsize': '10', + 'ranksep': '0.1', + 'height': '0.2', + 'fontname': 'monospace', + } + dot = Digraph(node_attr=node_attr, graph_attr={'size': '12,12'}) seen = set() def size_to_str(size): diff --git a/tutorials/1_Functional_Optimizer.ipynb b/tutorials/1_Functional_Optimizer.ipynb index 3d70eb62..07a8aeb8 100644 --- a/tutorials/1_Functional_Optimizer.ipynb +++ b/tutorials/1_Functional_Optimizer.ipynb @@ -18,7 +18,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In this tutorial, we will introduce how TorchOpt can be treated as functional optimizer to conduct normal optimization with functional programing style. We will also illustrate how to conduct differentiable optimization with functional programing in PyTorch." + "In this tutorial, we will introduce how TorchOpt can be treated as functional optimizer to conduct normal optimization with functional programming style. We will also illustrate how to conduct differentiable optimization with functional programming in PyTorch." ] }, { @@ -70,7 +70,7 @@ "source": [ "### 1.1 Original JAX implementation\n", "\n", - "The first example is JAX implementation coupled with [Optax](https://github.com/deepmind/optax), which belongs to functional programing style." + "The first example is JAX implementation coupled with [Optax](https://github.com/deepmind/optax), which belongs to functional programming style." ] }, { @@ -391,7 +391,7 @@ "source": [ "## 2. Differentiable Optimization with Functional Optimizer\n", "\n", - "Coupled with functional optimizer, you can conduct differentiable optimization by setting the `inplace` flag as `False` in update and `apply_updates` function. (which might be helpful for meta-learning algorithm implementation with functional programing style). \n", + "Coupled with functional optimizer, you can conduct differentiable optimization by setting the `inplace` flag as `False` in update and `apply_updates` function. (which might be helpful for meta-learning algorithm implementation with functional programming style). \n", "\n", "Note that `torchopt.SGD` and `torchopt.Adam` do not support differentiable optimization. Refer to the Meta-Optimizer notebook for PyTorch-like differentiable optimizers." ] diff --git a/tutorials/2_Visualization.ipynb b/tutorials/2_Visualization.ipynb index 3141f522..11c68bec 100644 --- a/tutorials/2_Visualization.ipynb +++ b/tutorials/2_Visualization.ipynb @@ -18,7 +18,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In [PyTorch](https://pytorch.org), if the attribute `requires_grad` a tensor is `True`, the computation graph will be created if we use the tensor to do any operations. The computation graph is implemented likes link-list -- `Tensor`s are nodes and they are linked by their attribute `gran_fn`. [PyTorchViz](https://github.com/szagoruyko/pytorchviz) is a Python package that uses [Graphviz](https://graphviz.org) as a backend for plotting computation graphs. TorchOpt use PyTorchViz as the blueprint and provide more easy-to-use visualization functions on the premise of supporting all its functions." + "In [PyTorch](https://pytorch.org), if the attribute `requires_grad` of a tensor is `True`, the computation graph will be created if we use the tensor to do any operations. The computation graph is implemented like link-list -- `Tensor`s are nodes and they are linked by their attribute `gran_fn`. [PyTorchViz](https://github.com/szagoruyko/pytorchviz) is a Python package that uses [Graphviz](https://graphviz.org) as a backend for plotting computation graphs. TorchOpt use PyTorchViz as the blueprint and provide more easy-to-use visualization functions on the premise of supporting all its functions." ] }, { diff --git a/tutorials/3_Meta_Optimizer.ipynb b/tutorials/3_Meta_Optimizer.ipynb index d50ace2d..4a09836c 100644 --- a/tutorials/3_Meta_Optimizer.ipynb +++ b/tutorials/3_Meta_Optimizer.ipynb @@ -112,7 +112,7 @@ "# Low-level API\n", "optim = torchopt.MetaOptimizer(net, torchopt.sgd(lr=1.0))\n", "\n", - "# High level API\n", + "# High-level API\n", "optim = torchopt.MetaSGD(net, lr=1.0)" ] }, @@ -274,7 +274,7 @@ "\n", "We observe that how to reinitialize the inner-loop parameter in a new bi-level process vary in different meta-learning algorithms. For instance, in algorithm like Model-Agnostic Meta-Learning (MAML) ([arXiv:1703.03400](https://arxiv.org/abs/1703.03400)), every time a new task comes, we need to reset the parameters to the initial ones. In other cases such as Meta-Gradient Reinforcement Learning (MGRL) ([arXiv:1805.09801](https://arxiv.org/abs/1805.09801)), the inner-loop network parameter just inherit previous updated parameter to continue the new bi-level process.\n", "\n", - "We provide the `torchopt.extract_state_dict` and `torchopt.recover_state_dict` functions to extract and restore the state of network and optimizer. By default, the extracted state dictionary is a reference (this design is for accumulating gradient of multi-task batch training, MAML for example). You can also set `by='copy'` to extract the copy of state dictionary or set `by='deepcopy'` to have a detached copy." + "We provide the `torchopt.extract_state_dict` and `torchopt.recover_state_dict` functions to extract and restore the state of network and optimizer. By default, the extracted state dictionary is a reference (this design is for accumulating gradient of multi-task batch training, MAML for example). You can also set `by='copy'` to extract the copy of the state dictionary or set `by='deepcopy'` to have a detached copy." ] }, { @@ -303,7 +303,7 @@ "# If set `detach_buffers=True`, the parameters are referenced as references while buffers are detached copies\n", "init_net_state = torchopt.extract_state_dict(net, by='reference', detach_buffers=True)\n", "\n", - "# Set `copy` to get the copy of state dictionary\n", + "# Set `copy` to get the copy of the state dictionary\n", "init_net_state_copy = torchopt.extract_state_dict(net, by='copy')\n", "init_optim_state_copy = torchopt.extract_state_dict(optim, by='copy')\n", "\n", @@ -680,7 +680,7 @@ "source": [ "**2. Get `Trying to backward through the graph a second time` error when conducting multiple meta-optimization.**\n", "\n", - "Please refer to the tutorial notebook [Stop Gradient](4_Stop_Gradient.ipynb) for more guidances." + "Please refer to the tutorial notebook [Stop Gradient](4_Stop_Gradient.ipynb) for more guidance." ] } ], diff --git a/tutorials/5_Implicit_Differentiation.ipynb b/tutorials/5_Implicit_Differentiation.ipynb index c2913101..f8258fcc 100644 --- a/tutorials/5_Implicit_Differentiation.ipynb +++ b/tutorials/5_Implicit_Differentiation.ipynb @@ -48,20 +48,15 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "0cdaac49-4b94-4900-9bb5-a39057ac8b21", "metadata": {}, "source": [ "## 1. Functional API\n", "\n", - "The basic functional API is `torchopt.diff.implicit.custom_root`, which is used as the decorator for the forward process implicit gradient procedures. Users are required to implement the stationary conditions for the inner-loop process, which will be used as the input of custom_root decorator. We show the pseudo code in the following part." - ] - }, - { - "cell_type": "markdown", - "id": "c0b4400b-a491-4f07-926c-c421ac5a2069", - "metadata": {}, - "source": [ + "The basic functional API is `torchopt.diff.implicit.custom_root`, which is used as the decorator for the forward process implicit gradient procedures. Users are required to implement the stationary conditions for the inner-loop process, which will be used as the input of custom_root decorator. We show the pseudo code in the following part.\n", + "\n", "```python\n", "# Functional API for implicit gradient\n", "def stationary(params, meta_params, data):\n", @@ -128,6 +123,7 @@ "# backpropogate, in this case we want to backpropogate to the initial parameters so we set it as 1.\n", "# You can also set argnums as (1, 2) if you want to backpropogate through multiple meta-parameters\n", "\n", + "\n", "# Here we pass argnums=1 to the custom_root. That means we want to compute the gradient of\n", "# optimal_params w.r.t. the 1-indexed argument in inner_solver, i.e., params.\n", "# torchopt.linear_solve.solve_normal_cg specify that we use the conjugate gradient based linear solver\n", @@ -246,6 +242,7 @@ "fmodel, meta_params = functorch.make_functional(model)\n", "data = (x, y, fmodel)\n", "\n", + "\n", "# Clone function for parameters\n", "def clone(params):\n", " cloned = []\n", @@ -334,6 +331,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "c92e67ea-b220-4a14-a1ea-4eb3c5f52b6b", "metadata": {}, @@ -375,7 +373,7 @@ "meta_params, data = ..., ...\n", "inner_net = InnerNet()\n", "\n", - "# Solve for inner-loop process related with the meta-parameters\n", + "# Solve for inner-loop process related to the meta-parameters\n", "optimal_inner_net = inner_net.solve(meta_params, *data)\n", "\n", "# Get outer-loss and solve for meta-gradient\n", diff --git a/tutorials/6_Zero_Order_Differentiation.ipynb b/tutorials/6_Zero_Order_Differentiation.ipynb index c8d1e551..968f6b6c 100644 --- a/tutorials/6_Zero_Order_Differentiation.ipynb +++ b/tutorials/6_Zero_Order_Differentiation.ipynb @@ -23,7 +23,7 @@ "source": [ "When the inner-loop process is non-differentiable or one wants to eliminate the heavy computation burdens in the previous two modes (brought by Hessian), one can choose ZD. ZD typically gets gradients based on zero-order estimation, such as finite-difference, or Evolutionary Strategy.\n", "\n", - "TorchOpt offers API for ES-based differentiation. Instead of optimizing the objective $F$, ES optimizes a Gaussion smoothing objective defined as $\\tilde{f}_{\\sigma} (\\theta) = \\mathbb{E}_{{z} \\sim \\mathcal{N}( {0}, {I}_d )} [ f ({\\theta} + \\sigma \\, z) ]$, where $\\sigma$ denotes precision. The gradient of such objective is $\\nabla_\\theta \\tilde{f}_{\\sigma} (\\theta) = \\frac{1}{\\sigma} \\mathbb{E}_{{z} \\sim \\mathcal{N}( {0}, {I}_d )} [ f({\\theta} + \\sigma \\, z) \\cdot z ]$. Refer to [ES-MAML](https://arxiv.org/pdf/1910.01215.pdf) for more details." + "TorchOpt offers API for ES-based differentiation. Instead of optimizing the objective $f (\\boldsymbol{\\theta}): \\mathbb{R}^n \\to \\mathbb{R}$, ES optimizes a Gaussion smoothing objective defined as $\\tilde{f}_{\\sigma} (\\boldsymbol{\\theta}) = \\mathbb{E}_{\\boldsymbol{z} \\sim \\mathcal{N}( 0, {I}_d )} [ f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) ]$, where $\\sigma$ denotes precision. The gradient of such objective is $\\nabla_{\\boldsymbol{\\theta}} \\tilde{f}_{\\sigma} (\\boldsymbol{\\theta}) = \\frac{1}{\\sigma} \\mathbb{E}_{\\boldsymbol{z} \\sim \\mathcal{N}( 0, {I}_d )} [ f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) \\cdot \\boldsymbol{z} ]$. Refer to [ES-MAML](https://arxiv.org/pdf/1910.01215.pdf) for more details." ] }, { @@ -58,49 +58,64 @@ "\n", "The basic functional API is `torchopt.diff.zero_order.zero_order`, which is used as the decorator for the forward process zero-order gradient procedures. Users are required to implement the noise sampling function, which will be used as the input of zero_order decorator. Here we show the specific meaning for each parameter used in the decorator.\n", "\n", - "- `distribution` for noise sampling distribution\n", - "- `method` for different kind of algorithms, we support `'naive'` ([ES-RL](https://arxiv.org/abs/1703.03864)), `'forward'` ([Forward-FD](http://proceedings.mlr.press/v80/choromanski18a/choromanski18a.pdf)), and `'antithetic'` ([antithetic](https://d1wqtxts1xzle7.cloudfront.net/75609515/coredp2011_1web-with-cover-page-v2.pdf?Expires=1670215467&Signature=RfP~mQhhhI7aGknwXbRBgSggFrKuNTPYdyUSdMmfTxOa62QoOJAm-Xhr3F1PLyjUQc2JVxmKIKGGuyYvyfCTpB31dfmMtuVQxZMWVF-SfErTN05SliC93yjA1x1g2kjhn8bkBFdQqGl~1RQSKnhj88BakgSeDNzyCxwbD5VgR89BXRs4YIK5RBIKYtgLhoyz5jar7wHS3TJhRzs3WNeTIAjAmLqJ068oGFZ0Jr7maGquTe3w~8LEEIprJ6cyCMc6b1UUJkmwjNq0RLTVbxgFjfi4Z9kyxyJB9IOS1J25OOON4jfwh5JlXS7MVskuONUyHJim1TQ8OwCraKlBsQLPQw__&Key-Pair-Id=APKAJLOHF5GGSLRBV4ZA)).\n", + "- `distribution` for noise sampling distribution. The distribution $\\lambda$ should be spherical symmetric and with a constant variance of $1$ for each element. I.e.:\n", + "\n", + " - Spherical symmetric: $\\mathbb{E}_{\\boldsymbol{z} \\sim \\lambda} [ \\boldsymbol{z} ] = \\boldsymbol{0}$.\n", + " - Constant variance of $1$ for each element: $\\mathbb{E}_{\\boldsymbol{z} \\sim \\lambda} [ {\\lvert z_i \\rvert}^2 ] = 1$.\n", + " - For example, the standard multi-dimensional normal distribution $\\mathcal{N} (\\boldsymbol{0}, \\boldsymbol{1})$.\n", + "\n", + "- `method` for different kind of algorithms, we support `'naive'` ([ES-RL](https://arxiv.org/abs/1703.03864)), `'forward'` ([Forward-FD](http://proceedings.mlr.press/v80/choromanski18a/choromanski18a.pdf)), and `'antithetic'` ([antithetic](https://arxiv.org/abs/1803.07055)).\n", + "\n", + " $$\n", + " \\begin{align*}\n", + " \\text{naive} \\qquad & \\nabla_{\\boldsymbol{\\theta}} \\tilde{f}_{\\sigma} (\\boldsymbol{\\theta}) = \\frac{1}{\\sigma} \\mathbb{E}_{\\boldsymbol{z} \\sim \\lambda} [ f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) \\cdot \\boldsymbol{z} ] \\\\\n", + " \\text{forward} \\qquad & \\nabla_{\\boldsymbol{\\theta}} \\tilde{f}_{\\sigma} (\\boldsymbol{\\theta}) = \\frac{1}{\\sigma} \\mathbb{E}_{\\boldsymbol{z} \\sim \\lambda} [ ( f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) - f (\\boldsymbol{\\theta}) ) \\cdot \\boldsymbol{z} ] \\\\\n", + " \\text{antithetic} \\qquad & \\nabla_{\\boldsymbol{\\theta}} \\tilde{f}_{\\sigma} (\\boldsymbol{\\theta}) = \\frac{1}{2 \\sigma} \\mathbb{E}_{\\boldsymbol{z} \\sim \\lambda} [ (f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) - f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) ) \\cdot \\boldsymbol{z} ]\n", + " \\end{align*}\n", + " $$\n", + "\n", "- `argnums` specifies which parameter we want to trace the meta-gradient.\n", - "- `sigma` is for precision.\n", "- `num_samples` specifies how many times we want to conduct the sampling.\n", + "- `sigma` is for precision. This is the scaling factor for the sampling distribution.\n", + "\n", + "We show the pseudo code in the following part.\n", "\n", - "We show the pseudo code in the following part." - ] - }, - { - "cell_type": "markdown", - "id": "c0b4400b-a491-4f07-926c-c421ac5a2069", - "metadata": {}, - "source": [ "```python\n", "# Functional API for zero-order differentiation\n", "# 1. Customize the noise distribution via a distribution class\n", "class Distribution:\n", - " def sample(self, sample_shape = torch.Size()):\n", - " # sampling function for noise\n", + " def sample(self, sample_shape=torch.Size()):\n", + " # Sampling function for noise\n", + " # NOTE: The distribution should be spherical symmetric and with a constant variance of 1.\n", + " ...\n", " return noise_batch\n", "\n", "distribution = Distribution()\n", "\n", "# 2. Customize the noise distribution via a sampling function\n", - "def distribution(sample_shape = torch.Size()):\n", - " # sampling function for noise\n", + "def distribution(sample_shape=torch.Size()):\n", + " # Sampling function for noise\n", + " # NOTE: The distribution should be spherical symmetric and with a constant variance of 1.\n", + " ...\n", " return noise_batch\n", "\n", "# 3. Distribution can also be an instance of `torch.distributions.Distribution`, e.g., `torch.distributions.Normal(...)`\n", "distribution = torch.distributions.Normal(loc=0, scale=1)\n", "\n", "# Decorator that wraps the function\n", - "@torchopt.diff.zero_order(distribution=distribution, method='naive', argnums=0, sigma=0.01, num_samples=100)\n", + "@torchopt.diff.zero_order(distribution=distribution, method='naive', argnums=0, num_samples=100, sigma=0.01)\n", "def forward(params, data):\n", " # Forward optimization process for params\n", - " return output\n", + " ...\n", + " return objective # the returned tensor should be a scalar tensor\n", "\n", "# Define params and get data\n", "params, data = ..., ...\n", - "loss = forward(params, data)\n", "\n", - "meta_grads = torch.autograd.grad(loss, params)\n", + "# Forward pass\n", + "loss = forward(params, data)\n", + "# Backward pass using zero-order differentiation\n", + "grads = torch.autograd.grad(loss, params)\n", "```" ] }, @@ -122,57 +137,56 @@ "name": "stdout", "output_type": "stream", "text": [ - "001: tensor(0.0269, grad_fn=)\n", - "002: tensor(0.0246, grad_fn=)\n", - "003: tensor(0.0225, grad_fn=)\n", - "004: tensor(0.0205, grad_fn=)\n", - "005: tensor(0.0187, grad_fn=)\n", - "006: tensor(0.0171, grad_fn=)\n", - "007: tensor(0.0156, grad_fn=)\n", - "008: tensor(0.0144, grad_fn=)\n", - "009: tensor(0.0134, grad_fn=)\n", - "010: tensor(0.0128, grad_fn=)\n", - "011: tensor(0.0122, grad_fn=)\n", + "001: tensor(0.0265, grad_fn=)\n", + "002: tensor(0.0243, grad_fn=)\n", + "003: tensor(0.0222, grad_fn=)\n", + "004: tensor(0.0202, grad_fn=)\n", + "005: tensor(0.0184, grad_fn=)\n", + "006: tensor(0.0170, grad_fn=)\n", + "007: tensor(0.0157, grad_fn=)\n", + "008: tensor(0.0146, grad_fn=)\n", + "009: tensor(0.0137, grad_fn=)\n", + "010: tensor(0.0130, grad_fn=)\n", + "011: tensor(0.0123, grad_fn=)\n", "012: tensor(0.0118, grad_fn=)\n", - "013: tensor(0.0120, grad_fn=)\n", - "014: tensor(0.0117, grad_fn=)\n", - "015: tensor(0.0117, grad_fn=)\n", - "016: tensor(0.0118, grad_fn=)\n", - "017: tensor(0.0121, grad_fn=)\n", - "018: tensor(0.0117, grad_fn=)\n", + "013: tensor(0.0114, grad_fn=)\n", + "014: tensor(0.0111, grad_fn=)\n", + "015: tensor(0.0111, grad_fn=)\n", + "016: tensor(0.0111, grad_fn=)\n", + "017: tensor(0.0113, grad_fn=)\n", + "018: tensor(0.0115, grad_fn=)\n", "019: tensor(0.0118, grad_fn=)\n", - "020: tensor(0.0118, grad_fn=)\n", - "021: tensor(0.0115, grad_fn=)\n", - "022: tensor(0.0117, grad_fn=)\n", - "023: tensor(0.0117, grad_fn=)\n", - "024: tensor(0.0116, grad_fn=)\n", - "025: tensor(0.0113, grad_fn=)\n" + "020: tensor(0.0120, grad_fn=)\n", + "021: tensor(0.0121, grad_fn=)\n", + "022: tensor(0.0121, grad_fn=)\n", + "023: tensor(0.0122, grad_fn=)\n", + "024: tensor(0.0122, grad_fn=)\n", + "025: tensor(0.0122, grad_fn=)\n" ] } ], "source": [ "torch.random.manual_seed(0)\n", "\n", - "fmodel, params = functorch.make_functional(torch.nn.Linear(32, 1))\n", + "fmodel, params = functorch.make_functional(nn.Linear(32, 1))\n", "x = torch.randn(64, 32) * 0.1\n", - "y = torch.randn(64) * 0.1\n", + "y = torch.randn(64, 1) * 0.1\n", "distribution = torch.distributions.Normal(loc=0, scale=1)\n", "\n", "\n", - "@torchopt.diff.zero_order.zero_order(\n", - " distribution=distribution, method='forward', argnums=0, sigma=0.01, num_samples=1000\n", + "@torchopt.diff.zero_order(\n", + " distribution=distribution, method='forward', argnums=0, num_samples=100, sigma=0.01\n", ")\n", "def forward_process(params, fn, x, y):\n", " y_pred = fn(params, x)\n", - " loss = torch.mean((y - y_pred) ** 2)\n", + " loss = F.mse_loss(y_pred, y)\n", " return loss\n", "\n", "\n", "optimizer = torchopt.adam(lr=0.01)\n", - "opt_state = optimizer.init(params)\n", + "opt_state = optimizer.init(params) # init optimizer\n", "\n", "for i in range(25):\n", - " opt_state = optimizer.init(params) # init optimizer\n", " loss = forward_process(params, fmodel, x, y) # compute loss\n", "\n", " grads = torch.autograd.grad(loss, params) # compute gradients\n", @@ -181,6 +195,136 @@ "\n", " print(f'{i + 1:03d}: {loss!r}')" ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "db723f6b", + "metadata": {}, + "source": [ + "## 2. OOP API\n", + "\n", + "The basic OOP API is the class `ZeroOrderGradientModule`. We make the network as an `nn.Module` following a classical PyTorch style. Users need to define the forward process zero-order gradient procedures `forward()` and a noise sampling function `sample()`. Here we show the specific meaning for each parameter used in the class.\n", + "\n", + "- `method` for different kind of algorithms, we support `'naive'` ([ES-RL](https://arxiv.org/abs/1703.03864)), `'forward'` ([Forward-FD](http://proceedings.mlr.press/v80/choromanski18a/choromanski18a.pdf)), and `'antithetic'` ([antithetic](https://d1wqtxts1xzle7.cloudfront.net/75609515/coredp2011_1web-with-cover-page-v2.pdf?Expires=1670215467&Signature=RfP~mQhhhI7aGknwXbRBgSggFrKuNTPYdyUSdMmfTxOa62QoOJAm-Xhr3F1PLyjUQc2JVxmKIKGGuyYvyfCTpB31dfmMtuVQxZMWVF-SfErTN05SliC93yjA1x1g2kjhn8bkBFdQqGl~1RQSKnhj88BakgSeDNzyCxwbD5VgR89BXRs4YIK5RBIKYtgLhoyz5jar7wHS3TJhRzs3WNeTIAjAmLqJ068oGFZ0Jr7maGquTe3w~8LEEIprJ6cyCMc6b1UUJkmwjNq0RLTVbxgFjfi4Z9kyxyJB9IOS1J25OOON4jfwh5JlXS7MVskuONUyHJim1TQ8OwCraKlBsQLPQw__&Key-Pair-Id=APKAJLOHF5GGSLRBV4ZA)).\n", + "- `num_samples` specifies how many times we want to conduct the sampling.\n", + "- `sigma` is for precision. This is the scaling factor for the sampling distribution.\n", + "\n", + "We show the pseudo code in the following part.\n", + "\n", + "```python\n", + "from torchopt.nn import ZeroOrderGradientModule\n", + "\n", + "# Inherited from the class ZeroOrderGradientModule\n", + "# Optionally specify the `method` and/or `num_samples` and/or `sigma` used for sampling\n", + "class Net(ZeroOrderGradientModule, method='naive', num_samples=100, sigma=0.01):\n", + " def __init__(self, ...):\n", + " ...\n", + "\n", + " def forward(self, batch):\n", + " # Forward process\n", + " ...\n", + " return objective # the returned tensor should be a scalar tensor\n", + "\n", + " def sample(self, sample_shape=torch.Size()):\n", + " # Generate a batch of noise samples\n", + " # NOTE: The distribution should be spherical symmetric and with a constant variance of 1.\n", + " ...\n", + " return noise_batch\n", + "\n", + "# Get model and data\n", + "net = Net(...)\n", + "data = ...\n", + "\n", + "# Forward pass\n", + "loss = Net(data)\n", + "# Backward pass using zero-order differentiation\n", + "grads = torch.autograd.grad(loss, net.parameters())\n", + "```" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "b53524f5", + "metadata": {}, + "source": [ + "Here we reimplement the functional API example above with the OOP API." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ecc5730c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "001: tensor(0.0201, grad_fn=)\n", + "002: tensor(0.0181, grad_fn=)\n", + "003: tensor(0.0167, grad_fn=)\n", + "004: tensor(0.0153, grad_fn=)\n", + "005: tensor(0.0142, grad_fn=)\n", + "006: tensor(0.0133, grad_fn=)\n", + "007: tensor(0.0125, grad_fn=)\n", + "008: tensor(0.0119, grad_fn=)\n", + "009: tensor(0.0116, grad_fn=)\n", + "010: tensor(0.0114, grad_fn=)\n", + "011: tensor(0.0112, grad_fn=)\n", + "012: tensor(0.0112, grad_fn=)\n", + "013: tensor(0.0113, grad_fn=)\n", + "014: tensor(0.0116, grad_fn=)\n", + "015: tensor(0.0118, grad_fn=)\n", + "016: tensor(0.0121, grad_fn=)\n", + "017: tensor(0.0123, grad_fn=)\n", + "018: tensor(0.0125, grad_fn=)\n", + "019: tensor(0.0127, grad_fn=)\n", + "020: tensor(0.0127, grad_fn=)\n", + "021: tensor(0.0125, grad_fn=)\n", + "022: tensor(0.0123, grad_fn=)\n", + "023: tensor(0.0120, grad_fn=)\n", + "024: tensor(0.0118, grad_fn=)\n", + "025: tensor(0.0117, grad_fn=)\n" + ] + } + ], + "source": [ + "torch.random.manual_seed(0)\n", + "\n", + "\n", + "class Net(torchopt.nn.ZeroOrderGradientModule, method='forward', num_samples=100, sigma=0.01):\n", + " def __init__(self, dim):\n", + " super().__init__()\n", + " self.fc = nn.Linear(dim, 1)\n", + " self.distribution = torch.distributions.Normal(loc=0, scale=1)\n", + "\n", + " def forward(self, x, y):\n", + " y_pred = self.fc(x)\n", + " loss = F.mse_loss(y_pred, y)\n", + " return loss\n", + "\n", + " def sample(self, sample_shape=torch.Size()):\n", + " return self.distribution.sample(sample_shape)\n", + "\n", + "\n", + "x = torch.randn(64, 32) * 0.1\n", + "y = torch.randn(64, 1) * 0.1\n", + "net = Net(dim=32)\n", + "\n", + "\n", + "optimizer = torchopt.Adam(net.parameters(), lr=0.01)\n", + "\n", + "for i in range(25):\n", + " loss = net(x, y) # compute loss\n", + "\n", + " optimizer.zero_grad()\n", + " loss.backward() # backward pass\n", + " optimizer.step() # update network parameters\n", + "\n", + " print(f'{i + 1:03d}: {loss!r}')" + ] } ], "metadata": { pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy