diff --git a/CHANGELOG.md b/CHANGELOG.md
index e6c23138..644cd7eb 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
+- Add implicit MAML omniglot few-shot classification example with OOP APIs by [@XuehaiPan](https://github.com/XuehaiPan) in [#107](https://github.com/metaopt/torchopt/pull/107).
- Add implicit MAML omniglot few-shot classification example by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#48](https://github.com/metaopt/torchopt/pull/48).
- Add object-oriented modules support for implicit meta-gradient by [@XuehaiPan](https://github.com/XuehaiPan) in [#101](https://github.com/metaopt/torchopt/pull/101).
- Bump PyTorch version to 1.13.0 by [@XuehaiPan](https://github.com/XuehaiPan) in [#104](https://github.com/metaopt/torchopt/pull/104).
diff --git a/README.md b/README.md
index 719f4442..1bcbf548 100644
--- a/README.md
+++ b/README.md
@@ -273,10 +273,10 @@ make install-editable # or run `pip3 install --no-build-isolation --editable .`
## Future Plan
- [X] CPU-accelerated optimizer
-- [X] Support general implicit differentiation with functional programing
- [X] Support more optimizers such as AdamW, RMSProp
-- [ ] Zero order optimization
-- [ ] Distributed optimizers
+- [X] Support general implicit differentiation
+- [X] Zero order optimization
+- [X] Distributed optimization
- [ ] Support `complex` data type
## Changelog
diff --git a/docs/source/examples/MAML.rst b/docs/source/examples/MAML.rst
index 390e45cc..87891c6a 100644
--- a/docs/source/examples/MAML.rst
+++ b/docs/source/examples/MAML.rst
@@ -99,8 +99,7 @@ Define the ``train`` function:
# Sample a batch of support and query images and labels.
x_spt, y_spt, x_qry, y_qry = db.next()
- task_num, setsz, c_, h, w = x_spt.size()
- querysz = x_qry.size(1)
+ task_num = x_spt.size(0)
# TODO: Maybe pull this out into a separate module so it
# doesn't have to be duplicated between `train` and `test`?
@@ -129,9 +128,9 @@ Define the ``train`` function:
# These will be used to update the model's meta-parameters.
qry_logits = net(x_qry[i])
qry_loss = F.cross_entropy(qry_logits, y_qry[i])
- qry_losses.append(qry_loss.detach())
- qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).sum().item() / querysz
- qry_accs.append(qry_acc)
+ qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean()
+ qry_losses.append(qry_loss)
+ qry_accs.append(qry_acc.item())
torchopt.recover_state_dict(net, net_state_dict)
torchopt.recover_state_dict(inner_opt, optim_state_dict)
@@ -139,15 +138,14 @@ Define the ``train`` function:
qry_losses = torch.mean(torch.stack(qry_losses))
qry_losses.backward()
meta_opt.step()
- qry_losses = sum(qry_losses) / task_num
- qry_accs = 100. * sum(qry_accs) / task_num
+ qry_losses = qry_losses.item()
+ qry_accs = 100.0 * np.mean(qry_accs)
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
print(
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
)
-
log.append(
{
'epoch': i,
@@ -181,8 +179,7 @@ Define the ``test`` function:
for batch_idx in range(n_test_iter):
x_spt, y_spt, x_qry, y_qry = db.next('test')
- task_num, setsz, c_, h, w = x_spt.size()
- querysz = x_qry.size(1)
+ task_num = x_spt.size(0)
# TODO: Maybe pull this out into a separate module so it
# doesn't have to be duplicated between `train` and `test`?
@@ -201,15 +198,17 @@ Define the ``test`` function:
# The query loss and acc induced by these parameters.
qry_logits = net(x_qry[i]).detach()
- qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction='none')
- qry_losses.append(qry_loss.detach())
- qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach())
+ qry_loss = F.cross_entropy(qry_logits, y_qry[i])
+ qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean()
+ qry_losses.append(qry_loss.item())
+ qry_accs.append(qry_acc.item())
torchopt.recover_state_dict(net, net_state_dict)
torchopt.recover_state_dict(inner_opt, optim_state_dict)
- qry_losses = torch.mean(torch.stack(qry_losses)).item()
- qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
+ qry_losses = np.mean(qry_losses)
+ qry_accs = 100.0 * np.mean(qry_accs)
+
print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}')
log.append(
{
diff --git a/examples/FuncTorch/maml_omniglot_vmap.py b/examples/FuncTorch/maml_omniglot_vmap.py
index 9bbb30ce..9f3c1d92 100644
--- a/examples/FuncTorch/maml_omniglot_vmap.py
+++ b/examples/FuncTorch/maml_omniglot_vmap.py
@@ -39,16 +39,10 @@
https://github.com/bamos/HowToTrainYourMAMLPytorch
"""
-
-import os
-import sys
-
-
-cur = os.path.abspath(os.path.dirname(__file__))
-root = os.path.split(cur)[0]
-sys.path.append(root + '/few-shot')
import argparse
import functools
+import pathlib
+import sys
import time
import functorch
@@ -59,12 +53,17 @@
import torch
import torch.nn.functional as F
import torch.optim as optim
-from support.omniglot_loaders import OmniglotNShot
from torch import nn
import torchopt
+CWD = pathlib(__file__).absolute().parent
+sys.path.append(str(CWD.parent / 'few-shot'))
+
+from helpers.omniglot_loaders import OmniglotNShot
+
+
mpl.use('Agg')
plt.style.use('bmh')
@@ -148,8 +147,6 @@ def loss_for_task(net, n_inner_iter, x_spt, y_spt, x_qry, y_qry):
opt = torchopt.sgd(lr=1e-1)
opt_state = opt.init(params)
- querysz = x_qry.size(0)
-
def compute_loss(new_params, buffers, x, y):
logits = fnet(new_params, buffers, x)
loss = F.cross_entropy(logits, y)
@@ -167,7 +164,7 @@ def compute_loss(new_params, buffers, x, y):
# These will be used to update the model's meta-parameters.
qry_logits = fnet(new_params, buffers, x_qry)
qry_loss = F.cross_entropy(qry_logits, y_qry)
- qry_acc = (qry_logits.argmax(dim=1) == y_qry).sum() / querysz
+ qry_acc = (qry_logits.argmax(dim=1) == y_qry).mean()
return qry_loss, qry_acc
@@ -192,18 +189,19 @@ def train(db, net, device, meta_opt, epoch, log):
qry_losses, qry_accs = functorch.vmap(compute_loss_for_task)(x_spt, y_spt, x_qry, y_qry)
# Compute the maml loss by summing together the returned losses.
- qry_losses.sum().backward()
-
+ qry_losses = torch.mean(torch.stack(qry_losses))
+ qry_losses.backward()
meta_opt.step()
- qry_losses = qry_losses.detach().sum() / task_num
- qry_accs = 100.0 * qry_accs.sum() / task_num
+ qry_losses = qry_losses.item()
+ qry_accs = 100.0 * torch.mean(torch.stack(qry_accs)).item()
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
+ torch.cuda.empty_cache()
+
if batch_idx % 4 == 0:
print(
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
)
-
log.append(
{
'epoch': i,
@@ -249,8 +247,10 @@ def test(db, net, device, epoch, log):
qry_losses.append(qry_loss.detach())
qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach())
- qry_losses = torch.cat(qry_losses).mean().item()
- qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item()
+ qry_losses = torch.mean(torch.stack(qry_losses)).item()
+ qry_accs = 100.0 * torch.mean(torch.stack(qry_accs)).item()
+ torch.cuda.empty_cache()
+
print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}')
log.append(
{
diff --git a/examples/L2R/helper/argument.py b/examples/L2R/helpers/argument.py
similarity index 100%
rename from examples/L2R/helper/argument.py
rename to examples/L2R/helpers/argument.py
diff --git a/examples/L2R/helper/model.py b/examples/L2R/helpers/model.py
similarity index 100%
rename from examples/L2R/helper/model.py
rename to examples/L2R/helpers/model.py
diff --git a/examples/L2R/helper/utils.py b/examples/L2R/helpers/utils.py
similarity index 100%
rename from examples/L2R/helper/utils.py
rename to examples/L2R/helpers/utils.py
diff --git a/examples/L2R/l2r.py b/examples/L2R/l2r.py
index 989b71cc..e77faa14 100644
--- a/examples/L2R/l2r.py
+++ b/examples/L2R/l2r.py
@@ -39,9 +39,9 @@
# isort: off
-from helper.argument import parse_args
-from helper.model import LeNet5
-from helper.utils import get_imbalance_dataset, plot, set_seed
+from helpers.argument import parse_args
+from helpers.model import LeNet5
+from helpers.utils import get_imbalance_dataset, plot, set_seed
def run_baseline(args, mnist_train, mnist_test):
diff --git a/examples/LOLA/helper/agent.py b/examples/LOLA/helpers/agent.py
similarity index 100%
rename from examples/LOLA/helper/agent.py
rename to examples/LOLA/helpers/agent.py
diff --git a/examples/LOLA/helper/argument.py b/examples/LOLA/helpers/argument.py
similarity index 100%
rename from examples/LOLA/helper/argument.py
rename to examples/LOLA/helpers/argument.py
diff --git a/examples/LOLA/helper/env.py b/examples/LOLA/helpers/env.py
similarity index 100%
rename from examples/LOLA/helper/env.py
rename to examples/LOLA/helpers/env.py
diff --git a/examples/LOLA/helper/utils.py b/examples/LOLA/helpers/utils.py
similarity index 100%
rename from examples/LOLA/helper/utils.py
rename to examples/LOLA/helpers/utils.py
diff --git a/examples/LOLA/lola_dice.py b/examples/LOLA/lola_dice.py
index 7b1417a4..4b6b2567 100644
--- a/examples/LOLA/lola_dice.py
+++ b/examples/LOLA/lola_dice.py
@@ -21,10 +21,10 @@
# isort: off
-from helper.agent import Agent
-from helper.argument import parse_args
-from helper.env import IPD
-from helper.utils import sample, step
+from helpers.agent import Agent
+from helpers.argument import parse_args
+from helpers.env import IPD
+from helpers.utils import sample, step
def main(args):
diff --git a/examples/distributed/few-shot/support/omniglot_loaders.py b/examples/distributed/few-shot/helpers/omniglot_loaders.py
similarity index 100%
rename from examples/distributed/few-shot/support/omniglot_loaders.py
rename to examples/distributed/few-shot/helpers/omniglot_loaders.py
diff --git a/examples/distributed/few-shot/maml_omniglot.py b/examples/distributed/few-shot/maml_omniglot.py
index 22c8bda1..78a85d71 100644
--- a/examples/distributed/few-shot/maml_omniglot.py
+++ b/examples/distributed/few-shot/maml_omniglot.py
@@ -58,7 +58,7 @@
import torchopt.distributed as todist
-from support.omniglot_loaders import OmniglotNShot # isort: skip
+from helpers.omniglot_loaders import OmniglotNShot # isort: skip
mpl.use('Agg')
@@ -187,7 +187,6 @@ def inner_loop(net_rref, x_spt, y_spt, x_qry, y_qry, n_inner_iter):
x_qry = x_qry.to(device)
y_qry = y_qry.to(device)
- querysz = x_qry.size(0)
inner_opt = torchopt.MetaSGD(net, lr=1e-1)
for _ in range(n_inner_iter):
@@ -197,7 +196,7 @@ def inner_loop(net_rref, x_spt, y_spt, x_qry, y_qry, n_inner_iter):
qry_logits = net(x_qry)
qry_loss = F.cross_entropy(qry_logits, y_qry).cpu()
- qry_acc = (qry_logits.argmax(dim=1) == y_qry).sum().cpu().item() / querysz
+ qry_acc = (qry_logits.argmax(dim=1) == y_qry).mean().cpu().item()
return qry_loss, qry_acc
@@ -232,11 +231,11 @@ def train(db: OmniglotNShot, net: nn.Module, meta_opt: optim.Adam, epoch: int, l
qry_acc = 100.0 * qry_acc
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
+ torch.cuda.empty_cache()
print(
f'[Epoch {i:.2f}] Train Loss: {qry_loss:.2f} | Acc: {qry_acc:.2f} | Time: {iter_time:.2f}'
)
-
log.append(
{
'epoch': i,
@@ -275,6 +274,8 @@ def test(db, net, epoch, log):
qry_losses = np.mean(qry_losses)
qry_accs = 100.0 * np.mean(qry_accs)
+ torch.cuda.empty_cache()
+
print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}')
log.append(
{
diff --git a/examples/distributed/few-shot/maml_omniglot_local_loader.py b/examples/distributed/few-shot/maml_omniglot_local_loader.py
index 9c36c6ad..d41a7dae 100644
--- a/examples/distributed/few-shot/maml_omniglot_local_loader.py
+++ b/examples/distributed/few-shot/maml_omniglot_local_loader.py
@@ -60,7 +60,7 @@
import torchopt.distributed as todist
-from support.omniglot_loaders import OmniglotNShot # isort: skip
+from helpers.omniglot_loaders import OmniglotNShot # isort: skip
mpl.use('Agg')
@@ -228,7 +228,6 @@ def inner_loop(net_rref, n_inner_iter, task_id, task_num, mode):
x_qry = x_qry.to(device)
y_qry = y_qry.to(device)
- querysz = x_qry.size(0)
inner_opt = torchopt.MetaSGD(net, lr=1e-1)
for _ in range(n_inner_iter):
@@ -238,7 +237,7 @@ def inner_loop(net_rref, n_inner_iter, task_id, task_num, mode):
qry_logits = net(x_qry)
qry_loss = F.cross_entropy(qry_logits, y_qry).cpu()
- qry_acc = (qry_logits.argmax(dim=1) == y_qry).sum().cpu().item() / querysz
+ qry_acc = (qry_logits.argmax(dim=1) == y_qry).mean().cpu().item()
return qry_loss, qry_acc
@@ -275,11 +274,11 @@ def train(net: nn.Module, meta_opt: optim.Adam, epoch: int, log: list):
qry_acc = 100.0 * qry_acc
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
+ torch.cuda.empty_cache()
print(
f'[Epoch {i:.2f}] Train Loss: {qry_loss:.2f} | Acc: {qry_acc:.2f} | Time: {iter_time:.2f}'
)
-
log.append(
{
'epoch': i,
@@ -319,6 +318,8 @@ def test(net, epoch, log):
qry_losses = np.mean(qry_losses)
qry_accs = 100.0 * np.mean(qry_accs)
+ torch.cuda.empty_cache()
+
print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}')
log.append(
{
diff --git a/examples/few-shot/support/omniglot_loaders.py b/examples/few-shot/helpers/omniglot_loaders.py
similarity index 100%
rename from examples/few-shot/support/omniglot_loaders.py
rename to examples/few-shot/helpers/omniglot_loaders.py
diff --git a/examples/few-shot/maml_omniglot.py b/examples/few-shot/maml_omniglot.py
index 766391db..879a235a 100644
--- a/examples/few-shot/maml_omniglot.py
+++ b/examples/few-shot/maml_omniglot.py
@@ -54,7 +54,7 @@
import torchopt
-from support.omniglot_loaders import OmniglotNShot # isort: skip
+from helpers.omniglot_loaders import OmniglotNShot # isort: skip
mpl.use('Agg')
@@ -133,8 +133,7 @@ def train(db, net, meta_opt, epoch, log):
# Sample a batch of support and query images and labels.
x_spt, y_spt, x_qry, y_qry = db.next()
- task_num, setsz, c_, h, w = x_spt.size()
- querysz = x_qry.size(1)
+ task_num = x_spt.size(0)
# TODO: Maybe pull this out into a separate module so it
# doesn't have to be duplicated between `train` and `test`?
@@ -165,9 +164,9 @@ def train(db, net, meta_opt, epoch, log):
# These will be used to update the model's meta-parameters.
qry_logits = net(x_qry[i])
qry_loss = F.cross_entropy(qry_logits, y_qry[i])
+ qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean()
qry_losses.append(qry_loss)
- qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).sum().item() / querysz
- qry_accs.append(qry_acc)
+ qry_accs.append(qry_acc.item())
torchopt.recover_state_dict(net, net_state_dict)
torchopt.recover_state_dict(inner_opt, optim_state_dict)
@@ -176,14 +175,14 @@ def train(db, net, meta_opt, epoch, log):
qry_losses.backward()
meta_opt.step()
qry_losses = qry_losses.item()
- qry_accs = 100.0 * sum(qry_accs) / task_num
+ qry_accs = 100.0 * np.mean(qry_accs)
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
+ torch.cuda.empty_cache()
print(
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
)
-
log.append(
{
'epoch': i,
@@ -211,8 +210,7 @@ def test(db, net, epoch, log):
for batch_idx in range(n_test_iter):
x_spt, y_spt, x_qry, y_qry = db.next('test')
- task_num, setsz, c_, h, w = x_spt.size()
- querysz = x_qry.size(1)
+ task_num = x_spt.size(0)
# TODO: Maybe pull this out into a separate module so it
# doesn't have to be duplicated between `train` and `test`?
@@ -231,15 +229,18 @@ def test(db, net, epoch, log):
# The query loss and acc induced by these parameters.
qry_logits = net(x_qry[i]).detach()
- qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction='none')
- qry_losses.append(qry_loss.detach())
- qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach())
+ qry_loss = F.cross_entropy(qry_logits, y_qry[i])
+ qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean()
+ qry_losses.append(qry_loss.item())
+ qry_accs.append(qry_acc.item())
torchopt.recover_state_dict(net, net_state_dict)
torchopt.recover_state_dict(inner_opt, optim_state_dict)
- qry_losses = torch.mean(torch.stack(qry_losses)).item()
- qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item()
+ qry_losses = np.mean(qry_losses)
+ qry_accs = 100.0 * np.mean(qry_accs)
+ torch.cuda.empty_cache()
+
print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}')
log.append(
{
diff --git a/examples/iMAML/README.md b/examples/iMAML/README.md
index 91f95f69..6208bc81 100644
--- a/examples/iMAML/README.md
+++ b/examples/iMAML/README.md
@@ -6,7 +6,8 @@ Code on implicit MAML few-shot Omniglot classification in paper [Meta-Learning w
```bash
### Run
-python3 imaml_omniglot.py --inner_steps 5
+python3 imaml_omniglot.py --inner_steps 5 # use OOP APIs
+python3 imaml_omniglot_functional.py --inner_steps 5 # use functional APIs
```
## Results
@@ -16,3 +17,7 @@ The figure illustrate the experimental result.
+
+
+

+
diff --git a/examples/iMAML/support/omniglot_loaders.py b/examples/iMAML/helpers/omniglot_loaders.py
similarity index 100%
rename from examples/iMAML/support/omniglot_loaders.py
rename to examples/iMAML/helpers/omniglot_loaders.py
diff --git a/examples/iMAML/imaml-accs-functional.png b/examples/iMAML/imaml-accs-functional.png
new file mode 100644
index 00000000..a23132cf
Binary files /dev/null and b/examples/iMAML/imaml-accs-functional.png differ
diff --git a/examples/iMAML/imaml-accs.png b/examples/iMAML/imaml-accs.png
index 1d5134a4..c0296be8 100644
Binary files a/examples/iMAML/imaml-accs.png and b/examples/iMAML/imaml-accs.png differ
diff --git a/examples/iMAML/imaml_omniglot.py b/examples/iMAML/imaml_omniglot.py
index 7b165ac0..2b0c9738 100644
--- a/examples/iMAML/imaml_omniglot.py
+++ b/examples/iMAML/imaml_omniglot.py
@@ -24,7 +24,6 @@
import argparse
import time
-import functorch
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
@@ -32,19 +31,53 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-import torch.optim as optim
import torchopt
-from torchopt import pytree
+from torchopt.diff.implicit import ImplicitMetaGradientModule
-from support.omniglot_loaders import OmniglotNShot # isort: skip
+from helpers.omniglot_loaders import OmniglotNShot # isort: skip
mpl.use('Agg')
plt.style.use('bmh')
+class InnerNet(
+ ImplicitMetaGradientModule,
+ linear_solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),
+):
+ def __init__(self, meta_net, n_inner_iter, reg_param):
+ super().__init__()
+ self.meta_net = meta_net
+ self.net = torchopt.module_clone(meta_net, by='deepcopy', detach_buffers=True)
+ self.n_inner_iter = n_inner_iter
+ self.reg_param = reg_param
+
+ def forward(self, x):
+ return self.net(x)
+
+ def objective(self, x, y):
+ y_pred = self(x)
+ loss = F.cross_entropy(y_pred, y)
+ regularization_loss = 0
+ for p1, p2 in zip(self.parameters(), self.meta_parameters()):
+ regularization_loss += 0.5 * self.reg_param * torch.sum(torch.square(p1 - p2))
+ return loss + regularization_loss
+
+ def solve(self, x, y):
+ params = tuple(self.parameters())
+ inner_optim = torchopt.SGD(params, lr=1e-1)
+ with torch.enable_grad():
+ # Temporarily enable gradient computation for conducting the optimization
+ for _ in range(self.n_inner_iter):
+ loss = self.objective(x, y)
+ inner_optim.zero_grad()
+ loss.backward(inputs=params)
+ inner_optim.step()
+ return self
+
+
def main():
argparser = argparse.ArgumentParser()
argparser.add_argument('--n_way', type=int, help='n way', default=5)
@@ -102,24 +135,18 @@ def main():
# We will use Adam to (meta-)optimize the initial parameters
# to be adapted.
net.train()
- fnet, params = functorch.make_functional(net)
- meta_opt = torchopt.adam(lr=1e-3)
- meta_opt_state = meta_opt.init(params)
+ meta_opt = torchopt.Adam(net.parameters(), lr=1e-3)
log = []
- test(db, [params, fnet], epoch=-1, log=log, args=args)
+ test(db, net, epoch=-1, log=log, args=args)
for epoch in range(10):
- meta_opt, meta_opt_state = train(
- db, [params, fnet], (meta_opt, meta_opt_state), epoch, log, args
- )
- test(db, [params, fnet], epoch, log, args)
+ train(db, net, meta_opt, epoch, log, args)
+ test(db, net, epoch, log, args)
plot(log)
-def train(db, net, meta_opt_and_state, epoch, log, args):
+def train(db, net, meta_opt, epoch, log, args):
n_train_iter = db.x_train.shape[0] // db.batchsz
- params, fnet = net
- meta_opt, meta_opt_state = meta_opt_and_state
# Given this module we've created, rip out the parameters and buffers
# and return a functional version of the module. `fnet` is stateless
# and can be called with `fnet(params, buffers, args, kwargs)`
@@ -130,54 +157,44 @@ def train(db, net, meta_opt_and_state, epoch, log, args):
# Sample a batch of support and query images and labels.
x_spt, y_spt, x_qry, y_qry = db.next()
- task_num, setsz, c_, h, w = x_spt.size()
- querysz = x_qry.size(1)
+ task_num = x_spt.size(0)
n_inner_iter = args.inner_steps
reg_param = args.reg_params
+
qry_losses = []
qry_accs = []
-
- init_params_copy = pytree.tree_map(
- lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), params
- )
+ meta_opt.zero_grad()
for i in range(task_num):
# Optimize the likelihood of the support set by taking
# gradient steps w.r.t. the model's parameters.
# This adapts the model's meta-parameters to the task.
- optimal_params = train_imaml_inner_solver(
- init_params_copy,
- params,
- (x_spt[i], y_spt[i]),
- (fnet, n_inner_iter, reg_param),
- )
+ inner_net = InnerNet(net, n_inner_iter, reg_param)
+ optimal_inner_net = inner_net.solve(x_spt[i], y_spt[i])
+
# The final set of adapted parameters will induce some
# final loss and accuracy on the query dataset.
# These will be used to update the model's meta-parameters.
- qry_logits = fnet(optimal_params, x_qry[i])
+ qry_logits = optimal_inner_net(x_qry[i])
qry_loss = F.cross_entropy(qry_logits, y_qry[i])
- # Update the model's meta-parameters to optimize the query
- # losses across all of the tasks sampled in this batch.
- # qry_loss = qry_loss / task_num # scale gradients
- meta_grads = torch.autograd.grad(qry_loss / task_num, params)
- meta_updates, meta_opt_state = meta_opt.update(meta_grads, meta_opt_state)
- params = torchopt.apply_updates(params, meta_updates)
- qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).sum().item() / querysz
-
- qry_losses.append(qry_loss.detach())
- qry_accs.append(qry_acc)
-
- qry_losses = sum(qry_losses) / task_num
- qry_accs = 100.0 * sum(qry_accs) / task_num
+ qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean()
+ qry_losses.append(qry_loss)
+ qry_accs.append(qry_acc.item())
+
+ qry_losses = torch.mean(torch.stack(qry_losses))
+ qry_losses.backward()
+ meta_opt.step()
+ qry_losses = qry_losses.item()
+ qry_accs = 100.0 * np.mean(qry_accs)
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
+ torch.cuda.empty_cache()
print(
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
)
-
log.append(
{
'epoch': i,
@@ -188,8 +205,6 @@ def train(db, net, meta_opt_and_state, epoch, log, args):
}
)
- return (meta_opt, meta_opt_state)
-
def test(db, net, epoch, log, args):
# Crucially in our testing procedure here, we do *not* fine-tune
@@ -197,8 +212,6 @@ def test(db, net, epoch, log, args):
# Most research papers using MAML for this task do an extra
# stage of fine-tuning here that should be added if you are
# adapting this code for research.
- params, fnet = net
- # fnet, params, buffers = functorch.make_functional_with_buffers(net)
n_test_iter = db.x_test.shape[0] // db.batchsz
qry_losses = []
@@ -208,36 +221,32 @@ def test(db, net, epoch, log, args):
# doesn't have to be duplicated between `train` and `test`?
n_inner_iter = args.inner_steps
reg_param = args.reg_params
- init_params_copy = pytree.tree_map(
- lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), params
- )
for batch_idx in range(n_test_iter):
x_spt, y_spt, x_qry, y_qry = db.next('test')
- task_num, setsz, c_, h, w = x_spt.size()
- querysz = x_qry.size(1)
+ task_num = x_spt.size(0)
for i in range(task_num):
# Optimize the likelihood of the support set by taking
# gradient steps w.r.t. the model's parameters.
# This adapts the model's meta-parameters to the task.
- optimal_params = test_imaml_inner_solver(
- init_params_copy,
- params,
- (x_spt[i], y_spt[i]),
- (fnet, n_inner_iter, reg_param),
- )
+ inner_net = InnerNet(net, n_inner_iter, reg_param)
+ with torch.no_grad():
+ optimal_inner_net = inner_net.solve(x_spt[i], y_spt[i])
# The query loss and acc induced by these parameters.
- qry_logits = fnet(optimal_params, x_qry[i])
- qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction='none')
- qry_losses.append(qry_loss.detach())
- qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach())
+ qry_logits = optimal_inner_net(x_qry[i])
+ qry_loss = F.cross_entropy(qry_logits, y_qry[i])
+ qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean()
+ qry_losses.append(qry_loss.item())
+ qry_accs.append(qry_acc.item())
+
+ qry_losses = np.mean(qry_losses)
+ qry_accs = 100.0 * np.mean(qry_accs)
+ torch.cuda.empty_cache()
- qry_losses = torch.cat(qry_losses).mean().item()
- qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item()
print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}')
log.append(
{
@@ -250,71 +259,6 @@ def test(db, net, epoch, log, args):
)
-def imaml_objective(optimal_params, init_params, data, aux):
- x_spt, y_spt = data
- fnet, n_inner_iter, reg_param = aux
- fnet.eval()
- y_pred = fnet(optimal_params, x_spt)
- fnet.train()
- regularization_loss = 0
- for p1, p2 in zip(optimal_params, init_params):
- regularization_loss += 0.5 * reg_param * torch.sum(torch.square(p1 - p2))
- loss = F.cross_entropy(y_pred, y_spt) + regularization_loss
- return loss
-
-
-@torchopt.diff.implicit.custom_root(
- functorch.grad(imaml_objective, argnums=0),
- argnums=1,
- has_aux=False,
- solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),
-)
-def train_imaml_inner_solver(init_params_copy, init_params, data, aux):
- x_spt, y_spt = data
- fnet, n_inner_iter, reg_param = aux
- # Initial functional optimizer based on TorchOpt
- params = init_params_copy
- inner_opt = torchopt.sgd(lr=1e-1)
- inner_opt_state = inner_opt.init(params)
- with torch.enable_grad():
- # Temporarily enable gradient computation for conducting the optimization
- for _ in range(n_inner_iter):
- pred = fnet(params, x_spt)
- loss = F.cross_entropy(pred, y_spt) # compute loss
- # Compute regularization loss
- regularization_loss = 0
- for p1, p2 in zip(params, init_params):
- regularization_loss += 0.5 * reg_param * torch.sum(torch.square(p1 - p2))
- final_loss = loss + regularization_loss
- grads = torch.autograd.grad(final_loss, params) # compute gradients
- updates, inner_opt_state = inner_opt.update(grads, inner_opt_state) # get updates
- params = torchopt.apply_updates(params, updates)
- return params
-
-
-def test_imaml_inner_solver(init_params_copy, init_params, data, aux):
- x_spt, y_spt = data
- fnet, n_inner_iter, reg_param = aux
- # Initial functional optimizer based on TorchOpt
- params = init_params_copy
- inner_opt = torchopt.sgd(lr=1e-1)
- inner_opt_state = inner_opt.init(params)
- with torch.enable_grad():
- # Temporarily enable gradient computation for conducting the optimization
- for _ in range(n_inner_iter):
- pred = fnet(params, x_spt)
- loss = F.cross_entropy(pred, y_spt) # compute loss
- # Compute regularization loss
- regularization_loss = 0
- for p1, p2 in zip(params, init_params):
- regularization_loss += 0.5 * reg_param * torch.sum(torch.square(p1 - p2))
- final_loss = loss + regularization_loss
- grads = torch.autograd.grad(final_loss, params) # compute gradients
- updates, inner_opt_state = inner_opt.update(grads, inner_opt_state) # get updates
- params = torchopt.apply_updates(params, updates)
- return params
-
-
def plot(log):
# Generally you should pull your plotting code out of your training
# script but we are doing it here for brevity.
diff --git a/examples/iMAML/imaml_omniglot_functional.py b/examples/iMAML/imaml_omniglot_functional.py
new file mode 100644
index 00000000..080541c6
--- /dev/null
+++ b/examples/iMAML/imaml_omniglot_functional.py
@@ -0,0 +1,338 @@
+# Copyright 2022 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""
+This example shows how to use TorchOpt to do iMAML-GD (see [1] for more details)
+for few-shot Omniglot classification.
+
+[1] Rajeswaran, A., Finn, C., Kakade, S. M., & Levine, S. (2019).
+ Meta-learning with implicit gradients. In Advances in Neural Information Processing Systems (pp. 113-124).
+ https://arxiv.org/abs/1909.04630
+"""
+
+import argparse
+import time
+
+import functorch
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import torchopt
+from torchopt import pytree
+
+
+from helpers.omniglot_loaders import OmniglotNShot # isort: skip
+
+
+mpl.use('Agg')
+plt.style.use('bmh')
+
+
+def main():
+ argparser = argparse.ArgumentParser()
+ argparser.add_argument('--n_way', type=int, help='n way', default=5)
+ argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=5)
+ argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=5)
+ argparser.add_argument('--inner_steps', type=int, help='number of inner steps', default=5)
+ argparser.add_argument(
+ '--reg_params', type=float, help='regularization parameters', default=2.0
+ )
+ argparser.add_argument(
+ '--task_num', type=int, help='meta batch size, namely task num', default=16
+ )
+ argparser.add_argument('--seed', type=int, help='random seed', default=1)
+ args = argparser.parse_args()
+
+ torch.manual_seed(args.seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(args.seed)
+ torch.backends.cudnn.benchmark = False
+ torch.backends.cudnn.deterministic = True
+ np.random.seed(args.seed)
+ rng = np.random.default_rng(args.seed)
+
+ # Set up the Omniglot loader.
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+ db = OmniglotNShot(
+ '/tmp/omniglot-data',
+ batchsz=args.task_num,
+ n_way=args.n_way,
+ k_shot=args.k_spt,
+ k_query=args.k_qry,
+ imgsz=28,
+ rng=rng,
+ device=device,
+ )
+
+ # Create a vanilla PyTorch neural network.
+ net = nn.Sequential(
+ nn.Conv2d(1, 64, 3),
+ nn.BatchNorm2d(64, momentum=1.0, affine=True, track_running_stats=False),
+ nn.ReLU(inplace=False),
+ nn.MaxPool2d(2, 2),
+ nn.Conv2d(64, 64, 3),
+ nn.BatchNorm2d(64, momentum=1.0, affine=True, track_running_stats=False),
+ nn.ReLU(inplace=False),
+ nn.MaxPool2d(2, 2),
+ nn.Conv2d(64, 64, 3),
+ nn.BatchNorm2d(64, momentum=1.0, affine=True, track_running_stats=False),
+ nn.ReLU(inplace=False),
+ nn.MaxPool2d(2, 2),
+ nn.Flatten(),
+ nn.Linear(64, args.n_way),
+ ).to(device)
+
+ # We will use Adam to (meta-)optimize the initial parameters
+ # to be adapted.
+ net.train()
+ fnet, params = functorch.make_functional(net)
+ meta_opt = torchopt.adam(lr=1e-3)
+ meta_opt_state = meta_opt.init(params)
+
+ log = []
+ test(db, [params, fnet], epoch=-1, log=log, args=args)
+ for epoch in range(10):
+ meta_opt, meta_opt_state = train(
+ db, [params, fnet], (meta_opt, meta_opt_state), epoch, log, args
+ )
+ test(db, [params, fnet], epoch, log, args)
+ plot(log)
+
+
+def train(db, net, meta_opt_and_state, epoch, log, args):
+ n_train_iter = db.x_train.shape[0] // db.batchsz
+ params, fnet = net
+ meta_opt, meta_opt_state = meta_opt_and_state
+ # Given this module we've created, rip out the parameters and buffers
+ # and return a functional version of the module. `fnet` is stateless
+ # and can be called with `fnet(params, buffers, args, kwargs)`
+ # fnet, params, buffers = functorch.make_functional_with_buffers(net)
+
+ for batch_idx in range(n_train_iter):
+ start_time = time.time()
+ # Sample a batch of support and query images and labels.
+ x_spt, y_spt, x_qry, y_qry = db.next()
+
+ task_num = x_spt.size(0)
+
+ n_inner_iter = args.inner_steps
+ reg_param = args.reg_params
+ qry_losses = []
+ qry_accs = []
+
+ init_params_copy = pytree.tree_map(
+ lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), params
+ )
+
+ for i in range(task_num):
+ # Optimize the likelihood of the support set by taking
+ # gradient steps w.r.t. the model's parameters.
+ # This adapts the model's meta-parameters to the task.
+
+ optimal_params = train_imaml_inner_solver(
+ init_params_copy,
+ params,
+ (x_spt[i], y_spt[i]),
+ (fnet, n_inner_iter, reg_param),
+ )
+ # The final set of adapted parameters will induce some
+ # final loss and accuracy on the query dataset.
+ # These will be used to update the model's meta-parameters.
+ qry_logits = fnet(optimal_params, x_qry[i])
+ qry_loss = F.cross_entropy(qry_logits, y_qry[i])
+ # Update the model's meta-parameters to optimize the query
+ # losses across all of the tasks sampled in this batch.
+ # qry_loss = qry_loss / task_num # scale gradients
+ meta_grads = torch.autograd.grad(qry_loss / task_num, params)
+ meta_updates, meta_opt_state = meta_opt.update(meta_grads, meta_opt_state)
+ params = torchopt.apply_updates(params, meta_updates)
+ qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean()
+ qry_losses.append(qry_loss.item())
+ qry_accs.append(qry_acc.item())
+
+ qry_losses = np.mean(qry_losses)
+ qry_accs = 100.0 * np.mean(qry_accs)
+ i = epoch + float(batch_idx) / n_train_iter
+ iter_time = time.time() - start_time
+ torch.cuda.empty_cache()
+
+ print(
+ f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
+ )
+ log.append(
+ {
+ 'epoch': i,
+ 'loss': qry_losses,
+ 'acc': qry_accs,
+ 'mode': 'train',
+ 'time': time.time(),
+ }
+ )
+
+ return (meta_opt, meta_opt_state)
+
+
+def test(db, net, epoch, log, args):
+ # Crucially in our testing procedure here, we do *not* fine-tune
+ # the model during testing for simplicity.
+ # Most research papers using MAML for this task do an extra
+ # stage of fine-tuning here that should be added if you are
+ # adapting this code for research.
+ params, fnet = net
+ # fnet, params, buffers = functorch.make_functional_with_buffers(net)
+ n_test_iter = db.x_test.shape[0] // db.batchsz
+
+ qry_losses = []
+ qry_accs = []
+
+ # TODO: Maybe pull this out into a separate module so it
+ # doesn't have to be duplicated between `train` and `test`?
+ n_inner_iter = args.inner_steps
+ reg_param = args.reg_params
+ init_params_copy = pytree.tree_map(
+ lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), params
+ )
+
+ for batch_idx in range(n_test_iter):
+ x_spt, y_spt, x_qry, y_qry = db.next('test')
+
+ task_num = x_spt.size(0)
+
+ for i in range(task_num):
+ # Optimize the likelihood of the support set by taking
+ # gradient steps w.r.t. the model's parameters.
+ # This adapts the model's meta-parameters to the task.
+
+ optimal_params = test_imaml_inner_solver(
+ init_params_copy,
+ params,
+ (x_spt[i], y_spt[i]),
+ (fnet, n_inner_iter, reg_param),
+ )
+
+ # The query loss and acc induced by these parameters.
+ qry_logits = fnet(optimal_params, x_qry[i])
+ qry_loss = F.cross_entropy(qry_logits, y_qry[i])
+ qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean()
+ qry_losses.append(qry_loss.item())
+ qry_accs.append(qry_acc.item())
+
+ qry_losses = np.mean(qry_losses)
+ qry_accs = 100.0 * np.mean(qry_accs)
+ torch.cuda.empty_cache()
+
+ print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}')
+ log.append(
+ {
+ 'epoch': epoch + 1,
+ 'loss': qry_losses,
+ 'acc': qry_accs,
+ 'mode': 'test',
+ 'time': time.time(),
+ }
+ )
+
+
+def imaml_objective(optimal_params, init_params, data, aux):
+ x_spt, y_spt = data
+ fnet, n_inner_iter, reg_param = aux
+ y_pred = fnet(optimal_params, x_spt)
+ regularization_loss = 0
+ for p1, p2 in zip(optimal_params, init_params):
+ regularization_loss += 0.5 * reg_param * torch.sum(torch.square(p1 - p2))
+ loss = F.cross_entropy(y_pred, y_spt) + regularization_loss
+ return loss
+
+
+@torchopt.diff.implicit.custom_root(
+ functorch.grad(imaml_objective, argnums=0),
+ argnums=1,
+ has_aux=False,
+ solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),
+)
+def train_imaml_inner_solver(init_params_copy, init_params, data, aux):
+ x_spt, y_spt = data
+ fnet, n_inner_iter, reg_param = aux
+ # Initial functional optimizer based on TorchOpt
+ params = init_params_copy
+ inner_opt = torchopt.sgd(lr=1e-1)
+ inner_opt_state = inner_opt.init(params)
+ with torch.enable_grad():
+ # Temporarily enable gradient computation for conducting the optimization
+ for _ in range(n_inner_iter):
+ pred = fnet(params, x_spt)
+ loss = F.cross_entropy(pred, y_spt) # compute loss
+ # Compute regularization loss
+ regularization_loss = 0
+ for p1, p2 in zip(params, init_params):
+ regularization_loss += 0.5 * reg_param * torch.sum(torch.square(p1 - p2))
+ final_loss = loss + regularization_loss
+ grads = torch.autograd.grad(final_loss, params) # compute gradients
+ updates, inner_opt_state = inner_opt.update(grads, inner_opt_state) # get updates
+ params = torchopt.apply_updates(params, updates)
+ return params
+
+
+def test_imaml_inner_solver(init_params_copy, init_params, data, aux):
+ x_spt, y_spt = data
+ fnet, n_inner_iter, reg_param = aux
+ # Initial functional optimizer based on TorchOpt
+ params = init_params_copy
+ inner_opt = torchopt.sgd(lr=1e-1)
+ inner_opt_state = inner_opt.init(params)
+ with torch.enable_grad():
+ # Temporarily enable gradient computation for conducting the optimization
+ for _ in range(n_inner_iter):
+ pred = fnet(params, x_spt)
+ loss = F.cross_entropy(pred, y_spt) # compute loss
+ # Compute regularization loss
+ regularization_loss = 0
+ for p1, p2 in zip(params, init_params):
+ regularization_loss += 0.5 * reg_param * torch.sum(torch.square(p1 - p2))
+ final_loss = loss + regularization_loss
+ grads = torch.autograd.grad(final_loss, params) # compute gradients
+ updates, inner_opt_state = inner_opt.update(grads, inner_opt_state) # get updates
+ params = torchopt.apply_updates(params, updates)
+ return params
+
+
+def plot(log):
+ # Generally you should pull your plotting code out of your training
+ # script but we are doing it here for brevity.
+ df = pd.DataFrame(log)
+
+ fig, ax = plt.subplots(figsize=(8, 4), dpi=250)
+ train_df = df[df['mode'] == 'train']
+ test_df = df[df['mode'] == 'test']
+ ax.plot(train_df['epoch'], train_df['acc'], label='Train')
+ ax.plot(test_df['epoch'], test_df['acc'], label='Test')
+ ax.set_xlabel('Epoch')
+ ax.set_ylabel('Accuracy')
+ ax.set_ylim(80, 100)
+ ax.set_title('iMAML Omniglot (Functional)')
+ ax.legend(ncol=2, loc='lower right')
+ fig.tight_layout()
+ fname = 'imaml-accs-functional.png'
+ print(f'--- Plotting accuracy to {fname}')
+ fig.savefig(fname)
+ plt.close(fig)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/setup.py b/setup.py
index f457b32e..966bdb32 100644
--- a/setup.py
+++ b/setup.py
@@ -1,5 +1,6 @@
import os
import pathlib
+import re
import shutil
import sys
import sysconfig
@@ -84,8 +85,6 @@ def build_extension(self, ext):
VERSION_CONTENT = None
if not version.__release__:
- import re
-
VERSION_CONTENT = VERSION_FILE.read_text(encoding='UTF-8')
VERSION_FILE.write_text(
data=re.sub(
diff --git a/tests/helpers.py b/tests/helpers.py
index a84d78bc..6c7c4f01 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -23,6 +23,7 @@
import pytest
import torch
import torch.nn as nn
+import torch.types
from torch.utils import data
@@ -100,7 +101,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
@torch.no_grad()
def get_models(
- device: Optional[Union[str, torch.device]] = None, dtype: torch.dtype = torch.float32
+ device: torch.types.Device = None, dtype: torch.dtype = torch.float32
) -> Tuple[nn.Module, nn.Module, nn.Module, data.DataLoader]:
seed_everything(seed=42)
diff --git a/tests/test_implicit.py b/tests/test_implicit.py
index 399ee025..06d180d4 100644
--- a/tests/test_implicit.py
+++ b/tests/test_implicit.py
@@ -16,7 +16,7 @@
import copy
from collections import OrderedDict
from types import FunctionType
-from typing import Optional, Tuple, Union
+from typing import Tuple
import functorch
import jax
@@ -27,6 +27,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
+import torch.types
from torch.utils import data
import helpers
@@ -70,7 +71,7 @@ def func(params, x):
@torch.no_grad()
def get_model_torch(
- device: Optional[Union[str, torch.device]] = None, dtype: torch.dtype = torch.float32
+ device: torch.types.Device = None, dtype: torch.dtype = torch.float32
) -> Tuple[nn.Module, data.DataLoader]:
helpers.seed_everything(seed=42)
@@ -242,7 +243,7 @@ class InnerNet(ImplicitMetaGradientModule, has_aux=True):
def __init__(self, meta_model):
super().__init__()
self.meta_model = meta_model
- self.model = copy.deepcopy(meta_model)
+ self.model = torchopt.module_clone(meta_model, by='deepcopy', detach_buffers=True)
def forward(self, x):
return self.model(x)
diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py
index 3ddaf296..4ad48d32 100644
--- a/torchopt/diff/implicit/nn/module.py
+++ b/torchopt/diff/implicit/nn/module.py
@@ -24,7 +24,7 @@
import torchopt.nn
from torchopt import pytree
from torchopt.diff.implicit.decorator import custom_root
-from torchopt.typing import TensorTree, TupleOfTensors
+from torchopt.typing import LinearSolver, TensorTree, TupleOfTensors
from torchopt.utils import extract_module_containers
@@ -76,12 +76,16 @@ def enable_implicit_gradients(
cls: Type['ImplicitMetaGradientModule'],
) -> Type['ImplicitMetaGradientModule']:
"""Enable implicit gradients for the :func:`solve` function."""
- solve = cls.solve
- has_aux = cls.has_aux
- if getattr(solve, '__implicit_gradients_enabled__', False):
+ cls_solve = cls.solve
+ if getattr(cls_solve, '__implicit_gradients_enabled__', False):
raise ValueError('Implicit gradients are already enabled for the solve function.')
- @functools.wraps(solve)
+ cls_has_aux = cls.has_aux
+ custom_root_kwargs = dict(has_aux=cls_has_aux, solve=cls.linear_solve)
+ if cls.linear_solve is None:
+ custom_root_kwargs.pop('solve')
+
+ @functools.wraps(cls_solve)
def wrapped( # pylint: disable=too-many-locals
self: 'ImplicitMetaGradientModule', *input, **kwargs # pylint: disable=redefined-builtin
) -> Union['ImplicitMetaGradientModule', Tuple['ImplicitMetaGradientModule', Any]]:
@@ -141,15 +145,15 @@ def optimality_fn(
):
container.update(container_backup)
- @custom_root(optimality_fn, argnums=1, has_aux=has_aux)
+ @custom_root(optimality_fn, argnums=1, **custom_root_kwargs) # type: ignore[arg-type]
def solver_fn(
flat_params: TupleOfTensors, # pylint: disable=unused-argument
flat_meta_params: TupleOfTensors, # pylint: disable=unused-argument
*input, # pylint: disable=redefined-builtin
**kwargs,
) -> Union[TupleOfTensors, Tuple[TupleOfTensors, Any]]:
- output = solve(self, *input, **kwargs)
- if has_aux:
+ output = cls_solve(self, *input, **kwargs)
+ if cls_has_aux:
if not (isinstance(output, tuple) and len(output) == 2):
raise RuntimeError(
f'Output of method ImplicitMetaGradientModule.solve should be a '
@@ -163,12 +167,12 @@ def solver_fn(
)
flat_optimal_params: TupleOfTensors = tuple(pytree.tree_leaves(params_containers)) # type: ignore[arg-type]
- if has_aux:
+ if cls_has_aux:
return flat_optimal_params, aux
return flat_optimal_params
output = solver_fn(flat_params, flat_meta_params, *input, **kwargs)
- if has_aux:
+ if cls_has_aux:
_, aux = output
return self, aux
return self
@@ -184,11 +188,15 @@ class ImplicitMetaGradientModule(torchopt.nn.MetaGradientModule):
_custom_optimality: bool
_custom_objective: bool
has_aux: bool
+ linear_solve: Optional[LinearSolver]
- def __init_subclass__(cls, has_aux=False) -> None:
+ def __init_subclass__(
+ cls, has_aux: bool = False, linear_solve: Optional[LinearSolver] = None
+ ) -> None:
"""Initialize the subclass."""
super().__init_subclass__()
cls.has_aux = has_aux
+ cls.linear_solve = linear_solve
optimality = getattr(cls, 'optimality', ImplicitMetaGradientModule.optimality)
objective = getattr(cls, 'objective', ImplicitMetaGradientModule.objective)
diff --git a/torchopt/typing.py b/torchopt/typing.py
index f1dcd1dd..e6e01690 100644
--- a/torchopt/typing.py
+++ b/torchopt/typing.py
@@ -21,6 +21,7 @@
from optree.typing import PyTree, PyTreeTypeVar
from torch import Tensor
from torch.futures import Future
+from torch.types import Device
from torchopt.base import ChainedGradientTransformation, EmptyState, GradientTransformation
@@ -50,6 +51,8 @@
'OptionalTensorOrOptionalTensors',
'OptionalTensorTree',
'Future',
+ 'LinearSolver',
+ 'Device',
]
T = TypeVar('T')
@@ -85,3 +88,6 @@
__all__.extend(['RRef'])
else:
RRef = None # type: ignore[misc,assignment] # pylint: disable=invalid-name
+
+# solver(matvec, b) -> solution
+LinearSolver: TypeAlias = Callable[[Callable[[TensorTree], TensorTree], TensorTree], TensorTree]
diff --git a/torchopt/utils.py b/torchopt/utils.py
index 3301f92c..cfe25e32 100644
--- a/torchopt/utils.py
+++ b/torchopt/utils.py
@@ -35,7 +35,7 @@
import torch.nn as nn
from torchopt import pytree
-from torchopt.typing import OptState, TensorTree
+from torchopt.typing import Device, OptState, TensorTree
if TYPE_CHECKING:
@@ -106,7 +106,7 @@ def extract_state_dict(
target: nn.Module,
*,
by: CopyMode = 'reference',
- device: Optional[Union[int, str, torch.device]] = None,
+ device: Device = None,
with_buffers: bool = True,
enable_visual: bool = False,
visual_prefix: str = '',
@@ -119,7 +119,7 @@ def extract_state_dict(
target: 'MetaOptimizer',
*,
by: CopyMode = 'reference',
- device: Optional[Union[int, str, torch.device]] = None,
+ device: Device = None,
with_buffers: bool = True,
enable_visual: bool = False,
visual_prefix: str = '',
@@ -132,7 +132,7 @@ def extract_state_dict(
target: Union[nn.Module, 'MetaOptimizer'],
*,
by: CopyMode = 'reference',
- device: Optional[Union[int, str, torch.device]] = None,
+ device: Device = None,
with_buffers: bool = True,
detach_buffers: bool = False,
enable_visual: bool = False,
@@ -360,7 +360,7 @@ def module_clone(
*,
by: CopyMode = 'reference',
detach_buffers: bool = False,
- device: Optional[Union[int, str, torch.device]] = None,
+ device: Device = None,
) -> nn.Module:
...
@@ -371,7 +371,7 @@ def module_clone(
*,
by: CopyMode = 'reference',
detach_buffers: bool = False,
- device: Optional[Union[int, str, torch.device]] = None,
+ device: Device = None,
) -> 'MetaOptimizer':
...
@@ -382,7 +382,7 @@ def module_clone(
*,
by: CopyMode = 'reference',
detach_buffers: bool = False,
- device: Optional[Union[int, str, torch.device]] = None,
+ device: Device = None,
) -> TensorTree:
...
@@ -393,7 +393,7 @@ def module_clone(
*,
by: CopyMode = 'reference',
detach_buffers: bool = False,
- device: Optional[Union[int, str, torch.device]] = None,
+ device: Device = None,
) -> Union[nn.Module, 'MetaOptimizer', TensorTree]:
"""Clone a module.
pFad - Phonifier reborn
Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies:
Alternative Proxy
pFad Proxy
pFad v3 Proxy
pFad v4 Proxy