Skip to content

feat: functorch integration #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Sep 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
766ca2f
init
vmoens Apr 13, 2022
0635bc9
amend
vmoens Apr 13, 2022
e7ee747
amend
vmoens Apr 13, 2022
b9c0373
amend
vmoens Apr 13, 2022
589de77
feat: resolve conflicts and init func optimizer
Benjamin-eecs Aug 12, 2022
639c4dc
feat: init func optimizer
Benjamin-eecs Aug 12, 2022
4ea3fe5
feat: init func optimizer
Benjamin-eecs Aug 12, 2022
53e9194
feat: init func optimizer
Benjamin-eecs Aug 12, 2022
f0bdbdf
fix: keep the tensor types after update
XuehaiPan Aug 14, 2022
788fd7e
Merge branch 'main' into functorch_functional
XuehaiPan Aug 14, 2022
ced2114
Merge branch 'main' into functorch_functional
Benjamin-eecs Sep 5, 2022
1358386
Merge branch 'main' into functorch_functional
Benjamin-eecs Sep 7, 2022
5b94aab
fix: pass lint
Benjamin-eecs Sep 7, 2022
e2ff85f
Merge branch 'main' into functorch_functional
Benjamin-eecs Sep 9, 2022
d434103
Merge branch 'main' into functorch_functional
Benjamin-eecs Sep 10, 2022
f7c6858
fix: revert nn.Parameter fix
Benjamin-eecs Sep 11, 2022
3d39c30
chore: cleanup
XuehaiPan Sep 11, 2022
402e48e
feat: update `FuncOptimizer`
XuehaiPan Sep 11, 2022
ff4533c
docs: add docs for `FuncOptimizer`
XuehaiPan Sep 11, 2022
e376123
docs(CHANGELOG): update CHANGELOG.md
XuehaiPan Sep 11, 2022
847549b
chore: cleanup
XuehaiPan Sep 11, 2022
f9ab259
chore: cleanup
XuehaiPan Sep 11, 2022
8f8d3ef
chore: cleanup
XuehaiPan Sep 11, 2022
ad71291
chore: handle corner case
XuehaiPan Sep 11, 2022
474dee5
test: add tests for `FuncOptimizer`
XuehaiPan Sep 11, 2022
fa2a38c
chore: add type check
XuehaiPan Sep 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add wrapper class for functional optimizers and examples of `functorch` integration by [@vmoens](https://github.com/vmoens) and [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#6](https://github.com/metaopt/torchopt/pull/6).
- Implicit differentiation support by [@JieRen98](https://github.com/JieRen98) and [@waterhorse1](https://github.com/waterhorse1) and [@XuehaiPan](https://github.com/XuehaiPan) in [#41](https://github.com/metaopt/torchopt/pull/41).

### Changed
Expand Down
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,22 @@ updates, opt_state = optimizer.update(grads, opt_state) # get updates
params = torchopt.apply_updates(params, updates) # update network parameters
```

We also provide a wrapper `torchopt.FuncOptimizer` to make maintaining the optimizer state easier:

```python
net = Net() # init
loader = Loader()
optimizer = torchopt.FuncOptimizer(torchopt.adam()) # wrap with `torchopt.FuncOptimizer`

model, params = functorch.make_functional(net) # use functorch extract network parameters

for xs, ys in loader: # get data
pred = model(params, xs) # forward
loss = F.cross_entropy(pred, ys) # compute loss

params = optimizer.step(loss, params) # update network parameters
```

### PyTorch-Like API

We also offer origin PyTorch APIs (e.g. `zero_grad()` or `step()`) by wrapping our Optax-Like API for traditional PyTorch user:
Expand Down
7 changes: 7 additions & 0 deletions docs/source/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,18 @@ Functional Optimizers

.. autosummary::

FuncOptimizer
adam
sgd
rmsprop
adamw

Wrapper for Function Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: FuncOptimizer
:members:

Functional Adam Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
196 changes: 196 additions & 0 deletions examples/MAML-RL/func_maml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# Copyright 2022 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import argparse
from typing import NamedTuple

import functorch
import gym
import numpy as np
import torch
import torch.optim as optim

import torchopt
from helpers.policy import CategoricalMLPPolicy


TASK_NUM = 40
TRAJ_NUM = 20
TRAJ_LEN = 10

STATE_DIM = 10
ACTION_DIM = 5

GAMMA = 0.99
LAMBDA = 0.95

outer_iters = 500
inner_iters = 1


class Traj(NamedTuple):
obs: np.ndarray
acs: np.ndarray
next_obs: np.ndarray
rews: np.ndarray
gammas: np.ndarray


def sample_traj(env, task, fpolicy, params):
env.reset_task(task)
obs_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM, STATE_DIM), dtype=np.float32)
next_obs_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM, STATE_DIM), dtype=np.float32)
acs_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM), dtype=np.int8)
rews_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM), dtype=np.float32)
gammas_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM), dtype=np.float32)
with torch.no_grad():
for batch in range(TRAJ_NUM):
ob = env.reset()
for step in range(TRAJ_LEN):
ob_tensor = torch.from_numpy(ob)
pi, _ = fpolicy(params, ob_tensor)
ac_tensor = pi.sample()
ac = ac_tensor.cpu().numpy()
next_ob, rew, done, info = env.step(ac)

obs_buf[step][batch] = ob
next_obs_buf[step][batch] = next_ob
acs_buf[step][batch] = ac
rews_buf[step][batch] = rew
gammas_buf[step][batch] = done * GAMMA
ob = next_ob
return Traj(obs=obs_buf, acs=acs_buf, next_obs=next_obs_buf, rews=rews_buf, gammas=gammas_buf)


def a2c_loss(traj, fpolicy, params, value_coef):
lambdas = np.ones_like(traj.gammas) * LAMBDA
_, next_values = fpolicy(params, torch.from_numpy(traj.next_obs))
next_values = torch.squeeze(next_values, -1).detach().numpy()
# Work backwards to compute `G_{T-1}`, ..., `G_0`.
returns = []
g = next_values[-1, :]
for i in reversed(range(next_values.shape[0])):
g = traj.rews[i, :] + traj.gammas[i, :] * (
(1 - lambdas[i, :]) * next_values[i, :] + lambdas[i, :] * g
)
returns.insert(0, g)
lambda_returns = torch.from_numpy(np.array(returns))
pi, values = fpolicy(params, torch.from_numpy(traj.obs))
log_probs = pi.log_prob(torch.from_numpy(traj.acs))
advs = lambda_returns - torch.squeeze(values, -1)
action_loss = -(advs.detach() * log_probs).mean()
value_loss = advs.pow(2).mean()

loss = action_loss + value_coef * value_loss
return loss


def evaluate(env, seed, task_num, fpolicy, params):
pre_reward_ls = []
post_reward_ls = []
inner_opt = torchopt.MetaSGD(lr=0.5)
env = gym.make(
'TabularMDP-v0',
**dict(
num_states=STATE_DIM, num_actions=ACTION_DIM, max_episode_steps=TRAJ_LEN, seed=args.seed
),
)
tasks = env.sample_tasks(num_tasks=task_num)

for idx in range(task_num):
for _ in range(inner_iters):
pre_trajs = sample_traj(env, tasks[idx], fpolicy, params)

inner_loss = a2c_loss(pre_trajs, fpolicy, params, value_coef=0.5)
params = inner_opt.step(inner_loss, params)
post_trajs = sample_traj(env, tasks[idx], fpolicy, params)

# Logging
pre_reward_ls.append(np.sum(pre_trajs.rews, axis=0).mean())
post_reward_ls.append(np.sum(post_trajs.rews, axis=0).mean())

return pre_reward_ls, post_reward_ls


def main(args):
# init training
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
# Env
env = gym.make(
'TabularMDP-v0',
**dict(
num_states=STATE_DIM, num_actions=ACTION_DIM, max_episode_steps=TRAJ_LEN, seed=args.seed
),
)
# Policy
policy = CategoricalMLPPolicy(input_size=STATE_DIM, output_size=ACTION_DIM)
fpolicy, params = functorch.make_functional(policy)

inner_opt = torchopt.MetaSGD(lr=0.5)
outer_opt = optim.Adam(params, lr=1e-3)
train_pre_reward = []
train_post_reward = []
test_pre_reward = []
test_post_reward = []

for i in range(outer_iters):
tasks = env.sample_tasks(num_tasks=TASK_NUM)
train_pre_reward_ls = []
train_post_reward_ls = []

outer_opt.zero_grad()

param_orig = [p.detach().clone().requires_grad_() for p in params]
_params = list(params)
for idx in range(TASK_NUM):

for _ in range(inner_iters):
pre_trajs = sample_traj(env, tasks[idx], fpolicy, _params)
inner_loss = a2c_loss(pre_trajs, fpolicy, _params, value_coef=0.5)
_params = inner_opt.step(inner_loss, _params)
post_trajs = sample_traj(env, tasks[idx], fpolicy, _params)
outer_loss = a2c_loss(post_trajs, fpolicy, _params, value_coef=0.5)
outer_loss.backward()
_params = [p.detach().clone().requires_grad_() for p in param_orig]

# Logging
train_pre_reward_ls.append(np.sum(pre_trajs.rews, axis=0).mean())
train_post_reward_ls.append(np.sum(post_trajs.rews, axis=0).mean())
outer_opt.step()

test_pre_reward_ls, test_post_reward_ls = evaluate(
env, args.seed, TASK_NUM, fpolicy, params
)

train_pre_reward.append(sum(train_pre_reward_ls) / TASK_NUM)
train_post_reward.append(sum(train_post_reward_ls) / TASK_NUM)
test_pre_reward.append(sum(test_pre_reward_ls) / TASK_NUM)
test_post_reward.append(sum(test_post_reward_ls) / TASK_NUM)

print('Train_iters', i)
print('train_pre_reward', sum(train_pre_reward_ls) / TASK_NUM)
print('train_post_reward', sum(train_post_reward_ls) / TASK_NUM)
print('test_pre_reward', sum(test_pre_reward_ls) / TASK_NUM)
print('test_post_reward', sum(test_post_reward_ls) / TASK_NUM)


if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Reinforcement learning with Model-Agnostic Meta-Learning (MAML) - Train'
)
parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
args = parser.parse_args()
main(args)
5 changes: 3 additions & 2 deletions examples/MAML-RL/maml.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,9 @@ def a2c_loss(traj, policy, value_coef):
advs = lambda_returns - torch.squeeze(values, -1)
action_loss = -(advs.detach() * log_probs).mean()
value_loss = advs.pow(2).mean()
a2c_loss = action_loss + value_coef * value_loss
return a2c_loss

loss = action_loss + value_coef * value_loss
return loss


def evaluate(env, seed, task_num, policy):
Expand Down
56 changes: 55 additions & 1 deletion tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.
# ==============================================================================

from typing import Tuple
from typing import Callable, Tuple

import functorch
import pytest
import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -364,3 +365,56 @@ def test_RMSProp(
optim_ref.step()

helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype)


@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-2, 1e-3],
optimizers=[
(torchopt.sgd, torch.optim.SGD),
(torchopt.adam, torch.optim.Adam),
(torchopt.adamw, torch.optim.AdamW),
(torchopt.rmsprop, torch.optim.RMSprop),
],
inplace=[True, False],
weight_decay=[0.0, 1e-2],
)
def test_FuncOptimizer(
dtype: torch.dtype,
lr: float,
optimizers: Tuple[Callable, torch.optim.Optimizer],
inplace: bool,
weight_decay: float,
) -> None:
model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype)

torchopt_optimizer, torch_optimizer = optimizers

fmodel, params, buffers = functorch.make_functional_with_buffers(model)
optim = torchopt.FuncOptimizer(
torchopt_optimizer(
lr=lr,
weight_decay=weight_decay,
),
inplace=inplace,
)
optim_ref = torch_optimizer(
model_ref.parameters(),
lr,
weight_decay=weight_decay,
)

for xs, ys in loader:
xs = xs.to(dtype=dtype)
pred = fmodel(params, buffers, xs)
pred_ref = model_ref(xs)
loss = F.cross_entropy(pred, ys)
loss_ref = F.cross_entropy(pred_ref, ys)

params = optim.step(loss, params)

optim_ref.zero_grad()
loss_ref.backward()
optim_ref.step()

helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype)
2 changes: 2 additions & 0 deletions torchopt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torchopt._src.clip import clip_grad_norm
from torchopt._src.combine import chain
from torchopt._src.optimizer import SGD, Adam, AdamW, Optimizer, RMSProp, RMSprop, meta
from torchopt._src.optimizer.func import FuncOptimizer
from torchopt._src.optimizer.meta import (
MetaAdam,
MetaAdamW,
Expand Down Expand Up @@ -68,6 +69,7 @@
'MetaAdamW',
'MetaRMSProp',
'MetaRMSprop',
'FuncOptimizer',
'apply_updates',
'extract_state_dict',
'recover_state_dict',
Expand Down
12 changes: 12 additions & 0 deletions torchopt/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ def adam(

Returns:
The corresponding :class:`GradientTransformation` instance.

See Also:
The functional optimizer wrapper :class:`torchopt.FuncOptimizer`.
"""
b1, b2 = betas
# pylint: disable=unneeded-not
Expand Down Expand Up @@ -270,6 +273,9 @@ def adamw(

Returns:
The corresponding :class:`GradientTransformation` instance.

See Also:
The functional optimizer wrapper :class:`torchopt.FuncOptimizer`.
"""
b1, b2 = betas
# pylint: disable=unneeded-not
Expand Down Expand Up @@ -361,6 +367,9 @@ def rmsprop(

Returns:
The corresponding :class:`GradientTransformation` instance.

See Also:
The functional optimizer wrapper :class:`torchopt.FuncOptimizer`.
"""
# pylint: disable=unneeded-not
if not (callable(lr) or 0.0 <= lr):
Expand Down Expand Up @@ -437,6 +446,9 @@ def sgd(

Returns:
The corresponding :class:`GradientTransformation` instance.

See Also:
The functional optimizer wrapper :class:`torchopt.FuncOptimizer`.
"""
# pylint: disable=unneeded-not
if not (callable(lr) or 0.0 <= lr):
Expand Down
1 change: 1 addition & 0 deletions torchopt/_src/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@
from torchopt._src.optimizer.adam import Adam
from torchopt._src.optimizer.adamw import AdamW
from torchopt._src.optimizer.base import Optimizer
from torchopt._src.optimizer.func import FuncOptimizer
from torchopt._src.optimizer.rmsprop import RMSProp, RMSprop
from torchopt._src.optimizer.sgd import SGD
3 changes: 3 additions & 0 deletions torchopt/_src/optimizer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def __init__(self, params: Iterable[torch.Tensor], impl: GradientTransformation)
Note that using ``Optimizer(sgd())`` or ``Optimizer(chain(sgd()))`` is equivalent to
:class:`torchopt.SGD`.
"""
if not isinstance(impl, GradientTransformation):
raise TypeError(f'{impl} (type: {type(impl).__name__}) is not a GradientTransformation')

self.impl = impl
self.param_groups = [] # type: ignore
self.param_tree_groups = [] # type: ignore
Expand Down
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy