Skip to content

Examples of maml-omniglot using functorch.vmap with torchopt #39

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Aug 12, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
14f0652
feat(examples): add functorch vmap support examples
Benjamin-eecs Jul 19, 2022
4ebcf38
feat(examples): add functorch vmap support examples
Benjamin-eecs Jul 19, 2022
e959bc0
feat: working functorch + torchopt parallel training example
Benjamin-eecs Jul 21, 2022
c9ecd25
fix: pass lint
Benjamin-eecs Jul 21, 2022
d310aa1
feat(examples): add maml omniglot with functorch.vmap + torchopt example
Benjamin-eecs Jul 21, 2022
220c5f8
Merge remote-tracking branch 'upstream/main' into feature/functorch_v…
Benjamin-eecs Jul 22, 2022
a0d8466
merge: sync with main
Benjamin-eecs Jul 22, 2022
fcc9f3e
Merge remote-tracking branch 'upstream/main' into feature/functorch_v…
Benjamin-eecs Aug 6, 2022
05960bc
fix: pass lint
Benjamin-eecs Aug 6, 2022
045bff4
Merge branch 'main' into feature/maml-omniglot_vmap
XuehaiPan Aug 7, 2022
2b7ee7e
Merge branch 'main' into feature/maml-omniglot_vmap
XuehaiPan Aug 7, 2022
8a3ea6a
merge: resolve conflicts
Benjamin-eecs Aug 10, 2022
acba2e1
fix: update comment
Benjamin-eecs Aug 10, 2022
78556b2
Merge branch 'feature/functorch_vmap_support' into feature/maml-omnig…
Benjamin-eecs Aug 10, 2022
95c33ad
Merge branch 'feature/maml-omniglot_vmap' of https://github.com/Benja…
Benjamin-eecs Aug 10, 2022
ba4ddc1
fix: pass lint
Benjamin-eecs Aug 11, 2022
fc679ad
fix: resolve comments
Benjamin-eecs Aug 11, 2022
c3babe5
Merge branch 'feature/functorch_vmap_support' into feature/maml-omnig…
Benjamin-eecs Aug 11, 2022
05cc649
fix: resolve comments
Benjamin-eecs Aug 11, 2022
6d9109d
fix: resolve conflicts
Benjamin-eecs Aug 11, 2022
4b60252
fix: correct wrong comments
Benjamin-eecs Aug 11, 2022
48a553a
Merge branch 'metaopt:main' into feature/maml-omniglot_vmap
Benjamin-eecs Aug 12, 2022
28848bb
chore: minor fix
Benjamin-eecs Aug 12, 2022
62ed013
chore: minor fix
Benjamin-eecs Aug 12, 2022
0defbd6
chore: update CHANGELOG
Benjamin-eecs Aug 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: working functorch + torchopt parallel training example
  • Loading branch information
Benjamin-eecs committed Jul 21, 2022
commit e959bc0fb7e23015dfc3e1f7708c8734979dbb8a
164 changes: 0 additions & 164 deletions examples/functorch/parallel_train.py

This file was deleted.

171 changes: 101 additions & 70 deletions examples/functorch/parallel_train_torchopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from collections import namedtuple
from typing import Any, NamedTuple

import functorch
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -56,71 +57,89 @@ def forward(self, x):
return x


class Net(nn.Module):
def __init__(self, dim):
super().__init__()
self.fc = nn.Linear(dim, 1, bias=True)
nn.init.ones_(self.fc.weight)
nn.init.zeros_(self.fc.bias)

def forward(self, x):
return self.fc(x)


def train_step_fn(training_state, batch, targets):
weights, opt_state = training_state

def compute_loss(weights, batch, targets):
output = func_model(weights, batch)
loss = loss_fn(output, targets)
return loss

grads, loss = grad_and_value(compute_loss)(weights, batch, targets)

# functional optimizer API is here now
# new_opt_state0 = opt_state[0]._asdict()
# for k, v in new_opt_state0.items():
# if type(v) is tuple:
# new_opt_state0[k] = tuple(v_el.clone() for v_el in v)
# new_opt_state = (opt_state[0]._make(new_opt_state0.values()), opt_state[1])

updates, new_opt_state = optimizer.update(grads, opt_state)
new_weights = torchopt.apply_updates(weights, updates)
# Default `inplace=True` gave me an error
# weights = torchopt.apply_updates(weights, updates, inplace=False)
return loss, (new_weights, new_opt_state)


def step4(weights, opt_state):
for i in range(2000):
loss, (weights, opt_state) = train_step_fn((weights, opt_state), points, labels)
if i % 100 == 0:
print(loss)


def init_fn(model_idx):
print(model_idx)
# models = [MLPClassifier().to(DEVICE) for _ in range(model_idx)]
# print(len(models))
# print(models)
# _, weights, _ = combine_state_for_ensemble(models)
# print(weights)
_, weights = make_functional(Net(4).to(DEVICE))
opt_state = optimizer.init(weights)
print(weights)
# print(opt_state)
print(opt_state)
return weights, opt_state


def step6(num_models):
parallel_init_fn = vmap(init_fn, randomness='same')
parallel_train_step_fn = vmap(train_step_fn, in_dims=(0, None, None))
weights, opt_state = parallel_init_fn(torch.ones(num_models, 1))
for i in range(2000):
loss, (weights, opt_states) = parallel_train_step_fn((weights, opt_state), points, labels)
if i % 200 == 0:
print(loss)
class ParallelTrainFunctorchOriginal:
def __init__(self, loss_fn, lr):
self.loss_fn = loss_fn
self.lr = lr
self.func_model, _ = make_functional(MLPClassifier().to(DEVICE))

def init_fn(self, num_models):
models = [MLPClassifier().to(DEVICE) for _ in range(num_models)]
_, batched_weights, _ = combine_state_for_ensemble(models)
return batched_weights

def train_step_fn(self, weights, batch, targets):
def compute_loss(weights, batch, targets):
output = self.func_model(weights, batch)
loss = self.loss_fn(output, targets)
return loss

grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets)
# NB: PyTorch is missing a "functional optimizer API" (possibly coming soon)
# so we are going to re-implement SGD here.
new_weights = []
with torch.no_grad():
for grad_weight, weight in zip(grad_weights, weights):
new_weights.append(weight - grad_weight * self.lr)

return loss, new_weights

def test_train_step_fn(self, weights, points, labels):
for i in range(2000):
loss, weights = self.train_step_fn(weights, points, labels)
if i % 100 == 0:
print(loss)

def test_parallel_train_step_fn(self, num_models):
parallel_train_step_fn = vmap(self.train_step_fn, in_dims=(0, None, None))
batched_weights = self.init_fn(num_models=num_models)
for i in range(2000):
loss, batched_weights = parallel_train_step_fn(batched_weights, points, labels)
if i % 200 == 0:
print(loss)


class ParallelTrainFunctorchTorchOpt:
def __init__(self, loss_fn, optimizer):
self.loss_fn = loss_fn
self.optimizer = optimizer
self.func_model, _ = make_functional(MLPClassifier().to(DEVICE))

def init_fn(self, model_idx):
_, weights = make_functional(MLPClassifier().to(DEVICE))
opt_state = self.optimizer.init(weights)
return weights, opt_state

def train_step_fn(self, training_state, batch, targets):
weights, opt_state = training_state

def compute_loss(weights, batch, targets):
output = self.func_model(weights, batch)
loss = self.loss_fn(output, targets)
return loss

grads, loss = grad_and_value(compute_loss)(weights, batch, targets)
# functional optimizer API is here now
updates, new_opt_state = optimizer.update(grads, opt_state, inplace=False)
new_weights = torchopt.apply_updates(weights, updates, inplace=False)
return loss, (new_weights, new_opt_state)

def test_train_step_fn(self, weights, opt_state, points, labels):
for i in range(2000):
loss, (weights, opt_state) = self.train_step_fn((weights, opt_state), points, labels)
if i % 100 == 0:
print(loss)

def test_parallel_train_step_fn(self, num_models):
parallel_init_fn = vmap(self.init_fn, randomness='same')
parallel_train_step_fn = vmap(self.train_step_fn, in_dims=(0, None, None))
weights, opt_state = parallel_init_fn(torch.ones(num_models, 1))
for i in range(2000):
loss, (weights, opt_states) = parallel_train_step_fn(
(weights, opt_state), points, labels
)
if i % 200 == 0:
print(loss)


if __name__ == '__main__':
Expand All @@ -136,7 +155,7 @@ def step6(num_models):
# }

# GOAL: Demonstrate that it is possible to use eager-mode vmap
# to parallelize training over models.

parser = argparse.ArgumentParser(description='Functorch Ensembled Models with TorchOpt')
parser.add_argument(
'--device',
Expand All @@ -153,17 +172,29 @@ def step6(num_models):
loss_fn = nn.NLLLoss()
# Step 3: Make the model functional(!!) and define a training function.
func_model, weights = make_functional(MLPClassifier().to(DEVICE))

# original functorch implementation
functorch_original = ParallelTrainFunctorchOriginal(loss_fn=loss_fn, lr=0.2)
# Step 4: Let's verify this actually trains.
# We should see the loss decrease.
functorch_original.test_train_step_fn(weights, points, labels)
# Step 6: Now, can we try multiple models at the same time?
# The answer is: yes! `loss` is a 2-tuple, and we can see that the value keeps
# on decreasing
functorch_original.test_parallel_train_step_fn(num_models=2)

# functorch + torchopt implementation
optimizer = torchopt.adam(lr=0.2)
opt_state = optimizer.init(weights)
functorch_original = ParallelTrainFunctorchTorchOpt(loss_fn=loss_fn, optimizer=optimizer)
# Step 4: Let's verify this actually trains.
# We should see the loss decrease.
step4(weights, opt_state)
# Step 5: We're ready for multiple models. Let's define an init_fn
# that, given a number of models, returns to us all of the weights.
functorch_original.test_train_step_fn(weights, opt_state, points, labels)
# Step 6: Now, can we try multiple models at the same time?
# The answer is: yes! `loss` is a 2-tuple, and we can see that the value keeps
# on decreasing
step6(5)
functorch_original.test_parallel_train_step_fn(num_models=2)

# Step 7: Now, the flaw with step 6 is that we were training on the same exact
# data. This can lead to all of the models in the ensemble overfitting in the
# same way. The solution that http://willwhitney.com/parallel-training-jax.html
Expand Down
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy