Skip to content

docs: add more details on implicit differentiation #143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 9, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
docs: update implicit gradient docs
  • Loading branch information
XuehaiPan committed Mar 6, 2023
commit edcede2fe1cf77363bbd250c5a9e5d6d5be59136
44 changes: 32 additions & 12 deletions docs/source/implicit_diff/implicit_diff.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,29 @@ Implicit Differentiation
:width: 80%
:align: center

Implicit differentiation is the task of differentiating through the solution of an optimization problem satisfying a mapping function :math:`Opt` capturing the optimality conditions of the problem. The simplest example is to differentiate through the solution of a minimization problem with respect to its inputs.

Implicit differentiation is the task of differentiating through the solution of an optimization problem satisfying a mapping function :math:`T` capturing the optimality conditions of the problem.
The simplest example is to differentiate through the solution of a minimization problem with respect to its inputs.
Namely, given

.. math::

\boldsymbol{\theta}^{\prime} (\boldsymbol{\phi}) \triangleq \underset{\boldsymbol{\theta}}{\mathop{\operatorname{argmin}}} ~ \mathcal{L}^{\text{in}} (\boldsymbol{\phi},\boldsymbol{\theta}).

By treating the solution :math:`\boldsymbol{\theta}^{\prime}` as an implicit function of :math:`\boldsymbol{\phi}`, the idea of implicit differentiation is to directly get analytical best-response derivatives :math:`\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})` by the implicit function theorem. This is suitable for algorithms such as `IMAML <https://arxiv.org/abs/1909.04630>`_ and `DEQ <https://arxiv.org/abs/1909.01377>`_ when the inner-level optimality conditions :math:`Opt` is defined by :math:`\left. \frac{ \partial \mathcal{L}^{\text{in}} (\boldsymbol{\phi}, \boldsymbol{\theta})}{\partial \boldsymbol{\theta}} \right\rvert_{\boldsymbol{\theta} = \boldsymbol{\theta}^{\prime}} = 0` (e.g., the function :math:`F` in the figure means the solution is obtained by unrolled gradient fixed point update) or reaches some stationary point, optimality conditions :math:`Opt` can be sometimes defined by the root of :math:`T` such that :math:`T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime}) = 0`.
By treating the solution :math:`\boldsymbol{\theta}^{\prime}` as an implicit function of :math:`\boldsymbol{\phi}`, the idea of implicit differentiation is to directly get analytical best-response derivatives :math:`\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})` by the implicit function theorem.
This is suitable for algorithms, such as `IMAML <https://arxiv.org/abs/1909.04630>`_ and `DEQ <https://arxiv.org/abs/1909.01377>`_, when the inner-level optimality conditions :math:`T` is defined by :math:`\left. \frac{ \partial \mathcal{L}^{\text{in}} (\boldsymbol{\phi}, \boldsymbol{\theta})}{\partial \boldsymbol{\theta}} \right\rvert_{\boldsymbol{\theta} = \boldsymbol{\theta}^{\prime}} = 0` or reaches some stationary point.
The optimality conditions sometimes can be defined by the root of :math:`T` such that :math:`T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime}) = \boldsymbol{0}`.

For example, the function :math:`F` in the figure means the solution is obtained by unrolled gradient fixed point update:

.. math::

\boldsymbol{\theta}_{k + 1} = F (\boldsymbol{\phi}, \boldsymbol{\theta}_k) = \boldsymbol{\theta}_k - \alpha \nabla_{\boldsymbol{\theta}_k} \mathcal{L}^{\text{in}} (\boldsymbol{\phi}, \boldsymbol{\theta}_k),

the optimality conditions can be determined by the `KKT conditions <https://en.wikipedia.org/wiki/Karush%E2%80%93Kuhn%E2%80%93Tucker_conditions>`_ of the problem:

.. math::

\boldsymbol{\theta}^{\prime} (\boldsymbol{\phi}) = F (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})) \ \Longrightarrow \ T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})) = \left. \frac{ \partial \mathcal{L}^{\text{in}} (\boldsymbol{\phi}, \boldsymbol{\theta})}{\partial \boldsymbol{\theta}} \right\rvert_{\boldsymbol{\theta} = \boldsymbol{\theta}^{\prime}} = \boldsymbol{0}.

Custom Solvers
--------------
Expand All @@ -27,25 +41,31 @@ Custom Solvers

torchopt.diff.implicit.custom_root

Let :math:`T: \mathbb{R}^n \times \mathbb{R}^d \rightarrow \mathbb{R}^d` be a user-provided mapping, capturing the optimality conditions of a problem. An optimal solution, denoted :math:`\boldsymbol{\theta}^{\prime}(\boldsymbol{\phi})`, should be a root of :math:`T` :
Let :math:`T (\boldsymbol{\phi}, \boldsymbol{\theta}): \mathbb{R}^n \times \mathbb{R}^d \to \mathbb{R}^d` be a user-provided mapping function, that captures the optimality conditions of a problem.
An optimal solution, denoted :math:`\boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})`, should be a root of :math:`T`:

.. math::
T(\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi}))=0 \text {. }

We can see :math:`\boldsymbol{\theta}^{\prime}(\boldsymbol{\phi})` as an implicitly defined function of :math:`\boldsymbol{\phi} \in \mathbb{R}^n`, i.e., :math:`\boldsymbol{\theta}^{\prime}: \mathbb{R}^n \rightarrow \mathbb{R}^d`.
More precisely, from the implicit function theorem, we know that for :math:`\left(\boldsymbol{\phi}_0, \boldsymbol{\theta}^{\prime}_0\right)` satisfying :math:`T\left(\boldsymbol{\phi}_0, \boldsymbol{\theta}^{\prime}_0\right)=0` with a continuously differentiable :math:`T`,
if the Jacobian :math:`\nabla_{\boldsymbol{\theta}^{\prime}} T` evaluated at :math:`\left(\boldsymbol{\phi}_0, \boldsymbol{\theta}^{\prime}_0\right)` is a square invertible matrix, then there exists a function :math:`\boldsymbol{\theta}^{\prime}(\cdot)` defined on a neighborhood of :math:`\boldsymbol{\phi}_0` such that :math:`\boldsymbol{\theta}^{\prime}(\boldsymbol{\phi}_0)=\boldsymbol{\theta}^{\prime}_0`.
Furthermore, for all :math:`\boldsymbol{\phi}` in this neighborhood, we have that :math:`T(\boldsymbol{\phi}_0, \boldsymbol{\theta}^{\prime}_0)=0` and :math:`\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi})` exists. Using the chain rule, the Jacobian :math:`\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi})` satisfies
T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi})) = \boldsymbol{0}.

We can see :math:`\boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})` as an implicitly defined function of :math:`\boldsymbol{\phi} \in \mathbb{R}^n`, i.e., :math:`\boldsymbol{\theta}^{\prime}: \mathbb{R}^n \rightarrow \mathbb{R}^d`.
More precisely, from the `implicit function theorem <https://en.wikipedia.org/wiki/Implicit_function_theorem>`_, we know that for :math:`(\boldsymbol{\phi}_0, \boldsymbol{\theta}^{\prime}_0)` satisfying :math:`T (\boldsymbol{\phi}_0, \boldsymbol{\theta}^{\prime}_0) = \boldsymbol{0}` with a continuously differentiable :math:`T`, if the Jacobian :math:`\nabla_{\boldsymbol{\theta}^{\prime}} T` evaluated at :math:`(\boldsymbol{\phi}_0, \boldsymbol{\theta}^{\prime}_0)` is a square invertible matrix, then there exists a function :math:`\boldsymbol{\theta}^{\prime} (\cdot)` defined on a neighborhood of :math:`\boldsymbol{\phi}_0` such that :math:`\boldsymbol{\theta}^{\prime} (\boldsymbol{\phi}_0) = \boldsymbol{\theta}^{\prime}_0`.
Furthermore, for all :math:`\boldsymbol{\phi}` in this neighborhood, we have that :math:`T (\boldsymbol{\phi}_0, \boldsymbol{\theta}^{\prime}_0) = \boldsymbol{0}` and :math:`\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})` exists. Using the chain rule, the Jacobian :math:`\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi})` satisfies:

.. math::
\nabla_{\boldsymbol{\theta}^{\prime}} T(\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi})) \nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi})+\nabla_{\boldsymbol{\phi}} T(\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi}))=0 .

\frac{d T}{d \boldsymbol{\phi}} = \underbrace{\nabla_{\boldsymbol{\theta}^{\prime}} T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi}))}_{\frac{\partial T}{\partial \boldsymbol{\theta}^{\prime}}} \underbrace{\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})}_{\frac{d \boldsymbol{\theta}^{\prime}}{d \boldsymbol{\phi}}} + \underbrace{\nabla_{\boldsymbol{\phi}} T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi}))}_{\frac{\partial T}{\partial \boldsymbol{\phi}}} = \boldsymbol{0}. \qquad ( T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime}) = \boldsymbol{0} = \text{const})

Computing :math:`\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi})` therefore boils down to the resolution of the linear system of equations

.. math::
\underbrace{-\nabla_{\boldsymbol{\theta}^{\prime}} T(\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi}))}_{A \in \mathbb{R}^{d \times d}} \underbrace{\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi})}_{J \in \mathbb{R}^{d \times n}}=\underbrace{\nabla_{\boldsymbol{\phi}} T(\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi}))}_{B \in \mathbb{R}^{d \times n}} \text {. }

TorchOpt provides the :func:`custom_root` decorators, for easily adding implicit differentiation on top of any existing solver (also called forward optimization). :func:`custom_root` requires users to define the stationary conditions for the problem solution (e.g., KKT conditions) and will automatically calculate the gradient for backward gradient computation. Here is an example of the :func:`custom_root` decorators, which is also the **functional API** for implicit gradient.
\underbrace{\nabla_{\boldsymbol{\theta}^{\prime}} T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi}))}_{A \in \mathbb{R}^{d \times d}} \underbrace{\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})}_{J \in \mathbb{R}^{d \times n}} = \underbrace{- \nabla_{\boldsymbol{\phi}} T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi}))}_{B \in \mathbb{R}^{d \times n}}.

TorchOpt provides a decorator function :func:`custom_root`, for easily adding implicit differentiation on top of any existing inner optimization solver (also called forward optimization).
The :func:`custom_root` decorator requires users to define the stationary conditions for the problem solution (e.g., `KKT conditions <https://en.wikipedia.org/wiki/Karush%E2%80%93Kuhn%E2%80%93Tucker_conditions>`_) and will automatically calculate the gradient for backward gradient computation.

Here is an example of the :func:`custom_root` decorators, which is also the **functional API** for implicit gradient.

.. code-block:: python

Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy