Skip to content

Commit 48c660c

Browse files
authored
Merge branch 'main' into docs/implicit_gradient
2 parents 04d169a + ac2f0db commit 48c660c

File tree

16 files changed

+508
-35
lines changed

16 files changed

+508
-35
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313

1414
### Added
1515

16+
- Add wrapper class for functional optimizers and examples of `functorch` integration by [@vmoens](https://github.com/vmoens) and [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#6](https://github.com/metaopt/torchopt/pull/6).
1617
- Implicit differentiation support by [@JieRen98](https://github.com/JieRen98) and [@waterhorse1](https://github.com/waterhorse1) and [@XuehaiPan](https://github.com/XuehaiPan) in [#41](https://github.com/metaopt/torchopt/pull/41).
1718

1819
### Changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ PYTHON ?= $(shell command -v python3 || command -v python)
1414
default: install
1515

1616
install:
17-
$(PYTHON) -m pip install .
17+
$(PYTHON) -m pip install -vvv .
1818

1919
install-editable:
2020
$(PYTHON) -m pip install --upgrade pip

README.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,22 @@ updates, opt_state = optimizer.update(grads, opt_state) # get updates
7777
params = torchopt.apply_updates(params, updates) # update network parameters
7878
```
7979

80+
We also provide a wrapper `torchopt.FuncOptimizer` to make maintaining the optimizer state easier:
81+
82+
```python
83+
net = Net() # init
84+
loader = Loader()
85+
optimizer = torchopt.FuncOptimizer(torchopt.adam()) # wrap with `torchopt.FuncOptimizer`
86+
87+
model, params = functorch.make_functional(net) # use functorch extract network parameters
88+
89+
for xs, ys in loader: # get data
90+
pred = model(params, xs) # forward
91+
loss = F.cross_entropy(pred, ys) # compute loss
92+
93+
params = optimizer.step(loss, params) # update network parameters
94+
```
95+
8096
### PyTorch-Like API
8197

8298
We also offer origin PyTorch APIs (e.g. `zero_grad()` or `step()`) by wrapping our Optax-Like API for traditional PyTorch user:

docs/source/api/api.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,18 @@ Functional Optimizers
2929

3030
.. autosummary::
3131

32+
FuncOptimizer
3233
adam
3334
sgd
3435
rmsprop
3536
adamw
3637

38+
Wrapper for Function Optimizer
39+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
40+
41+
.. autoclass:: FuncOptimizer
42+
:members:
43+
3744
Functional Adam Optimizer
3845
~~~~~~~~~~~~~~~~~~~~~~~~~
3946

examples/MAML-RL/func_maml.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
# Copyright 2022 MetaOPT Team. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
import argparse
17+
from typing import NamedTuple
18+
19+
import functorch
20+
import gym
21+
import numpy as np
22+
import torch
23+
import torch.optim as optim
24+
25+
import torchopt
26+
from helpers.policy import CategoricalMLPPolicy
27+
28+
29+
TASK_NUM = 40
30+
TRAJ_NUM = 20
31+
TRAJ_LEN = 10
32+
33+
STATE_DIM = 10
34+
ACTION_DIM = 5
35+
36+
GAMMA = 0.99
37+
LAMBDA = 0.95
38+
39+
outer_iters = 500
40+
inner_iters = 1
41+
42+
43+
class Traj(NamedTuple):
44+
obs: np.ndarray
45+
acs: np.ndarray
46+
next_obs: np.ndarray
47+
rews: np.ndarray
48+
gammas: np.ndarray
49+
50+
51+
def sample_traj(env, task, fpolicy, params):
52+
env.reset_task(task)
53+
obs_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM, STATE_DIM), dtype=np.float32)
54+
next_obs_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM, STATE_DIM), dtype=np.float32)
55+
acs_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM), dtype=np.int8)
56+
rews_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM), dtype=np.float32)
57+
gammas_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM), dtype=np.float32)
58+
with torch.no_grad():
59+
for batch in range(TRAJ_NUM):
60+
ob = env.reset()
61+
for step in range(TRAJ_LEN):
62+
ob_tensor = torch.from_numpy(ob)
63+
pi, _ = fpolicy(params, ob_tensor)
64+
ac_tensor = pi.sample()
65+
ac = ac_tensor.cpu().numpy()
66+
next_ob, rew, done, info = env.step(ac)
67+
68+
obs_buf[step][batch] = ob
69+
next_obs_buf[step][batch] = next_ob
70+
acs_buf[step][batch] = ac
71+
rews_buf[step][batch] = rew
72+
gammas_buf[step][batch] = done * GAMMA
73+
ob = next_ob
74+
return Traj(obs=obs_buf, acs=acs_buf, next_obs=next_obs_buf, rews=rews_buf, gammas=gammas_buf)
75+
76+
77+
def a2c_loss(traj, fpolicy, params, value_coef):
78+
lambdas = np.ones_like(traj.gammas) * LAMBDA
79+
_, next_values = fpolicy(params, torch.from_numpy(traj.next_obs))
80+
next_values = torch.squeeze(next_values, -1).detach().numpy()
81+
# Work backwards to compute `G_{T-1}`, ..., `G_0`.
82+
returns = []
83+
g = next_values[-1, :]
84+
for i in reversed(range(next_values.shape[0])):
85+
g = traj.rews[i, :] + traj.gammas[i, :] * (
86+
(1 - lambdas[i, :]) * next_values[i, :] + lambdas[i, :] * g
87+
)
88+
returns.insert(0, g)
89+
lambda_returns = torch.from_numpy(np.array(returns))
90+
pi, values = fpolicy(params, torch.from_numpy(traj.obs))
91+
log_probs = pi.log_prob(torch.from_numpy(traj.acs))
92+
advs = lambda_returns - torch.squeeze(values, -1)
93+
action_loss = -(advs.detach() * log_probs).mean()
94+
value_loss = advs.pow(2).mean()
95+
96+
loss = action_loss + value_coef * value_loss
97+
return loss
98+
99+
100+
def evaluate(env, seed, task_num, fpolicy, params):
101+
pre_reward_ls = []
102+
post_reward_ls = []
103+
inner_opt = torchopt.MetaSGD(lr=0.5)
104+
env = gym.make(
105+
'TabularMDP-v0',
106+
**dict(
107+
num_states=STATE_DIM, num_actions=ACTION_DIM, max_episode_steps=TRAJ_LEN, seed=args.seed
108+
),
109+
)
110+
tasks = env.sample_tasks(num_tasks=task_num)
111+
112+
for idx in range(task_num):
113+
for _ in range(inner_iters):
114+
pre_trajs = sample_traj(env, tasks[idx], fpolicy, params)
115+
116+
inner_loss = a2c_loss(pre_trajs, fpolicy, params, value_coef=0.5)
117+
params = inner_opt.step(inner_loss, params)
118+
post_trajs = sample_traj(env, tasks[idx], fpolicy, params)
119+
120+
# Logging
121+
pre_reward_ls.append(np.sum(pre_trajs.rews, axis=0).mean())
122+
post_reward_ls.append(np.sum(post_trajs.rews, axis=0).mean())
123+
124+
return pre_reward_ls, post_reward_ls
125+
126+
127+
def main(args):
128+
# init training
129+
torch.manual_seed(args.seed)
130+
torch.cuda.manual_seed_all(args.seed)
131+
# Env
132+
env = gym.make(
133+
'TabularMDP-v0',
134+
**dict(
135+
num_states=STATE_DIM, num_actions=ACTION_DIM, max_episode_steps=TRAJ_LEN, seed=args.seed
136+
),
137+
)
138+
# Policy
139+
policy = CategoricalMLPPolicy(input_size=STATE_DIM, output_size=ACTION_DIM)
140+
fpolicy, params = functorch.make_functional(policy)
141+
142+
inner_opt = torchopt.MetaSGD(lr=0.5)
143+
outer_opt = optim.Adam(params, lr=1e-3)
144+
train_pre_reward = []
145+
train_post_reward = []
146+
test_pre_reward = []
147+
test_post_reward = []
148+
149+
for i in range(outer_iters):
150+
tasks = env.sample_tasks(num_tasks=TASK_NUM)
151+
train_pre_reward_ls = []
152+
train_post_reward_ls = []
153+
154+
outer_opt.zero_grad()
155+
156+
param_orig = [p.detach().clone().requires_grad_() for p in params]
157+
_params = list(params)
158+
for idx in range(TASK_NUM):
159+
160+
for _ in range(inner_iters):
161+
pre_trajs = sample_traj(env, tasks[idx], fpolicy, _params)
162+
inner_loss = a2c_loss(pre_trajs, fpolicy, _params, value_coef=0.5)
163+
_params = inner_opt.step(inner_loss, _params)
164+
post_trajs = sample_traj(env, tasks[idx], fpolicy, _params)
165+
outer_loss = a2c_loss(post_trajs, fpolicy, _params, value_coef=0.5)
166+
outer_loss.backward()
167+
_params = [p.detach().clone().requires_grad_() for p in param_orig]
168+
169+
# Logging
170+
train_pre_reward_ls.append(np.sum(pre_trajs.rews, axis=0).mean())
171+
train_post_reward_ls.append(np.sum(post_trajs.rews, axis=0).mean())
172+
outer_opt.step()
173+
174+
test_pre_reward_ls, test_post_reward_ls = evaluate(
175+
env, args.seed, TASK_NUM, fpolicy, params
176+
)
177+
178+
train_pre_reward.append(sum(train_pre_reward_ls) / TASK_NUM)
179+
train_post_reward.append(sum(train_post_reward_ls) / TASK_NUM)
180+
test_pre_reward.append(sum(test_pre_reward_ls) / TASK_NUM)
181+
test_post_reward.append(sum(test_post_reward_ls) / TASK_NUM)
182+
183+
print('Train_iters', i)
184+
print('train_pre_reward', sum(train_pre_reward_ls) / TASK_NUM)
185+
print('train_post_reward', sum(train_post_reward_ls) / TASK_NUM)
186+
print('test_pre_reward', sum(test_pre_reward_ls) / TASK_NUM)
187+
print('test_post_reward', sum(test_post_reward_ls) / TASK_NUM)
188+
189+
190+
if __name__ == '__main__':
191+
parser = argparse.ArgumentParser(
192+
description='Reinforcement learning with Model-Agnostic Meta-Learning (MAML) - Train'
193+
)
194+
parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
195+
args = parser.parse_args()
196+
main(args)

examples/MAML-RL/maml.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,9 @@ def a2c_loss(traj, policy, value_coef):
9999
advs = lambda_returns - torch.squeeze(values, -1)
100100
action_loss = -(advs.detach() * log_probs).mean()
101101
value_loss = advs.pow(2).mean()
102-
a2c_loss = action_loss + value_coef * value_loss
103-
return a2c_loss
102+
103+
loss = action_loss + value_coef * value_loss
104+
return loss
104105

105106

106107
def evaluate(env, seed, task_num, policy):

setup.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,13 @@
1414
from setuptools.command.build_ext import build_ext
1515

1616
HERE = pathlib.Path(__file__).absolute().parent
17+
VERSION_FILE = HERE / 'torchopt' / 'version.py'
1718

18-
sys.path.insert(0, str(HERE / 'torchopt'))
19-
import version # noqa
19+
try:
20+
from torchopt import version # noqa
21+
except ImportError:
22+
sys.path.insert(0, str(VERSION_FILE.parent))
23+
import version # noqa
2024

2125

2226
class CMakeExtension(Extension):
@@ -81,10 +85,28 @@ def build_extension(self, ext):
8185
os.chdir(HERE)
8286

8387

84-
setup(
85-
version=version.__version__,
86-
package_data={'sharedlib': ['*.so', '*.pyd']},
87-
include_package_data=True,
88-
cmdclass={'build_ext': cmake_build_ext},
89-
ext_modules=[CMakeExtension('torchopt._C', source_dir=HERE)],
90-
)
88+
VERSION_CONTENT = None
89+
if not version.__release__:
90+
import re
91+
92+
VERSION_CONTENT = VERSION_FILE.read_text(encoding='UTF-8')
93+
VERSION_FILE.write_text(
94+
data=re.sub(
95+
r"""__version__\s*=\s*('[^']+'|"[^"]+")""",
96+
r"__version__ = '{}'".format(version.__version__),
97+
string=VERSION_CONTENT,
98+
),
99+
encoding='UTF-8',
100+
)
101+
102+
try:
103+
setup(
104+
version=version.__version__,
105+
package_data={'sharedlib': ['*.so', '*.pyd']},
106+
include_package_data=True,
107+
cmdclass={'build_ext': cmake_build_ext},
108+
ext_modules=[CMakeExtension('torchopt._C', source_dir=HERE)],
109+
)
110+
finally:
111+
if VERSION_CONTENT is not None:
112+
VERSION_FILE.write_text(data=VERSION_CONTENT, encoding='UTF-8')

tests/test_optimizer.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16-
from typing import Tuple
16+
from typing import Callable, Tuple
1717

18+
import functorch
1819
import pytest
1920
import torch
2021
import torch.nn.functional as F
@@ -364,3 +365,56 @@ def test_RMSProp(
364365
optim_ref.step()
365366

366367
helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype)
368+
369+
370+
@helpers.parametrize(
371+
dtype=[torch.float64, torch.float32],
372+
lr=[1e-2, 1e-3],
373+
optimizers=[
374+
(torchopt.sgd, torch.optim.SGD),
375+
(torchopt.adam, torch.optim.Adam),
376+
(torchopt.adamw, torch.optim.AdamW),
377+
(torchopt.rmsprop, torch.optim.RMSprop),
378+
],
379+
inplace=[True, False],
380+
weight_decay=[0.0, 1e-2],
381+
)
382+
def test_FuncOptimizer(
383+
dtype: torch.dtype,
384+
lr: float,
385+
optimizers: Tuple[Callable, torch.optim.Optimizer],
386+
inplace: bool,
387+
weight_decay: float,
388+
) -> None:
389+
model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype)
390+
391+
torchopt_optimizer, torch_optimizer = optimizers
392+
393+
fmodel, params, buffers = functorch.make_functional_with_buffers(model)
394+
optim = torchopt.FuncOptimizer(
395+
torchopt_optimizer(
396+
lr=lr,
397+
weight_decay=weight_decay,
398+
),
399+
inplace=inplace,
400+
)
401+
optim_ref = torch_optimizer(
402+
model_ref.parameters(),
403+
lr,
404+
weight_decay=weight_decay,
405+
)
406+
407+
for xs, ys in loader:
408+
xs = xs.to(dtype=dtype)
409+
pred = fmodel(params, buffers, xs)
410+
pred_ref = model_ref(xs)
411+
loss = F.cross_entropy(pred, ys)
412+
loss_ref = F.cross_entropy(pred_ref, ys)
413+
414+
params = optim.step(loss, params)
415+
416+
optim_ref.zero_grad()
417+
loss_ref.backward()
418+
optim_ref.step()
419+
420+
helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy