Skip to content

docs(implicit_diff): implicit differentiation integration #73

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
14dbb31
docs: init implicit differentiation integration
Benjamin-eecs Sep 9, 2022
0d76585
fix: linear solve docs error remains
Benjamin-eecs Sep 9, 2022
a9cc777
feat(tutorials): add implicit differentiation
Benjamin-eecs Sep 10, 2022
672005a
fix(tutorials): update torchopt import
Benjamin-eecs Sep 11, 2022
6cdd1f0
docs: pass api docs
Benjamin-eecs Sep 11, 2022
9a261f9
docs: pass api docs
Benjamin-eecs Sep 11, 2022
e915753
docs: pass api docs
Benjamin-eecs Sep 11, 2022
d5564b7
fix(implicit): remove argument
JieRen98 Sep 11, 2022
a47beb0
docs: update `custom_root` docstring
XuehaiPan Sep 13, 2022
e4f512f
Merge branch 'main' into docs/implicit_gradient
XuehaiPan Sep 13, 2022
5cf9018
docs: update colab links
Benjamin-eecs Sep 15, 2022
37298d5
Merge branch 'main' into docs/implicit_gradient
Benjamin-eecs Sep 22, 2022
4c0b69b
Merge branch 'main' into docs/implicit_gradient
XuehaiPan Sep 22, 2022
ae89467
docs(implicit): update docstrings for `custom_root`
XuehaiPan Sep 22, 2022
4a36212
docs(CHANGELOG): update CHANGELOG.md
XuehaiPan Sep 22, 2022
623324b
docs(CHANGELOG): update CHANGELOG.md
XuehaiPan Sep 22, 2022
059fc79
docs(implicit): update tutorial
XuehaiPan Sep 22, 2022
84e06b2
docs(implicit): update docstrings
XuehaiPan Sep 22, 2022
df764cf
docs(README): update future plan
XuehaiPan Sep 22, 2022
8b6a945
chore: update gitignore
Benjamin-eecs Sep 22, 2022
a043f7b
chore: update makefile
Benjamin-eecs Sep 22, 2022
956a780
docs: update dictionary
XuehaiPan Sep 22, 2022
504d699
Merge branch 'main' into docs/implicit_gradient
XuehaiPan Sep 22, 2022
d881334
docs(implicit): update docstrings
XuehaiPan Sep 22, 2022
37ea557
fix(implicit): fix has_aux when result is single tensor
XuehaiPan Sep 22, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add API documentation and tutorial for implicit gradients by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@JieRen98](https://github.com/JieRen98) and [@XuehaiPan](https://github.com/XuehaiPan) in [#73](https://github.com/metaopt/torchopt/pull/73).
- Add wrapper class for functional optimizers and examples of `functorch` integration by [@vmoens](https://github.com/vmoens) and [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#6](https://github.com/metaopt/torchopt/pull/6).
- Implicit differentiation support by [@JieRen98](https://github.com/JieRen98) and [@waterhorse1](https://github.com/waterhorse1) and [@XuehaiPan](https://github.com/XuehaiPan) in [#41](https://github.com/metaopt/torchopt/pull/41).

Expand All @@ -22,8 +23,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fix `None` in module containers by [@XuehaiPan](https://github.com/XuehaiPan)
- Fix backward errors when using inplace `sqrt_` and `add_` by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@JieRen98](https://github.com/JieRen98) and [@XuehaiPan](https://github.com/XuehaiPan)
- Fix `None` in module containers by [@XuehaiPan](https://github.com/XuehaiPan).
- Fix backward errors when using inplace `sqrt_` and `add_` by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@JieRen98](https://github.com/JieRen98) and [@XuehaiPan](https://github.com/XuehaiPan).
- Fix LR scheduling by [@XuehaiPan](https://github.com/XuehaiPan) in [#76](https://github.com/metaopt/torchopt/pull/76).
- Fix the step count tensor (`shape=(1,)`) can change the shape of the scalar updates (`shape=()`) by [@XuehaiPan](https://github.com/XuehaiPan) in [#71](https://github.com/metaopt/torchopt/pull/71).

Expand Down
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -296,11 +296,11 @@ If you find TorchOpt useful, please cite it in your publications.

```bibtex
@software{TorchOpt,
author = {Jie Ren and Xidong Feng and Bo Liu and Xuehai Pan and Luo Mai and Yaodong Yang},
title = {TorchOpt},
year = {2022},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/metaopt/torchopt}},
author = {Jie Ren and Xidong Feng and Bo Liu and Xuehai Pan and Luo Mai and Yaodong Yang},
title = {TorchOpt},
year = {2022},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/metaopt/torchopt}}
}
```
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ sphinx-copybutton
sphinx-rtd-theme
sphinxcontrib-katex
sphinxcontrib-bibtex
sphinx-autodoc-typehints
sphinx-autodoc-typehints >= 1.19.2
IPython
ipykernel
pandoc
Expand Down
36 changes: 36 additions & 0 deletions docs/source/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,40 @@ Differentiable Meta-RMSProp Optimizer

------

Implicit differentiation
========================

.. currentmodule:: torchopt._src.implicit_diff

.. autosummary::

custom_root

Custom solvers
~~~~~~~~~~~~~~

.. autofunction:: custom_root

------

Linear system solving
=====================

.. currentmodule:: torchopt._src.linear_solve

.. autosummary::

solve_cg
solve_normal_cg

Indirect solvers
~~~~~~~~~~~~~~~~

.. autofunction:: solve_cg
.. autofunction:: solve_normal_cg

------

Optimizer Hooks
===============

Expand All @@ -147,6 +181,8 @@ Hook
.. autofunction:: register_hook
.. autofunction:: zero_nan_hook

------

Gradient Transformation
=======================

Expand Down
48 changes: 35 additions & 13 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,31 +44,31 @@ We provide a `conda <https://github.com/conda/conda>`_ environment recipe to ins


.. toctree::
:caption: Getting Started
:maxdepth: 1
:caption: Getting Started
:maxdepth: 1

torchopt101/torchopt-101.rst
torchopt101/torchopt-101.rst


.. toctree::
:caption: Examples
:maxdepth: 1
:caption: Examples
:maxdepth: 1

examples/MAML.rst
examples/MAML.rst


.. toctree::
:caption: Developer Documentation
:maxdepth: 1
:caption: Developer Documentation
:maxdepth: 1

developer/contributing.rst
developer/contributor.rst
developer/contributing.rst
developer/contributor.rst

.. toctree::
:caption: API Documentation
:maxdepth: 2
:caption: API Documentation
:maxdepth: 2

api/api.rst
api/api.rst

The Team
--------
Expand Down Expand Up @@ -97,3 +97,25 @@ License
-------

TorchOpt is licensed under the Apache 2.0 License.

Citing
------

If you find TorchOpt useful, please cite it in your publications.

.. code-block:: bibtex

@software{TorchOpt,
author = {Jie Ren and Xidong Feng and Bo Liu and Xuehai Pan and Luo Mai and Yaodong Yang},
title = {TorchOpt},
year = {2022},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/metaopt/torchopt}}
}


Indices and tables
==================

* :ref:`genindex`
4 changes: 4 additions & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,7 @@ Loshchilov
pytree
booleans
subtrees
optimality
argnums
matvec
Hermitian
9 changes: 5 additions & 4 deletions docs/source/torchopt101/torchopt-101.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ Get Started with Jupyter Notebook

In this tutorial, we will use Google Colaboratory to show you the most basic usages of TorchOpt.

- 1: `Functional Optimizer <https://colab.research.google.com/drive/1yfi-ETyIptlIM7WFYWF_IFhX4WF3LldP?usp=sharing>`_
- 2: `Visualization <https://colab.research.google.com/drive/1Uoo2epqZKmJNQOiO0EU8DGd33AVKBlAq?usp=sharing>`_
- 3: `Meta Optimizer <https://colab.research.google.com/drive/1lo9q2gQz073urYln-4Yub5s8APUoHvQJ?usp=sharing>`_
- 4: `Stop Gradient <https://colab.research.google.com/drive/1jp_oPHIG6aaQMYGNxG72FSuWjABk1DHo?usp=sharing>`_
- 1: `Functional Optimizer <https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/1_Functional_Optimizer.ipynb>`_
- 2: `Visualization <https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/2_Visualization.ipynb>`_
- 3: `Meta Optimizer <https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/3_Meta_Optimizer.ipynb>`_
- 4: `Stop Gradient <https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/4_Stop_Gradient.ipynb>`_
- 5: `Implicit Differentiation <https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/5_Implicit_Differentiation.ipynb>`_
31 changes: 21 additions & 10 deletions torchopt/_src/implicit_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,27 +346,39 @@ def custom_root(
argnums: Union[int, Tuple[int, ...]] = 0,
has_aux: bool = False,
solve: Callable = linear_solve.solve_normal_cg(),
reference_signature: Optional[Union[inspect.Signature, Callable]] = None,
) -> Callable[[Callable], Callable]:
"""Decorator for adding implicit differentiation to a root solver.

This wrapper should be used as a decorator:

.. code-block:: python

def optimality_fun(params, ...):
...

@custom_root(optimality_fun, argnums=argnums)
def solver_fun(params, arg1, arg2, ...):
...
return optimal_params

The first argument to ``optimality_fun`` and ``solver_fun`` is preserved as ``params``.
The ``argnums`` argument refers to the indices of the variables in ``solver_fun``'s signature.
For example, setting ``argnums=(1, 2)`` will compute the gradient of ``optimal_params`` with
respect to ``arg1`` and ``arg2`` in the above example.

Args:
optimality_fun: (callable)
An equation function, ``optimality_fun(params, *args)``. The invariant is
``optimality_fun(sol, *args) == 0`` at the solution / root of ``sol``.
argnums: (int or tuple of int, default: :const:`0`)
Specifies arguments to compute gradients with respect to.
Specifies arguments to compute gradients with respect to. The ``argnums`` can be an
integer or a tuple of integers, which respect to the zero-based indices of the arguments
of the ``solver_fun(params, *args)`` function. The argument ``params`` is included
for the counting, while it is indexed as ``argnums=0``.
has_aux: (default: :data:`False`)
Whether the decorated solver function returns auxiliary data.
solve: (callable, optional, default: :func:`linear_solve.solve_normal_cg`)
a linear solver of the form ``solve(matvec, b)``.
reference_signature: (function signature, optional)
Function signature (i.e. arguments and keyword arguments), with which the solver and
optimality functions are expected to agree. Defaults to ``optimality_fun``. It is
required that solver and optimality functions share the same input signature, but both
might be defined in such a way that the signature correspondence is ambiguous (e.g. if
both accept catch-all ``**kwargs``). To satisfy ``custom_root``'s requirement, any
function with an unambiguous signature can be provided here.

Returns:
A solver function decorator, i.e., ``custom_root(optimality_fun)(solver_fun)``.
Expand All @@ -383,5 +395,4 @@ def custom_root(
solve=solve,
argnums=argnums,
has_aux=has_aux,
reference_signature=reference_signature,
)
11 changes: 5 additions & 6 deletions torchopt/_src/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,12 @@ def cg(

Args:
A: (tensor or tree of tensors or function)
2D array or function that calculates the linear map (matrix-vector
product) ``Ax`` when called like ``A(x)``. ``A`` must represent a
hermitian, positive definite matrix, and must return array(s) with the
same structure and shape as its argument.
2D array or function that calculates the linear map (matrix-vector product) ``Ax`` when
called like ``A(x)``. ``A`` must represent a Hermitian, positive definite matrix, and
must return array(s) with the same structure and shape as its argument.
b: (tensor or tree of tensors)
Right hand side of the linear system representing a single vector. Can be
stored as an array or Python container of array(s) with any shape.
Right hand side of the linear system representing a single vector. Can be stored as an
array or Python container of array(s) with any shape.
x0: (tensor or tree of tensors, optional)
Starting guess for the solution. Must have the same structure as ``b``.
rtol: (float, optional, default: :const:`1e-5`)
Expand Down
6 changes: 3 additions & 3 deletions torchopt/_src/linear_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def _solve_normal_cg(
) -> TensorTree:
"""Solves the normal equation ``A^T A x = A^T b`` using conjugate gradient.

This can be used to solve Ax=b using conjugate gradient when A is not
hermitian, positive definite.
This can be used to solve ``A x = b`` using conjugate gradient when ``A`` is not Hermitian,
positive definite.

Args:
matvec: product between ``A`` and a vector.
Expand Down Expand Up @@ -136,5 +136,5 @@ def normal_matvec(x):


def solve_normal_cg(**kwargs):
"""Wrapper for `solve_normal_cg`."""
"""Wrapper for :func:`solve_normal_cg`."""
return functools.partial(_solve_normal_cg, **kwargs)
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy