Skip to content

fix(diff/implicit): fix memory leak of OOP APIs #113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jan 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ repos:
^setup.py$
)
- repo: https://github.com/pycqa/pydocstyle
rev: 6.2.2
rev: 6.2.3
hooks:
- id: pydocstyle
additional_dependencies: ['.[toml]']
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed


- Fix memory leak in implicit MAML omniglot few-shot classification example with OOP APIs by [@XuehaiPan](https://github.com/XuehaiPan) in [#113](https://github.com/metaopt/torchopt/pull/113).

### Removed

Expand Down
1 change: 1 addition & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,4 @@ jvp
ATen
samplable
conj
reparameterize
4 changes: 1 addition & 3 deletions examples/FuncTorch/maml_omniglot_vmap.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 MetaOPT Team. All Rights Reserved.
# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -196,7 +196,6 @@ def train(db, net, device, meta_opt, epoch, log):
qry_accs = 100.0 * torch.mean(torch.stack(qry_accs)).item()
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
torch.cuda.empty_cache()

if batch_idx % 4 == 0:
print(
Expand Down Expand Up @@ -249,7 +248,6 @@ def test(db, net, device, epoch, log):

qry_losses = torch.mean(torch.stack(qry_losses)).item()
qry_accs = 100.0 * torch.mean(torch.stack(qry_accs)).item()
torch.cuda.empty_cache()

print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}')
log.append(
Expand Down
4 changes: 1 addition & 3 deletions examples/distributed/few-shot/maml_omniglot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 MetaOPT Team. All Rights Reserved.
# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -229,7 +229,6 @@ def train(db: OmniglotNShot, net: nn.Module, meta_opt: optim.Adam, epoch: int, l
qry_acc = 100.0 * qry_acc
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
torch.cuda.empty_cache()

print(
f'[Epoch {i:.2f}] Train Loss: {qry_loss:.2f} | Acc: {qry_acc:.2f} | Time: {iter_time:.2f}'
Expand Down Expand Up @@ -272,7 +271,6 @@ def test(db, net, epoch, log):

qry_losses = np.mean(qry_losses)
qry_accs = 100.0 * np.mean(qry_accs)
torch.cuda.empty_cache()

print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}')
log.append(
Expand Down
4 changes: 1 addition & 3 deletions examples/distributed/few-shot/maml_omniglot_local_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 MetaOPT Team. All Rights Reserved.
# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -272,7 +272,6 @@ def train(net: nn.Module, meta_opt: optim.Adam, epoch: int, log: list):
qry_acc = 100.0 * qry_acc
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
torch.cuda.empty_cache()

print(
f'[Epoch {i:.2f}] Train Loss: {qry_loss:.2f} | Acc: {qry_acc:.2f} | Time: {iter_time:.2f}'
Expand Down Expand Up @@ -316,7 +315,6 @@ def test(net, epoch, log):

qry_losses = np.mean(qry_losses)
qry_accs = 100.0 * np.mean(qry_accs)
torch.cuda.empty_cache()

print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}')
log.append(
Expand Down
4 changes: 1 addition & 3 deletions examples/few-shot/maml_omniglot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 MetaOPT Team. All Rights Reserved.
# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -176,7 +176,6 @@ def train(db, net, meta_opt, epoch, log):
qry_accs = 100.0 * np.mean(qry_accs)
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
torch.cuda.empty_cache()

print(
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
Expand Down Expand Up @@ -237,7 +236,6 @@ def test(db, net, epoch, log):

qry_losses = np.mean(qry_losses)
qry_accs = 100.0 * np.mean(qry_accs)
torch.cuda.empty_cache()

print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}')
log.append(
Expand Down
27 changes: 14 additions & 13 deletions examples/iMAML/imaml_omniglot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 MetaOPT Team. All Rights Reserved.
# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -51,6 +51,13 @@ def __init__(self, meta_net, n_inner_iter, reg_param):
self.net = torchopt.module_clone(meta_net, by='deepcopy', detach_buffers=True)
self.n_inner_iter = n_inner_iter
self.reg_param = reg_param
self.reset_parameters()

def reset_parameters(self):
with torch.no_grad():
for p1, p2 in zip(self.parameters(), self.meta_parameters()):
p1.data.copy_(p2.data)
p1.detach_().requires_grad_()

def forward(self, x):
return self.net(x)
Expand Down Expand Up @@ -145,21 +152,16 @@ def main():

def train(db, net, meta_opt, epoch, log, args):
n_train_iter = db.x_train.shape[0] // db.batchsz
# Given this module we've created, rip out the parameters and buffers
# and return a functional version of the module. `fnet` is stateless
# and can be called with `fnet(params, buffers, args, kwargs)`
# fnet, params, buffers = functorch.make_functional_with_buffers(net)
n_inner_iter = args.inner_steps
reg_param = args.reg_params
task_num = args.task_num

inner_nets = [InnerNet(net, n_inner_iter, reg_param) for _ in range(task_num)]
for batch_idx in range(n_train_iter):
start_time = time.time()
# Sample a batch of support and query images and labels.
x_spt, y_spt, x_qry, y_qry = db.next()

task_num = x_spt.size(0)

n_inner_iter = args.inner_steps
reg_param = args.reg_params

qry_losses = []
qry_accs = []
meta_opt.zero_grad()
Expand All @@ -169,7 +171,8 @@ def train(db, net, meta_opt, epoch, log, args):
# gradient steps w.r.t. the model's parameters.
# This adapts the model's meta-parameters to the task.

inner_net = InnerNet(net, n_inner_iter, reg_param)
inner_net = inner_nets[i]
inner_net.reset_parameters()
optimal_inner_net = inner_net.solve(x_spt[i], y_spt[i])

# The final set of adapted parameters will induce some
Expand All @@ -188,7 +191,6 @@ def train(db, net, meta_opt, epoch, log, args):
qry_accs = 100.0 * np.mean(qry_accs)
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
torch.cuda.empty_cache()

print(
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
Expand Down Expand Up @@ -243,7 +245,6 @@ def test(db, net, epoch, log, args):

qry_losses = np.mean(qry_losses)
qry_accs = 100.0 * np.mean(qry_accs)
torch.cuda.empty_cache()

print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}')
log.append(
Expand Down
4 changes: 1 addition & 3 deletions examples/iMAML/imaml_omniglot_functional.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 MetaOPT Team. All Rights Reserved.
# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -165,7 +165,6 @@ def train(db, model, meta_opt_and_state, epoch, log, args):
qry_accs = 100.0 * np.mean(qry_accs)
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
torch.cuda.empty_cache()

print(
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
Expand Down Expand Up @@ -227,7 +226,6 @@ def test(db, model, epoch, log, args):

qry_losses = np.mean(qry_losses)
qry_accs = 100.0 * np.mean(qry_accs)
torch.cuda.empty_cache()

print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}')
log.append(
Expand Down
3 changes: 1 addition & 2 deletions torchopt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from torchopt.clip import clip_grad_norm
from torchopt.combine import chain
from torchopt.hook import register_hook
from torchopt.optim import SGD, Adam, AdamW, Optimizer, RMSProp, RMSprop, meta
from torchopt.optim import SGD, Adam, AdamW, Optimizer, RMSProp, RMSprop
from torchopt.optim.func import FuncOptimizer
from torchopt.optim.meta import (
MetaAdam,
Expand All @@ -56,7 +56,6 @@

__all__ = [
'accelerated_op_available',
'diff',
'adam',
'adamw',
'rmsprop',
Expand Down
3 changes: 2 additions & 1 deletion torchopt/diff/implicit/decorator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 MetaOPT Team. All Rights Reserved.
# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -289,6 +289,7 @@ def forward( # type: ignore[override] # pylint: disable=arguments-differ
f'solver_fn should be a torch.Tensor or a tuple of torch.Tensor. '
f'Got {output}'
)
output = tuple(t.data for t in output)

(
args_treespec,
Expand Down
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy