|
| 1 | +# Copyright 2022 MetaOPT Team. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | + |
| 16 | +import argparse |
| 17 | +from typing import NamedTuple |
| 18 | + |
| 19 | +import functorch |
| 20 | +import gym |
| 21 | +import numpy as np |
| 22 | +import torch |
| 23 | +import torch.optim as optim |
| 24 | + |
| 25 | +import torchopt |
| 26 | +from helpers.policy import CategoricalMLPPolicy |
| 27 | + |
| 28 | + |
| 29 | +TASK_NUM = 40 |
| 30 | +TRAJ_NUM = 20 |
| 31 | +TRAJ_LEN = 10 |
| 32 | + |
| 33 | +STATE_DIM = 10 |
| 34 | +ACTION_DIM = 5 |
| 35 | + |
| 36 | +GAMMA = 0.99 |
| 37 | +LAMBDA = 0.95 |
| 38 | + |
| 39 | +outer_iters = 500 |
| 40 | +inner_iters = 1 |
| 41 | + |
| 42 | + |
| 43 | +class Traj(NamedTuple): |
| 44 | + obs: np.ndarray |
| 45 | + acs: np.ndarray |
| 46 | + next_obs: np.ndarray |
| 47 | + rews: np.ndarray |
| 48 | + gammas: np.ndarray |
| 49 | + |
| 50 | + |
| 51 | +def sample_traj(env, task, fpolicy, params): |
| 52 | + env.reset_task(task) |
| 53 | + obs_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM, STATE_DIM), dtype=np.float32) |
| 54 | + next_obs_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM, STATE_DIM), dtype=np.float32) |
| 55 | + acs_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM), dtype=np.int8) |
| 56 | + rews_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM), dtype=np.float32) |
| 57 | + gammas_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM), dtype=np.float32) |
| 58 | + with torch.no_grad(): |
| 59 | + for batch in range(TRAJ_NUM): |
| 60 | + ob = env.reset() |
| 61 | + for step in range(TRAJ_LEN): |
| 62 | + ob_tensor = torch.from_numpy(ob) |
| 63 | + pi, _ = fpolicy(params, ob_tensor) |
| 64 | + ac_tensor = pi.sample() |
| 65 | + ac = ac_tensor.cpu().numpy() |
| 66 | + next_ob, rew, done, info = env.step(ac) |
| 67 | + |
| 68 | + obs_buf[step][batch] = ob |
| 69 | + next_obs_buf[step][batch] = next_ob |
| 70 | + acs_buf[step][batch] = ac |
| 71 | + rews_buf[step][batch] = rew |
| 72 | + gammas_buf[step][batch] = done * GAMMA |
| 73 | + ob = next_ob |
| 74 | + return Traj(obs=obs_buf, acs=acs_buf, next_obs=next_obs_buf, rews=rews_buf, gammas=gammas_buf) |
| 75 | + |
| 76 | + |
| 77 | +def a2c_loss(traj, fpolicy, params, value_coef): |
| 78 | + lambdas = np.ones_like(traj.gammas) * LAMBDA |
| 79 | + _, next_values = fpolicy(params, torch.from_numpy(traj.next_obs)) |
| 80 | + next_values = torch.squeeze(next_values, -1).detach().numpy() |
| 81 | + # Work backwards to compute `G_{T-1}`, ..., `G_0`. |
| 82 | + returns = [] |
| 83 | + g = next_values[-1, :] |
| 84 | + for i in reversed(range(next_values.shape[0])): |
| 85 | + g = traj.rews[i, :] + traj.gammas[i, :] * ( |
| 86 | + (1 - lambdas[i, :]) * next_values[i, :] + lambdas[i, :] * g |
| 87 | + ) |
| 88 | + returns.insert(0, g) |
| 89 | + lambda_returns = torch.from_numpy(np.array(returns)) |
| 90 | + pi, values = fpolicy(params, torch.from_numpy(traj.obs)) |
| 91 | + log_probs = pi.log_prob(torch.from_numpy(traj.acs)) |
| 92 | + advs = lambda_returns - torch.squeeze(values, -1) |
| 93 | + action_loss = -(advs.detach() * log_probs).mean() |
| 94 | + value_loss = advs.pow(2).mean() |
| 95 | + |
| 96 | + loss = action_loss + value_coef * value_loss |
| 97 | + return loss |
| 98 | + |
| 99 | + |
| 100 | +def evaluate(env, seed, task_num, fpolicy, params): |
| 101 | + pre_reward_ls = [] |
| 102 | + post_reward_ls = [] |
| 103 | + inner_opt = torchopt.MetaSGD(lr=0.5) |
| 104 | + env = gym.make( |
| 105 | + 'TabularMDP-v0', |
| 106 | + **dict( |
| 107 | + num_states=STATE_DIM, num_actions=ACTION_DIM, max_episode_steps=TRAJ_LEN, seed=args.seed |
| 108 | + ), |
| 109 | + ) |
| 110 | + tasks = env.sample_tasks(num_tasks=task_num) |
| 111 | + |
| 112 | + for idx in range(task_num): |
| 113 | + for _ in range(inner_iters): |
| 114 | + pre_trajs = sample_traj(env, tasks[idx], fpolicy, params) |
| 115 | + |
| 116 | + inner_loss = a2c_loss(pre_trajs, fpolicy, params, value_coef=0.5) |
| 117 | + params = inner_opt.step(inner_loss, params) |
| 118 | + post_trajs = sample_traj(env, tasks[idx], fpolicy, params) |
| 119 | + |
| 120 | + # Logging |
| 121 | + pre_reward_ls.append(np.sum(pre_trajs.rews, axis=0).mean()) |
| 122 | + post_reward_ls.append(np.sum(post_trajs.rews, axis=0).mean()) |
| 123 | + |
| 124 | + return pre_reward_ls, post_reward_ls |
| 125 | + |
| 126 | + |
| 127 | +def main(args): |
| 128 | + # init training |
| 129 | + torch.manual_seed(args.seed) |
| 130 | + torch.cuda.manual_seed_all(args.seed) |
| 131 | + # Env |
| 132 | + env = gym.make( |
| 133 | + 'TabularMDP-v0', |
| 134 | + **dict( |
| 135 | + num_states=STATE_DIM, num_actions=ACTION_DIM, max_episode_steps=TRAJ_LEN, seed=args.seed |
| 136 | + ), |
| 137 | + ) |
| 138 | + # Policy |
| 139 | + policy = CategoricalMLPPolicy(input_size=STATE_DIM, output_size=ACTION_DIM) |
| 140 | + fpolicy, params = functorch.make_functional(policy) |
| 141 | + |
| 142 | + inner_opt = torchopt.MetaSGD(lr=0.5) |
| 143 | + outer_opt = optim.Adam(params, lr=1e-3) |
| 144 | + train_pre_reward = [] |
| 145 | + train_post_reward = [] |
| 146 | + test_pre_reward = [] |
| 147 | + test_post_reward = [] |
| 148 | + |
| 149 | + for i in range(outer_iters): |
| 150 | + tasks = env.sample_tasks(num_tasks=TASK_NUM) |
| 151 | + train_pre_reward_ls = [] |
| 152 | + train_post_reward_ls = [] |
| 153 | + |
| 154 | + outer_opt.zero_grad() |
| 155 | + |
| 156 | + param_orig = [p.detach().clone().requires_grad_() for p in params] |
| 157 | + _params = list(params) |
| 158 | + for idx in range(TASK_NUM): |
| 159 | + |
| 160 | + for _ in range(inner_iters): |
| 161 | + pre_trajs = sample_traj(env, tasks[idx], fpolicy, _params) |
| 162 | + inner_loss = a2c_loss(pre_trajs, fpolicy, _params, value_coef=0.5) |
| 163 | + _params = inner_opt.step(inner_loss, _params) |
| 164 | + post_trajs = sample_traj(env, tasks[idx], fpolicy, _params) |
| 165 | + outer_loss = a2c_loss(post_trajs, fpolicy, _params, value_coef=0.5) |
| 166 | + outer_loss.backward() |
| 167 | + _params = [p.detach().clone().requires_grad_() for p in param_orig] |
| 168 | + |
| 169 | + # Logging |
| 170 | + train_pre_reward_ls.append(np.sum(pre_trajs.rews, axis=0).mean()) |
| 171 | + train_post_reward_ls.append(np.sum(post_trajs.rews, axis=0).mean()) |
| 172 | + outer_opt.step() |
| 173 | + |
| 174 | + test_pre_reward_ls, test_post_reward_ls = evaluate( |
| 175 | + env, args.seed, TASK_NUM, fpolicy, params |
| 176 | + ) |
| 177 | + |
| 178 | + train_pre_reward.append(sum(train_pre_reward_ls) / TASK_NUM) |
| 179 | + train_post_reward.append(sum(train_post_reward_ls) / TASK_NUM) |
| 180 | + test_pre_reward.append(sum(test_pre_reward_ls) / TASK_NUM) |
| 181 | + test_post_reward.append(sum(test_post_reward_ls) / TASK_NUM) |
| 182 | + |
| 183 | + print('Train_iters', i) |
| 184 | + print('train_pre_reward', sum(train_pre_reward_ls) / TASK_NUM) |
| 185 | + print('train_post_reward', sum(train_post_reward_ls) / TASK_NUM) |
| 186 | + print('test_pre_reward', sum(test_pre_reward_ls) / TASK_NUM) |
| 187 | + print('test_post_reward', sum(test_post_reward_ls) / TASK_NUM) |
| 188 | + |
| 189 | + |
| 190 | +if __name__ == '__main__': |
| 191 | + parser = argparse.ArgumentParser( |
| 192 | + description='Reinforcement learning with Model-Agnostic Meta-Learning (MAML) - Train' |
| 193 | + ) |
| 194 | + parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)') |
| 195 | + args = parser.parse_args() |
| 196 | + main(args) |
0 commit comments