Skip to content

Commit 13fdeca

Browse files
committed
Merge branch 'main' into feature/adagrad
2 parents f095974 + 6223010 commit 13fdeca

28 files changed

+258
-106
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ repos:
2929
- id: clang-format
3030
stages: [commit, push, manual]
3131
- repo: https://github.com/charliermarsh/ruff-pre-commit
32-
rev: v0.0.256
32+
rev: v0.0.257
3333
hooks:
3434
- id: ruff
3535
args: [--fix, --exit-non-zero-on-fix]

examples/FuncTorch/maml_omniglot_vmap.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,10 @@ def main():
7979
argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15)
8080
argparser.add_argument('--device', type=str, help='device', default='cuda')
8181
argparser.add_argument(
82-
'--task_num', type=int, help='meta batch size, namely task num', default=32
82+
'--task_num',
83+
type=int,
84+
help='meta batch size, namely task num',
85+
default=32,
8386
)
8487
argparser.add_argument('--seed', type=int, help='random seed', default=1)
8588
args = argparser.parse_args()
@@ -199,7 +202,7 @@ def train(db, net, device, meta_opt, epoch, log):
199202

200203
if batch_idx % 4 == 0:
201204
print(
202-
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
205+
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}',
203206
)
204207
log.append(
205208
{
@@ -208,7 +211,7 @@ def train(db, net, device, meta_opt, epoch, log):
208211
'acc': qry_accs,
209212
'mode': 'train',
210213
'time': time.time(),
211-
}
214+
},
212215
)
213216

214217

@@ -257,7 +260,7 @@ def test(db, net, device, epoch, log):
257260
'acc': qry_accs,
258261
'mode': 'test',
259262
'time': time.time(),
260-
}
263+
},
261264
)
262265

263266

examples/FuncTorch/parallel_train_torchopt.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,9 @@ def test_parallel_train_step_fn(self, num_models):
135135
weights, opt_state = parallel_init_fn(torch.ones(num_models, 1))
136136
for i in range(2000):
137137
loss, (weights, opt_states) = parallel_train_step_fn(
138-
(weights, opt_state), points, labels
138+
(weights, opt_state),
139+
points,
140+
labels,
139141
)
140142
if i % 200 == 0:
141143
print(loss)
@@ -186,7 +188,9 @@ def test_parallel_train_step_fn(self, num_models):
186188
optimizer = torchopt.adam(lr=0.2)
187189
opt_state = optimizer.init(weights)
188190
functorch_original = ParallelTrainFunctorchTorchOpt(
189-
loss_fn=loss_fn, optimizer=optimizer, device=DEVICE
191+
loss_fn=loss_fn,
192+
optimizer=optimizer,
193+
device=DEVICE,
190194
)
191195
# Step 4: Let's verify this actually trains.
192196
# We should see the loss decrease.

examples/L2R/helpers/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(self, args):
5151
)
5252
self.args = args
5353
self.meta_weights = torch.zeros(self.args.batch_size, requires_grad=True).to(
54-
self.args.device
54+
self.args.device,
5555
)
5656
self.criterion = nn.BCELoss()
5757

examples/L2R/l2r.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,11 @@ def run_L2R(args, mnist_train, mnist_test):
199199
running_train_mean = np.mean(np.array(running_train_loss))
200200
print(
201201
'EPOCH: {}, BATCH: {}, WEIGHTED_TRAIN_LOSS: {}, VALID_LOSS: {}'.format(
202-
_epoch, idx, running_train_mean, running_valid_mean
203-
)
202+
_epoch,
203+
idx,
204+
running_train_mean,
205+
running_valid_mean,
206+
),
204207
)
205208
running_valid_loss = []
206209
running_train_loss = []

examples/LOLA/helpers/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def dice_objective(self, use_baseline=True):
8282
if use_baseline:
8383
# variance_reduction:
8484
baseline_term = torch.mean(
85-
torch.sum((1 - magic_box(stochastic_nodes)) * discounted_values, dim=1)
85+
torch.sum((1 - magic_box(stochastic_nodes)) * discounted_values, dim=1),
8686
)
8787
dice_objective = dice_objective + baseline_term
8888

examples/MAML-RL/func_maml.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,11 @@ def main(args):
173173
outer_opt.step()
174174

175175
test_pre_reward_ls, test_post_reward_ls = evaluate(
176-
env, args.seed, TASK_NUM, fpolicy, params
176+
env,
177+
args.seed,
178+
TASK_NUM,
179+
fpolicy,
180+
params,
177181
)
178182

179183
train_pre_reward.append(sum(train_pre_reward_ls) / TASK_NUM)
@@ -190,7 +194,7 @@ def main(args):
190194

191195
if __name__ == '__main__':
192196
parser = argparse.ArgumentParser(
193-
description='Reinforcement learning with Model-Agnostic Meta-Learning (MAML) - Train'
197+
description='Reinforcement learning with Model-Agnostic Meta-Learning (MAML) - Train',
194198
)
195199
parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
196200
args = parser.parse_args()

examples/MAML-RL/helpers/tabular_mdp.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ def __init__(self, num_states, num_actions, max_episode_steps, seed, task=None):
4949

5050
self.action_space = spaces.Discrete(num_actions)
5151
self.observation_space = spaces.Box(
52-
low=0.0, high=1.0, shape=(num_states,), dtype=np.float32
52+
low=0.0,
53+
high=1.0,
54+
shape=(num_states,),
55+
dtype=np.float32,
5356
)
5457

5558
self._task = task
@@ -62,7 +65,8 @@ def __init__(self, num_states, num_actions, max_episode_steps, seed, task=None):
6265
),
6366
)
6467
self._rewards_mean = task.get(
65-
'rewards_mean', np.zeros((num_states, num_actions), dtype=np.float32)
68+
'rewards_mean',
69+
np.zeros((num_states, num_actions), dtype=np.float32),
6670
)
6771
self._state = 0
6872
self._elapsed_steps = None
@@ -79,7 +83,9 @@ def sample_tasks(self, num_tasks):
7983
size=(num_tasks, self.num_states, self.num_actions),
8084
)
8185
rewards_mean = self.np_random.normal(
82-
1.0, 1.0, size=(num_tasks, self.num_states, self.num_actions)
86+
1.0,
87+
1.0,
88+
size=(num_tasks, self.num_states, self.num_actions),
8389
)
8490
tasks = [
8591
{'transitions': transition, 'rewards_mean': reward_mean}
@@ -106,7 +112,8 @@ def step(self, action):
106112
reward = self.np_random.normal(mean, 1.0)
107113

108114
self._state = self.np_random.choice(
109-
self.num_states, p=self._transitions[self._state, action]
115+
self.num_states,
116+
p=self._transitions[self._state, action],
110117
)
111118
observation = np.zeros(self.num_states, dtype=np.float32)
112119
observation[self._state] = 1.0

examples/MAML-RL/maml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def main(args):
193193

194194
if __name__ == '__main__':
195195
parser = argparse.ArgumentParser(
196-
description='Reinforcement learning with Model-Agnostic Meta-Learning (MAML) - Train'
196+
description='Reinforcement learning with Model-Agnostic Meta-Learning (MAML) - Train',
197197
)
198198
parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
199199
args = parser.parse_args()

examples/MAML-RL/maml_torchrl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,15 +234,15 @@ def lambda_env():
234234
f'train_pre_reward: {train_pre_reward[-1]: 4.4f}, '
235235
f'train_post_reward: {train_post_reward[-1]: 4.4f}, '
236236
f'test_pre_reward: {test_pre_reward[-1]: 4.4f}, '
237-
f'test_post_reward: {test_post_reward[-1]: 4.4f}, '
237+
f'test_post_reward: {test_post_reward[-1]: 4.4f}, ',
238238
)
239239

240240
env.close()
241241

242242

243243
if __name__ == '__main__':
244244
parser = argparse.ArgumentParser(
245-
description='Reinforcement learning with Model-Agnostic Meta-Learning (MAML) - Train'
245+
description='Reinforcement learning with Model-Agnostic Meta-Learning (MAML) - Train',
246246
)
247247
parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
248248
parser.add_argument('--parallel', action='store_true', help='run envs in parallel')

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy