Skip to content

feat: implicit differentiation integration #41

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 62 commits into from
Sep 8, 2022
Merged
Changes from 1 commit
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
6a22b93
feat(third_party/pybind11): switch to 2.9.2
JieRen98 Apr 17, 2022
efc8dd7
Merge branch 'metaopt:main' into dev-implicit
JieRen98 Apr 17, 2022
5cfed82
feat(implicit): add new implicit implementation
JieRen98 May 1, 2022
99a89c5
fix(linear_solver): fix _make_rmatvec
JieRen98 Jul 25, 2022
3155e4a
feat(implicit): allow multi input
JieRen98 Jul 30, 2022
8665895
feat(all): use jax.tree_util.tree_map
JieRen98 Jul 30, 2022
53c31b5
feat(implicit): rename linear_solve.py
JieRen98 Jul 30, 2022
12c30a2
feat(implicit): allow argnums
JieRen98 Aug 9, 2022
f4adb11
fix(implicit): allow has_aux
JieRen98 Aug 9, 2022
76a4a8a
fix(implicit): fix memory leak
JieRen98 Aug 9, 2022
addce03
fix(implicit): add assert
JieRen98 Aug 9, 2022
9de8c4b
fix(implicit): fix cg convergence problem
JieRen98 Aug 12, 2022
1f2bac0
style(implicit): clean codes
JieRen98 Aug 12, 2022
362fbb4
fix(implicit): fix 0-d tensor
JieRen98 Aug 12, 2022
aa29460
fix(implicit): fix 0-d tensor
JieRen98 Aug 12, 2022
309b25d
fix(implicit): fix res is tensor
JieRen98 Aug 13, 2022
ddaa619
fix(implicit): make aux does not require grad
JieRen98 Aug 13, 2022
5af9181
Merge branch 'merge-remote' into dev-implicit
JieRen98 Aug 20, 2022
7d4b1a0
fix(implicit): import implicit functions
JieRen98 Aug 20, 2022
63b5643
style(implicit): reformat
JieRen98 Aug 20, 2022
1114bef
fix(implicit): lint
JieRen98 Aug 20, 2022
abd38fe
Merge branch 'metaopt:main' into dev-implicit
JieRen98 Aug 28, 2022
c055e0d
implicit test
waterhorse1 Aug 29, 2022
074d7b8
docs(implicit): add docs
JieRen98 Aug 29, 2022
9fca772
Merge branch 'metaopt:main' into dev-implicit
JieRen98 Sep 3, 2022
7edfd07
fix(implicit): fix bug
JieRen98 Sep 3, 2022
48f1f88
fix(implicit): fix bug
JieRen98 Sep 3, 2022
f0f3be1
fix(implicit): install jax&jaxopt for test
JieRen98 Sep 4, 2022
060dc5b
fix(implicit): chmod
JieRen98 Sep 4, 2022
33ba3e8
fix(implicit): disable invalid-name
JieRen98 Sep 4, 2022
37ef230
fix(implicit): depend on functorch
JieRen98 Sep 4, 2022
50e5b11
docs(implicit): add description
JieRen98 Sep 4, 2022
24a0e02
fix(implicit): lint bug
JieRen98 Sep 4, 2022
76c6e55
fix(implicit): build bug
JieRen98 Sep 5, 2022
c455ac3
chore: update dependencies
XuehaiPan Sep 5, 2022
2f0ef6c
Merge remote-tracking branch 'upstream/main' into dev-implicit
XuehaiPan Sep 5, 2022
2f9041c
Merge branch 'main' into dev-implicit
XuehaiPan Sep 5, 2022
90f9700
Merge branch 'main' into dev-implicit
XuehaiPan Sep 6, 2022
3185c9d
wip
XuehaiPan Sep 6, 2022
9760fef
wip
XuehaiPan Sep 6, 2022
5fe9129
wip
XuehaiPan Sep 6, 2022
1eace13
wip
XuehaiPan Sep 6, 2022
52a9d60
wip
XuehaiPan Sep 6, 2022
ea24cd1
wip
XuehaiPan Sep 6, 2022
3b00248
wip
XuehaiPan Sep 6, 2022
2e5e882
wip
XuehaiPan Sep 6, 2022
5a26ec2
wip
XuehaiPan Sep 6, 2022
7d79e9e
wip
XuehaiPan Sep 6, 2022
aadfb5d
feat: test other torch specs only once
XuehaiPan Sep 6, 2022
39a1002
fix(implicit): fix scalar bug
JieRen98 Sep 6, 2022
1528f87
fix(implicit): fix scalar bug
JieRen98 Sep 7, 2022
fc56956
fix(implicit): fix solving bug
JieRen98 Sep 7, 2022
75fafcd
fix(implicit): remove unused import
JieRen98 Sep 7, 2022
2dfe754
chore: update type hint
XuehaiPan Sep 7, 2022
3b627f4
fix: update tree inner product
XuehaiPan Sep 7, 2022
028c607
fix: update tree inner product
XuehaiPan Sep 7, 2022
b333d8b
chore: expose pytree
XuehaiPan Sep 7, 2022
02d28b9
style(implicit): change to a compact style
JieRen98 Sep 7, 2022
d8abe9a
new test implicit gradient
waterhorse1 Sep 7, 2022
bde87b1
style(implicit): reformat
JieRen98 Sep 7, 2022
dfd4c30
fix: fix tests
XuehaiPan Sep 8, 2022
ec3b2d3
lint: appease linters
XuehaiPan Sep 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix(implicit): fix memory leak
  • Loading branch information
JieRen98 committed Aug 9, 2022
commit 76a4a8a29751e1ec95a060f868f0e634bb7f81a5
72 changes: 59 additions & 13 deletions TorchOpt/_src/implicit_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self, argnums, *args) -> None:
pre_filled = []
post_filled = []
for idx, arg in enumerate(args):
if idx + 1 in argnums:
if idx + 1 in argnums: # plus 1 because we exclude the first argument
post_filled.append(arg)
else:
pre_filled.append(arg)
Expand All @@ -84,11 +84,13 @@ def __init__(self, argnums, *args) -> None:
def __call__(self, *args) -> Any:
args = list(args)
true_args = []
pre_filled_counter = 0
for idx in range(self.len_args):
if idx + 1 in self.argnums:
arg = args.pop(0)
if idx + 1 in self.argnums: # plus 1 because we exclude the first argument
arg = args[idx]
else:
arg = self.pre_filled.pop(0)
arg = self.pre_filled[pre_filled_counter]
pre_filled_counter += 1
true_args.append(arg)
return optimality_fun(sol, *true_args)

Expand All @@ -104,8 +106,8 @@ def __call__(self, *args) -> Any:
result = list(result)
true_result = [None]
for idx in range(fun_args.len_args):
if idx + 1 in argnums:
true_result.append(result.pop(0))
if idx + 1 in argnums: # plus 1 because we exclude the first argument
true_result.append(result[idx])
else:
true_result.append(None)

Expand Down Expand Up @@ -159,6 +161,35 @@ def map_back(out_args):
return out_args, out_kwargs, map_back


def _split_tensor_and_others(mixed_tuple):
flat_tuple, tree = jax.tree_flatten(mixed_tuple)
tensor_tuple = []
non_tensor_tuple = []
tensor_mask = []
for item in flat_tuple:
if isinstance(item, torch.Tensor):
tensor_tuple.append(item)
tensor_mask.append(True)
else:
non_tensor_tuple.append(item)
tensor_mask.append(False)
return tree, tuple(tensor_mask), tuple(tensor_tuple), tuple(non_tensor_tuple)


def _merge_tensor_and_others(tree, tensor_mask, tensor_tuple, non_tensor_tuple):
tensor_counter = 0
non_tensor_counter = 0
result_tuple = []
for is_tensor in tensor_mask:
if is_tensor:
result_tuple.append(tensor_tuple[tensor_counter])
tensor_counter += 1
else:
result_tuple.append(non_tensor_tuple[non_tensor_counter])
non_tensor_counter += 1
result_tuple = tuple(result_tuple)
return tree.unflatten(result_tuple)

def _custom_root(solver_fun, optimality_fun, solve, argnums, has_aux,
reference_signature=None):
# When caling through `jax.custom_vjp`, jax attempts to resolve all
Expand Down Expand Up @@ -205,24 +236,39 @@ def forward(ctx, *flat_args):

args, kwargs = _extract_kwargs(kwarg_keys, args)
res = solver_fun(*args, **kwargs)
args_tree, args_tensor_mask, args_tensor, args_non_tensor = _split_tensor_and_others(args)
ctx.args_tree = args_tree
ctx.args_tensor_mask = args_tensor_mask
ctx.args_non_tensor = args_non_tensor
if has_aux:
aux = res[1]
res = res[0]
ctx.aux = (res, args)
if isinstance(res, torch.Tensor):
ctx.save_for_backward(res, *args_tensor)
else:
ctx.save_for_backward(*res, *args_tensor)
return *res, aux
else:
ctx.aux = (res, args)
if isinstance(res, torch.Tensor):
ctx.save_for_backward(res, *args_tensor)
else:
ctx.save_for_backward(*res, *args_tensor)
return res

@staticmethod
def backward(ctx, *cotangent):
res, flat_args = ctx.aux
ctx.aux = None
args, kwargs = _extract_kwargs(kwarg_keys, flat_args)

# solver_fun can return auxiliary data if has_aux = True.
if has_aux:
cotangent = cotangent[:-1]

saved_tensors = ctx.saved_tensors
res, args_tensor = saved_tensors[:len(cotangent)], saved_tensors[len(cotangent):]
args_tree = ctx.args_tree
args_tensor_mask = ctx.args_tensor_mask
args_non_tensor = ctx.args_non_tensor
args = _merge_tensor_and_others(args_tree, args_tensor_mask, args_tensor, args_non_tensor)

args, kwargs = _extract_kwargs(kwarg_keys, args)

sol = res
ba_args, ba_kwargs, map_back = _signature_bind_and_match(
reference_signature, *args, **kwargs)
Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy