Skip to content

Examples of maml-omniglot using functorch.vmap with torchopt #39

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Aug 12, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
14f0652
feat(examples): add functorch vmap support examples
Benjamin-eecs Jul 19, 2022
4ebcf38
feat(examples): add functorch vmap support examples
Benjamin-eecs Jul 19, 2022
e959bc0
feat: working functorch + torchopt parallel training example
Benjamin-eecs Jul 21, 2022
c9ecd25
fix: pass lint
Benjamin-eecs Jul 21, 2022
d310aa1
feat(examples): add maml omniglot with functorch.vmap + torchopt example
Benjamin-eecs Jul 21, 2022
220c5f8
Merge remote-tracking branch 'upstream/main' into feature/functorch_v…
Benjamin-eecs Jul 22, 2022
a0d8466
merge: sync with main
Benjamin-eecs Jul 22, 2022
fcc9f3e
Merge remote-tracking branch 'upstream/main' into feature/functorch_v…
Benjamin-eecs Aug 6, 2022
05960bc
fix: pass lint
Benjamin-eecs Aug 6, 2022
045bff4
Merge branch 'main' into feature/maml-omniglot_vmap
XuehaiPan Aug 7, 2022
2b7ee7e
Merge branch 'main' into feature/maml-omniglot_vmap
XuehaiPan Aug 7, 2022
8a3ea6a
merge: resolve conflicts
Benjamin-eecs Aug 10, 2022
acba2e1
fix: update comment
Benjamin-eecs Aug 10, 2022
78556b2
Merge branch 'feature/functorch_vmap_support' into feature/maml-omnig…
Benjamin-eecs Aug 10, 2022
95c33ad
Merge branch 'feature/maml-omniglot_vmap' of https://github.com/Benja…
Benjamin-eecs Aug 10, 2022
ba4ddc1
fix: pass lint
Benjamin-eecs Aug 11, 2022
fc679ad
fix: resolve comments
Benjamin-eecs Aug 11, 2022
c3babe5
Merge branch 'feature/functorch_vmap_support' into feature/maml-omnig…
Benjamin-eecs Aug 11, 2022
05cc649
fix: resolve comments
Benjamin-eecs Aug 11, 2022
6d9109d
fix: resolve conflicts
Benjamin-eecs Aug 11, 2022
4b60252
fix: correct wrong comments
Benjamin-eecs Aug 11, 2022
48a553a
Merge branch 'metaopt:main' into feature/maml-omniglot_vmap
Benjamin-eecs Aug 12, 2022
28848bb
chore: minor fix
Benjamin-eecs Aug 12, 2022
62ed013
chore: minor fix
Benjamin-eecs Aug 12, 2022
0defbd6
chore: update CHANGELOG
Benjamin-eecs Aug 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat(examples): add maml omniglot with functorch.vmap + torchopt example
  • Loading branch information
Benjamin-eecs committed Jul 21, 2022
commit d310aa13d54d75ab449604c116840e68714f1a83
3 changes: 1 addition & 2 deletions examples/few-shot/maml_omniglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,10 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from support.omniglot_loaders import OmniglotNShot

import torchopt

from .support.omniglot_loaders import OmniglotNShot


mpl.use('Agg')
plt.style.use('bmh')
Expand Down
289 changes: 289 additions & 0 deletions examples/functorch/maml_omniglot_vmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
# Copyright 2022 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# This file is modified from:
# https://github.com/facebookresearch/higher/blob/main/examples/maml-omniglot.py
# ==============================================================================
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This example shows how to use higher to do Model Agnostic Meta Learning (MAML)
for few-shot Omniglot classification.
For more details see the original MAML paper:
https://arxiv.org/abs/1703.03400
This code has been modified from Jackie Loong's PyTorch MAML implementation:
https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py
Our MAML++ fork and experiments are available at:
https://github.com/bamos/HowToTrainYourMAMLPytorch
"""


import os
import sys


cur = os.path.abspath(os.path.dirname(__file__))
root = os.path.split(cur)[0]
sys.path.append(root + '/few-shot')
import argparse
import functools
import time

import functorch
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torch.optim as optim
from functorch import grad, make_functional_with_buffers, vmap
from support.omniglot_loaders import OmniglotNShot
from torch import nn

import torchopt


mpl.use('Agg')
plt.style.use('bmh')


# Squash the warning spam
functorch._C._set_vmap_fallback_warning_enabled(False)


def main():
argparser = argparse.ArgumentParser()
argparser.add_argument('--n_way', type=int, help='n way', default=5)
argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=5)
argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15)
argparser.add_argument('--device', type=str, help='device', default='cuda')
argparser.add_argument(
'--task_num', type=int, help='meta batch size, namely task num', default=32
)
argparser.add_argument('--seed', type=int, help='random seed', default=1)
args = argparser.parse_args()

torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
rng = np.random.default_rng(args.seed)

# Set up the Omniglot loader.
device = args.device
db = OmniglotNShot(
'/tmp/omniglot-data',
batchsz=args.task_num,
n_way=args.n_way,
k_shot=args.k_spt,
k_query=args.k_qry,
imgsz=28,
rng=rng,
device=device,
)

# Create a vanilla PyTorch neural network.
inplace_relu = True
net = nn.Sequential(
nn.Conv2d(1, 64, 3),
nn.BatchNorm2d(64, affine=True, track_running_stats=False),
nn.ReLU(inplace=inplace_relu),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64, affine=True, track_running_stats=False),
nn.ReLU(inplace=inplace_relu),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64, affine=True, track_running_stats=False),
nn.ReLU(inplace=inplace_relu),
nn.MaxPool2d(2, 2),
nn.Flatten(),
nn.Linear(64, args.n_way),
).to(device)

net.train()

# Given this module we've created, rip out the parameters and buffers
# and return a functional version of the module. `fnet` is stateless
# and can be called with `fnet(params, buffers, args, kwargs)`
fnet, params, buffers = make_functional_with_buffers(net)

# We will use Adam to (meta-)optimize the initial parameters
# to be adapted.
meta_opt = optim.Adam(params, lr=1e-3)

log = []
for epoch in range(100):
train(db, [params, buffers, fnet], device, meta_opt, epoch, log)
test(db, [params, buffers, fnet], device, epoch, log)
plot(log)


# Trains a model for n_inner_iter using the support and returns a loss
# using the query.
def loss_for_task(net, n_inner_iter, x_spt, y_spt, x_qry, y_qry):
params, buffers, fnet = net
opt = torchopt.sgd(lr=1e-1)
opt_state = opt.init(params)

querysz = x_qry.size(0)

def compute_loss(new_params, buffers, x, y):
logits = fnet(new_params, buffers, x)
loss = F.cross_entropy(logits, y)
return loss

new_params = params
# change sgd implementation with torchopt
for _ in range(n_inner_iter):
grads = grad(compute_loss)(new_params, buffers, x_spt, y_spt)
updates, opt_state = opt.update(grads, opt_state, inplace=False)
new_params = torchopt.apply_updates(new_params, updates, inplace=False)

# The final set of adapted parameters will induce some
# final loss and accuracy on the query dataset.
# These will be used to update the model's meta-parameters.
qry_logits = fnet(new_params, buffers, x_qry)
qry_loss = F.cross_entropy(qry_logits, y_qry)
qry_acc = (qry_logits.argmax(dim=1) == y_qry).sum() / querysz

return qry_loss, qry_acc


def train(db, net, device, meta_opt, epoch, log):
params, buffers, fnet = net
n_train_iter = db.x_train.shape[0] // db.batchsz

for batch_idx in range(n_train_iter):
start_time = time.time()
# Sample a batch of support and query images and labels.
x_spt, y_spt, x_qry, y_qry = db.next()

task_num, setsz, c_, h, w = x_spt.size()

n_inner_iter = 5
meta_opt.zero_grad()

# In parallel, trains one model per task. There is a support (x, y)
# for each task and a query (x, y) for each task.
compute_loss_for_task = functools.partial(loss_for_task, net, n_inner_iter)
qry_losses, qry_accs = vmap(compute_loss_for_task)(x_spt, y_spt, x_qry, y_qry)

# Compute the maml loss by summing together the returned losses.
qry_losses.sum().backward()

meta_opt.step()
qry_losses = qry_losses.detach().sum() / task_num
qry_accs = 100.0 * qry_accs.sum() / task_num
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
if batch_idx % 4 == 0:
print(
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
)

log.append(
{
'epoch': i,
'loss': qry_losses,
'acc': qry_accs,
'mode': 'train',
'time': time.time(),
}
)


def test(db, net, device, epoch, log):
# Crucially in our testing procedure here, we do *not* fine-tune
# the model during testing for simplicity.
# Most research papers using MAML for this task do an extra
# stage of fine-tuning here that should be added if you are
# adapting this code for research.
[params, buffers, fnet] = net
n_test_iter = db.x_test.shape[0] // db.batchsz

qry_losses = []
qry_accs = []

for batch_idx in range(n_test_iter):
x_spt, y_spt, x_qry, y_qry = db.next('test')
task_num, setsz, c_, h, w = x_spt.size()

# TODO: Maybe pull this out into a separate module so it
# doesn't have to be duplicated between `train` and `test`?
n_inner_iter = 5

for i in range(task_num):
new_params = params
for _ in range(n_inner_iter):
spt_logits = fnet(new_params, buffers, x_spt[i])
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
grads = torch.autograd.grad(spt_loss, new_params)
new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)]

# The query loss and acc induced by these parameters.
qry_logits = fnet(new_params, buffers, x_qry[i]).detach()
qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction='none')
qry_losses.append(qry_loss.detach())
qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach())

qry_losses = torch.cat(qry_losses).mean().item()
qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item()
print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}')
log.append(
{
'epoch': epoch + 1,
'loss': qry_losses,
'acc': qry_accs,
'mode': 'test',
'time': time.time(),
}
)


def plot(log):
# Generally you should pull your plotting code out of your training
# script but we are doing it here for brevity.
df = pd.DataFrame(log)

fig, ax = plt.subplots(figsize=(6, 4))
train_df = df[df['mode'] == 'train']
test_df = df[df['mode'] == 'test']
ax.plot(train_df['epoch'], train_df['acc'], label='Train')
ax.plot(test_df['epoch'], test_df['acc'], label='Test')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
ax.set_ylim(70, 100)
fig.legend(ncol=2, loc='lower right')
fig.tight_layout()
fname = 'maml-accs.png'
print(f'--- Plotting accuracy to {fname}')
fig.savefig(fname)
plt.close(fig)


if __name__ == '__main__':
main()
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy