Skip to content

Commit 111ddb3

Browse files
committed
feat(distributed): RPC-based distributed training support and add distributed MAML example
1 parent af6d24c commit 111ddb3

File tree

19 files changed

+1656
-430
lines changed

19 files changed

+1656
-430
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313

1414
### Added
1515

16+
- Add RPC-based distributed training support and add distributed MAML example by [@XuehaiPan](https://github.com/XuehaiPan) in [#83](https://github.com/metaopt/torchopt/pull/83).
1617
- Add full type hints by [@XuehaiPan](https://github.com/XuehaiPan) in [#92](https://github.com/metaopt/torchopt/pull/92).
1718
- Add API documentation and tutorial for implicit gradients by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@JieRen98](https://github.com/JieRen98) and [@XuehaiPan](https://github.com/XuehaiPan) in [#73](https://github.com/metaopt/torchopt/pull/73).
1819
- Add wrapper class for functional optimizers and examples of `functorch` integration by [@vmoens](https://github.com/vmoens) and [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#6](https://github.com/metaopt/torchopt/pull/6).

conda-recipe.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ dependencies:
5151
- seaborn
5252
- python-graphviz
5353
- pillow
54+
- setproctitle
5455

5556
# Documentation
5657
- sphinx

docs/source/spelling_wordlist.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ iterable
5656
nan
5757
param
5858
Graphviz
59+
Autograd
5960
autograd
6061
attrs
6162
GradientTransformations
@@ -84,3 +85,5 @@ argnums
8485
matvec
8586
Hermitian
8687
deepcopy
88+
RRef
89+
rref
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# MAML few-shot Omniglot classification-examples
2+
3+
Code on MAML few-shot Omniglot classification in paper [Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks](https://arxiv.org/abs/1703.03400) using TorchOpt. We use `MetaSGD` as the inner-loop optimizer.
4+
5+
## Usage
6+
7+
```bash
8+
### Run
9+
torchrun --nnode 1 --nproc_per_node 8 maml_omniglot.py
10+
```
11+
12+
## Results
13+
14+
The figure illustrate the experimental result.
15+
16+
<div align=center>
17+
<img src="./maml-accs.png" width="800" />
18+
</div>
154 KB
Loading
Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
# Copyright 2022 MetaOPT Team. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
# This file is modified from:
16+
# https://github.com/facebookresearch/higher/blob/main/examples/maml-omniglot.py
17+
# ==============================================================================
18+
# Copyright (c) Facebook, Inc. and its affiliates.
19+
#
20+
# Licensed under the Apache License, Version 2.0 (the "License");
21+
# you may not use this file except in compliance with the License.
22+
# You may obtain a copy of the License at
23+
#
24+
# http://www.apache.org/licenses/LICENSE-2.0
25+
#
26+
# Unless required by applicable law or agreed to in writing, software
27+
# distributed under the License is distributed on an "AS IS" BASIS,
28+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29+
# See the License for the specific language governing permissions and
30+
# limitations under the License.
31+
"""
32+
This example shows how to use TorchOpt to do Model Agnostic Meta Learning (MAML)
33+
for few-shot Omniglot classification.
34+
For more details see the original MAML paper:
35+
https://arxiv.org/abs/1703.03400
36+
This code has been modified from Jackie Loong's PyTorch MAML implementation:
37+
https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py
38+
Our MAML++ fork and experiments are available at:
39+
https://github.com/bamos/HowToTrainYourMAMLPytorch
40+
"""
41+
42+
import argparse
43+
import os
44+
import random
45+
import time
46+
47+
import matplotlib as mpl
48+
import matplotlib.pyplot as plt
49+
import numpy as np
50+
import pandas as pd
51+
import torch
52+
import torch.nn as nn
53+
import torch.nn.functional as F
54+
import torch.optim as optim
55+
from setproctitle import getproctitle, setproctitle
56+
57+
import torchopt
58+
import torchopt.distributed as todist
59+
60+
61+
from support.omniglot_loaders import OmniglotNShot # isort: skip
62+
63+
64+
mpl.use('Agg')
65+
plt.style.use('bmh')
66+
67+
68+
def worker_init():
69+
world_info = todist.get_world_info()
70+
71+
proctitle = f'{world_info.worker_name}: {getproctitle().strip()}'
72+
print(f'worker_init => {proctitle}')
73+
setproctitle(proctitle)
74+
75+
seed = world_info.local_rank
76+
77+
os.environ['PYTHONHASHSEED'] = str(seed)
78+
79+
random.seed(seed)
80+
np.random.seed(seed)
81+
82+
torch.manual_seed(seed)
83+
torch.cuda.manual_seed(seed)
84+
torch.cuda.manual_seed_all(seed)
85+
86+
if world_info.local_rank < torch.cuda.device_count():
87+
torch.cuda.set_device(world_info.local_rank)
88+
89+
90+
def build_model(args, device):
91+
return nn.Sequential(
92+
nn.Conv2d(1, 64, 3),
93+
nn.BatchNorm2d(64, momentum=1.0, affine=True),
94+
nn.ReLU(inplace=False),
95+
nn.MaxPool2d(2, 2),
96+
nn.Conv2d(64, 64, 3),
97+
nn.BatchNorm2d(64, momentum=1.0, affine=True),
98+
nn.ReLU(inplace=False),
99+
nn.MaxPool2d(2, 2),
100+
nn.Conv2d(64, 64, 3),
101+
nn.BatchNorm2d(64, momentum=1.0, affine=True),
102+
nn.ReLU(inplace=False),
103+
nn.MaxPool2d(2, 2),
104+
nn.Flatten(),
105+
nn.Linear(64, args.n_way),
106+
).to(device)
107+
108+
109+
@todist.rank_zero_only
110+
def get_data_loader(args, device):
111+
rng = np.random.default_rng(args.seed)
112+
113+
return OmniglotNShot(
114+
'/tmp/omniglot-data',
115+
batchsz=args.task_num,
116+
n_way=args.n_way,
117+
k_shot=args.k_spt,
118+
k_query=args.k_qry,
119+
imgsz=28,
120+
rng=rng,
121+
device=device,
122+
)
123+
124+
125+
@todist.auto_init_rpc(worker_init)
126+
def main():
127+
argparser = argparse.ArgumentParser()
128+
argparser.add_argument('--n_way', type=int, help='n way', default=5)
129+
argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=5)
130+
argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15)
131+
argparser.add_argument(
132+
'--task_num', type=int, help='meta batch size, namely task num', default=32
133+
)
134+
argparser.add_argument('--seed', type=int, help='random seed', default=1)
135+
args = argparser.parse_args()
136+
137+
torch.manual_seed(args.seed)
138+
if torch.cuda.is_available():
139+
torch.cuda.manual_seed_all(args.seed)
140+
torch.backends.cudnn.benchmark = False
141+
torch.backends.cudnn.deterministic = True
142+
np.random.seed(args.seed)
143+
144+
# Set up the Omniglot loader.
145+
db = get_data_loader(args, device=torch.device('cpu'))
146+
147+
# Create a vanilla PyTorch neural network.
148+
net = build_model(args, device=torch.device('cpu'))
149+
150+
# We will use Adam to (meta-)optimize the initial parameters
151+
# to be adapted.
152+
meta_opt = optim.Adam(net.parameters(), lr=1e-3)
153+
154+
log = []
155+
test(db, net, epoch=-1, log=log)
156+
for epoch in range(10):
157+
train(db, net, meta_opt, epoch=epoch, log=log)
158+
test(db, net, epoch=epoch, log=log)
159+
plot(log)
160+
161+
162+
def transpose_mean_reducer(results):
163+
qry_losses, qry_accs = tuple(zip(*results))
164+
qry_loss = torch.mean(torch.stack(qry_losses))
165+
qry_acc = np.mean(qry_accs)
166+
return qry_loss, qry_acc
167+
168+
169+
@todist.parallelize(
170+
partitioner=todist.dim_partitioner(dim=0, exclusive=True, keepdim=False),
171+
reducer=transpose_mean_reducer,
172+
)
173+
def inner_loop(net_rref, x_spt, y_spt, x_qry, y_qry, n_inner_iter):
174+
if torch.cuda.is_available():
175+
device = torch.device(f'cuda:{todist.get_local_rank() % torch.cuda.device_count()}')
176+
torch.cuda.set_device(device)
177+
else:
178+
device = None
179+
180+
original_net = net_rref.to_here()
181+
net = torchopt.module_clone(original_net, by='reference', device=device)
182+
if device is not None:
183+
x_spt = x_spt.to(device)
184+
y_spt = y_spt.to(device)
185+
x_qry = x_qry.to(device)
186+
y_qry = y_qry.to(device)
187+
188+
querysz = x_qry.size(0)
189+
inner_opt = torchopt.MetaSGD(net, lr=1e-1)
190+
191+
for _ in range(n_inner_iter):
192+
spt_logits = net(x_spt)
193+
spt_loss = F.cross_entropy(spt_logits, y_spt)
194+
inner_opt.step(spt_loss)
195+
196+
qry_logits = net(x_qry)
197+
qry_loss = F.cross_entropy(qry_logits, y_qry).cpu()
198+
qry_acc = (qry_logits.argmax(dim=1) == y_qry).sum().cpu().item() / querysz
199+
200+
return qry_loss, qry_acc
201+
202+
203+
@todist.rank_zero_only
204+
def train(db: OmniglotNShot, net: nn.Module, meta_opt: optim.Adam, epoch: int, log: list):
205+
net.train()
206+
n_train_iter = db.x_train.shape[0] // db.batchsz
207+
208+
net_rref = todist.rpc.RRef(net)
209+
for batch_idx in range(n_train_iter):
210+
start_time = time.time()
211+
# Sample a batch of support and query images and labels.
212+
x_spt, y_spt, x_qry, y_qry = db.next()
213+
214+
# TODO: Maybe pull this out into a separate module so it
215+
# doesn't have to be duplicated between `train` and `test`?
216+
217+
# Initialize the inner optimizer to adapt the parameters to
218+
# the support set.
219+
n_inner_iter = 5
220+
221+
meta_opt.zero_grad()
222+
with todist.autograd.context() as context_id:
223+
qry_loss, qry_acc = inner_loop(net_rref, x_spt, y_spt, x_qry, y_qry, n_inner_iter)
224+
todist.autograd.backward(context_id, qry_loss)
225+
meta_opt.step()
226+
227+
qry_loss = qry_loss.item()
228+
qry_acc = 100.0 * qry_acc
229+
i = epoch + float(batch_idx) / n_train_iter
230+
iter_time = time.time() - start_time
231+
232+
print(
233+
f'[Epoch {i:.2f}] Train Loss: {qry_loss:.2f} | Acc: {qry_acc:.2f} | Time: {iter_time:.2f}'
234+
)
235+
236+
log.append(
237+
{
238+
'epoch': i,
239+
'loss': qry_loss,
240+
'acc': qry_acc,
241+
'mode': 'train',
242+
'time': time.time(),
243+
}
244+
)
245+
246+
247+
@todist.rank_zero_only
248+
def test(db, net, epoch, log):
249+
# Crucially in our testing procedure here, we do *not* fine-tune
250+
# the model during testing for simplicity.
251+
# Most research papers using MAML for this task do an extra
252+
# stage of fine-tuning here that should be added if you are
253+
# adapting this code for research.
254+
net.train()
255+
n_test_iter = db.x_test.shape[0] // db.batchsz
256+
257+
qry_losses = []
258+
qry_accs = []
259+
260+
net_rref = todist.rpc.RRef(net)
261+
for _ in range(n_test_iter):
262+
x_spt, y_spt, x_qry, y_qry = db.next('test')
263+
264+
# TODO: Maybe pull this out into a separate module so it
265+
# doesn't have to be duplicated between `train` and `test`?
266+
n_inner_iter = 5
267+
268+
qry_loss, qry_acc = inner_loop(net_rref, x_spt, y_spt, x_qry, y_qry, n_inner_iter)
269+
qry_losses.append(qry_loss.item())
270+
qry_accs.append(qry_acc)
271+
272+
qry_losses = np.mean(qry_losses)
273+
qry_accs = 100.0 * np.mean(qry_accs)
274+
print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}')
275+
log.append(
276+
{
277+
'epoch': epoch + 1,
278+
'loss': qry_losses,
279+
'acc': qry_accs,
280+
'mode': 'test',
281+
'time': time.time(),
282+
}
283+
)
284+
285+
286+
@todist.rank_zero_only
287+
def plot(log):
288+
# Generally you should pull your plotting code out of your training
289+
# script but we are doing it here for brevity.
290+
df = pd.DataFrame(log)
291+
292+
fig, ax = plt.subplots(figsize=(8, 4), dpi=250)
293+
train_df = df[df['mode'] == 'train']
294+
test_df = df[df['mode'] == 'test']
295+
ax.plot(train_df['epoch'], train_df['acc'], label='Train')
296+
ax.plot(test_df['epoch'], test_df['acc'], label='Test')
297+
ax.set_xlabel('Epoch')
298+
ax.set_ylabel('Accuracy')
299+
ax.set_ylim(85, 100)
300+
ax.set_title('Distributed MAML Omniglot')
301+
ax.legend(ncol=2, loc='lower right')
302+
fig.tight_layout()
303+
fname = 'maml-accs.png'
304+
print(f'--- Plotting accuracy to {fname}')
305+
fig.savefig(fname)
306+
plt.close(fig)
307+
308+
309+
if __name__ == '__main__':
310+
main()

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy