Skip to content

docs(implicit_diff): implicit differentiation integration #73

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Sep 22, 2022
Merged
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
14dbb31
docs: init implicit differentiation integration
Benjamin-eecs Sep 9, 2022
0d76585
fix: linear solve docs error remains
Benjamin-eecs Sep 9, 2022
a9cc777
feat(tutorials): add implicit differentiation
Benjamin-eecs Sep 10, 2022
672005a
fix(tutorials): update torchopt import
Benjamin-eecs Sep 11, 2022
6cdd1f0
docs: pass api docs
Benjamin-eecs Sep 11, 2022
9a261f9
docs: pass api docs
Benjamin-eecs Sep 11, 2022
e915753
docs: pass api docs
Benjamin-eecs Sep 11, 2022
d5564b7
fix(implicit): remove argument
JieRen98 Sep 11, 2022
a47beb0
docs: update `custom_root` docstring
XuehaiPan Sep 13, 2022
e4f512f
Merge branch 'main' into docs/implicit_gradient
XuehaiPan Sep 13, 2022
5cf9018
docs: update colab links
Benjamin-eecs Sep 15, 2022
37298d5
Merge branch 'main' into docs/implicit_gradient
Benjamin-eecs Sep 22, 2022
4c0b69b
Merge branch 'main' into docs/implicit_gradient
XuehaiPan Sep 22, 2022
ae89467
docs(implicit): update docstrings for `custom_root`
XuehaiPan Sep 22, 2022
4a36212
docs(CHANGELOG): update CHANGELOG.md
XuehaiPan Sep 22, 2022
623324b
docs(CHANGELOG): update CHANGELOG.md
XuehaiPan Sep 22, 2022
059fc79
docs(implicit): update tutorial
XuehaiPan Sep 22, 2022
84e06b2
docs(implicit): update docstrings
XuehaiPan Sep 22, 2022
df764cf
docs(README): update future plan
XuehaiPan Sep 22, 2022
8b6a945
chore: update gitignore
Benjamin-eecs Sep 22, 2022
a043f7b
chore: update makefile
Benjamin-eecs Sep 22, 2022
956a780
docs: update dictionary
XuehaiPan Sep 22, 2022
504d699
Merge branch 'main' into docs/implicit_gradient
XuehaiPan Sep 22, 2022
d881334
docs(implicit): update docstrings
XuehaiPan Sep 22, 2022
37ea557
fix(implicit): fix has_aux when result is single tensor
XuehaiPan Sep 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
docs(implicit): update tutorial
  • Loading branch information
XuehaiPan committed Sep 22, 2022
commit 059fc79930398c9ae1fb2ab9ea818e7bd8d9f10e
136 changes: 67 additions & 69 deletions tutorials/5_Implicit_Differentiation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 1,
"id": "8f13ae67-e328-409f-84a8-1fc425c03a66",
"metadata": {},
"outputs": [],
Expand All @@ -35,7 +35,6 @@
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import jax\n",
"\n",
"import torchopt\n",
"from torchopt import implicit_diff, sgd"
Expand All @@ -48,17 +47,7 @@
"source": [
"## 1. Basic API\n",
"\n",
"The basic API is **implicit_diff**, which is used as the decorator for the forward process implicit gradient procedures."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "9a0a5b5a-8189-446e-ac9c-0594470df3c9",
"metadata": {},
"outputs": [],
"source": [
"from torchopt import implicit_diff, sgd"
"The basic API is `torchopt.implicit_diff`, which is used as the decorator for the forward process implicit gradient procedures."
]
},
{
Expand All @@ -70,62 +59,68 @@
"For IMAML, the inner-loop objective is described by the following equation.\n",
"\n",
"$$\n",
"\\mathcal{A} l g^{\\star}\\left(\\boldsymbol{\\theta}, \\mathcal{D}_{i}^{\\operatorname{tr}}\\right)=\\underset{\\phi^{\\prime} \\in \\Phi}{\\operatorname{argmin}} \\mathcal{L}\\left(\\boldsymbol{\\phi}^{\\prime}, \\mathcal{D}_{i}^{\\operatorname{tr}}\\right)+\\frac{\\lambda}{2}\\left\\|\\boldsymbol{\\phi}^{\\prime}-\\boldsymbol{\\theta}\\right\\|^{2}\n",
"{\\mathcal{Alg}}^{\\star} \\left( \\boldsymbol{\\theta}, \\mathcal{D}_{i}^{\\text{tr}} \\right) = \\underset{\\phi' \\in \\Phi}{\\operatorname{\\arg \\min}} ~ G \\left( \\boldsymbol{\\phi}', \\boldsymbol{\\theta} \\right) \\triangleq \\mathcal{L} \\left( \\boldsymbol{\\phi}', \\mathcal{D}_{i}^{\\text{tr}} \\right) + \\frac{\\lambda}{2} {\\left\\| \\boldsymbol{\\phi}' - \\boldsymbol{\\theta} \\right\\|}^{2}\n",
"$$\n",
"\n",
"According to this function, we can define the forward function **inner_solver**, where we solve this equation based on sufficient gradient descents. For such inner-loop process, the optimality condition is that the gradient w.r.t inner-loop parameter is 0.\n",
"According to this function, we can define the forward function `inner_solver`, where we solve this equation based on sufficient gradient descents. For such inner-loop process, the optimality condition is that the gradient w.r.t inner-loop parameter is $0$.\n",
"\n",
"$$\n",
"\\left.\\nabla_{\\boldsymbol{\\phi}^{\\prime}} G\\left(\\boldsymbol{\\phi}^{\\prime}, \\boldsymbol{\\theta}\\right)\\right|_{\\phi^{\\prime}=\\boldsymbol{\\phi}}=0\n",
"{\\left. \\nabla_{\\boldsymbol{\\phi}'} G \\left( \\boldsymbol{\\phi}', \\boldsymbol{\\theta} \\right) \\right|}_{\\boldsymbol{\\phi}' = \\boldsymbol{\\phi}^{\\star}} = 0\n",
"$$\n",
"\n",
"Thus we can define the optimality function by defining **imaml_objective** and make it first-order gradient w.r.t the inner-loop parameter as 0. We achieve so by calling out **functorch.grad(imaml_objective, argnums=0)**. Finally, the forward function is decorated by the **@implicit_diff** and the optimalit condition we define."
"Thus we can define the optimality function by defining `imaml_objective` and make it first-order gradient w.r.t the inner-loop parameter as $0$. We achieve so by calling out `functorch.grad(imaml_objective, argnums=0)`. Finally, the forward function is decorated by the `@implicit_diff.custom_root` and the optimality condition we define."
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 2,
"id": "8d623b2f-48ee-4df6-a2ce-cf306b4c9067",
"metadata": {},
"outputs": [],
"source": [
"# Optimality function\n",
"def imaml_objective(optimal_params, init_params, data):\n",
" x, y, f = data\n",
" y_pred = f(optimal_params, x)\n",
" regularisation_loss = 0\n",
" x, y, fmodel = data\n",
" y_pred = fmodel(optimal_params, x)\n",
" regularization_loss = 0.0\n",
" for p1, p2 in zip(optimal_params, init_params):\n",
" regularisation_loss += 0.5 * torch.sum((p1.view(-1) - p2.view(-1))**2)\n",
" loss = F.mse_loss(y_pred, y) + regularisation_loss\n",
" return loss \n",
" regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n",
" loss = F.mse_loss(y_pred, y) + regularization_loss\n",
" return loss\n",
"\n",
"# Optimality Condition is: the gradient w.r.t inner-loop optimal params is 0 (we achieve so by specifying argnums=0 in functorch.grad)\n",
"# the argnums=1 specify which meta-parameter we want to backpropogate, in this case we want to backpropogate to the initial parameters\n",
"# so we set it as 1.\n",
"# You can also set argnums as (1,2) if you want to backpropogate through multiple meta parameters\n",
"# Optimality Condition is: the gradient w.r.t inner-loop optimal params is 0 (we achieve so by\n",
"# specifying argnums=0 in functorch.grad) the argnums=1 specify which meta-parameter we want to\n",
"# backpropogate, in this case we want to backpropogate to the initial parameters so we set it as 1.\n",
"# You can also set argnums as (1, 2) if you want to backpropogate through multiple meta parameters\n",
"\n",
"# Here we pass argnums=1 to the custom_root. That means we want to compute the gradient of\n",
"# optimal_params w.r.t. the 1-indexed argument in inner_solver, i.e., init_params.\n",
"@implicit_diff.custom_root(functorch.grad(imaml_objective, argnums=0), argnums=1)\n",
"def inner_solver(init_params_copy, init_params, data):\n",
" \"\"\"Solve ridge regression by conjugate gradient.\"\"\"\n",
" # inital functional optimizer based on torchopt\n",
" x, y, f = data\n",
" # Initial functional optimizer based on TorchOpt\n",
" x, y, fmodel = data\n",
" params = init_params_copy\n",
" optimizer = sgd(lr=2e-2)\n",
" opt_state = optimizer.init(params)\n",
" with torch.enable_grad():\n",
" # temporarily enable gradient computation for conducting the optimization\n",
" # Temporarily enable gradient computation for conducting the optimization\n",
" for i in range(100):\n",
" pred = f(params, x) \n",
" loss = F.mse_loss(pred, y) # compute loss\n",
" regularisation_loss = 0\n",
" # compute regularisation loss\n",
" pred = fmodel(params, x) \n",
" loss = F.mse_loss(pred, y) # compute loss\n",
" \n",
" # Compute regularization loss\n",
" regularization_loss = 0.0\n",
" for p1, p2 in zip(params, init_params):\n",
" regularisation_loss += 0.5 * torch.sum((p1.view(-1) - p2.view(-1))**2)\n",
" final_loss = loss + regularisation_loss\n",
" grads = torch.autograd.grad(final_loss, params) # compute gradients\n",
" regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n",
" final_loss = loss + regularization_loss\n",
" \n",
" grads = torch.autograd.grad(final_loss, params) # compute gradients\n",
" updates, opt_state = optimizer.update(grads, opt_state) # get updates\n",
" params = TorchOpt.apply_updates(params, updates) \n",
" return params"
" params = torchopt.apply_updates(params, updates)\n",
"\n",
" optimal_params = params\n",
" return optimal_params"
]
},
{
Expand All @@ -138,15 +133,16 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 3,
"id": "fb95538b-1fd9-4ec8-9f57-6360bedc05b7",
"metadata": {},
"outputs": [],
"source": [
"torch.manual_seed(0)\n",
"x = torch.randn(20, 4)\n",
"w = torch.randn(4, 1)\n",
"b = torch.randn(1)\n",
"y = x @ w + b + torch.randn(20, 1) * 0.5"
"y = x @ w + b + 0.5 * torch.randn(20, 1)"
]
},
{
Expand All @@ -159,7 +155,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 4,
"id": "d50a7bfe-ac69-4089-8cf8-3cbd69d6d4e7",
"metadata": {
"tags": []
Expand All @@ -175,20 +171,20 @@
"\n",
" def forward(self, x):\n",
" return self.fc(x)\n",
" \n",
"\n",
"model = Net(4)\n",
"f, p = functorch.make_functional(model)\n",
"data = (x, y, f)\n",
"fmodel, meta_params = functorch.make_functional(model)\n",
"data = (x, y, fmodel)\n",
"\n",
"# clone function for \n",
"def clone(p):\n",
" p_out = []\n",
" for item in p:\n",
"# clone function for parameters\n",
"def clone(params):\n",
" cloned = []\n",
" for item in params:\n",
" if isinstance(item, torch.Tensor):\n",
" p_out.append(item.clone().detach_().requires_grad_(True))\n",
" cloned.append(item.clone().detach_().requires_grad_(True))\n",
" else:\n",
" p_out.append(item)\n",
" return tuple(p_out)"
" cloned.append(item)\n",
" return tuple(cloned)"
]
},
{
Expand All @@ -201,16 +197,16 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 5,
"id": "115e79c6-911f-4743-a2ed-e50a71c3a813",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"optimal_params = inner_solver(clone(p), p, data)\n",
"optimal_params = inner_solver(clone(meta_params), meta_params, data)\n",
"\n",
"outer_loss = f(optimal_params, x).mean()"
"outer_loss = fmodel(optimal_params, x).mean()"
]
},
{
Expand All @@ -223,31 +219,28 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 6,
"id": "6bdcbe8d-2336-4f80-b124-eb43c5a2fc0a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[-0.0582, -0.0163, 0.0379, -0.0265]]), tensor([0.2984]))"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
"name": "stdout",
"output_type": "stream",
"text": [
"(tensor([[-0.0369, 0.0248, 0.0347, 0.0067]]), tensor([0.3156]))\n"
]
}
],
"source": [
"torch.autograd.grad(outer_loss, p)"
"torch.autograd.grad(outer_loss, meta_params)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:torchopt] *",
"display_name": "Python 3.8.12 ('torchopt')",
"language": "python",
"name": "conda-env-torchopt-py"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -259,7 +252,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.8.12"
},
"vscode": {
"interpreter": {
"hash": "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da"
}
}
},
"nbformat": 4,
Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy