Skip to content

Commit 089a8cd

Browse files
committed
wip
1 parent 6604625 commit 089a8cd

File tree

8 files changed

+772
-61
lines changed

8 files changed

+772
-61
lines changed

torchopt/diff/implicit/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,9 @@
1414
# ==============================================================================
1515
"""Implicit Meta-Gradient."""
1616

17+
from torchopt.diff.implicit import nn
1718
from torchopt.diff.implicit.decorator import custom_root
19+
from torchopt.diff.implicit.nn import ImplicitMetaGradientModule
20+
21+
22+
__all__ = ['custom_root', 'ImplicitMetaGradientModule']

torchopt/diff/implicit/decorator.py

Lines changed: 55 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,13 @@ def __init__(
3939
self,
4040
optimality_fn: Callable,
4141
solution: Any,
42-
result_is_tensor: bool,
42+
output_is_tensor: bool,
4343
argnums: Tuple[int, ...],
4444
*args,
4545
) -> None:
4646
self.optimality_fn = optimality_fn
4747
self.solution = solution
48-
self.result_is_tensor = result_is_tensor
48+
self.output_is_tensor = output_is_tensor
4949
self.argnums = argnums
5050

5151
pre_filled = []
@@ -69,7 +69,7 @@ def __call__(self, *args) -> Any:
6969
arg = self.pre_filled[pre_filled_counter]
7070
pre_filled_counter += 1
7171
true_args.append(arg)
72-
if self.result_is_tensor:
72+
if self.output_is_tensor:
7373
return self.optimality_fn(self.solution[0], *true_args)
7474
return self.optimality_fn(self.solution, *true_args)
7575

@@ -80,12 +80,12 @@ def _root_vjp(
8080
solution: Any,
8181
args: Args,
8282
grad_outputs: Any,
83-
result_is_tensor: bool,
83+
output_is_tensor: bool,
8484
argnums: Tuple[int, ...],
8585
solve: Callable = linear_solve.solve_normal_cg(),
8686
) -> Tuple[Any, ...]:
8787

88-
if result_is_tensor:
88+
if output_is_tensor:
8989

9090
def optimality_cond(solution):
9191
return optimality_fn(solution[0], *args)
@@ -98,7 +98,7 @@ def optimality_cond(solution):
9898
_, vjp_optimality_cond, *_ = functorch.vjp(optimality_cond, solution)
9999

100100
# Compute the multiplication A^T u = (u^T A)^T.
101-
if result_is_tensor:
101+
if output_is_tensor:
102102

103103
def matvec(u):
104104
return vjp_optimality_cond(u[0])[0]
@@ -115,32 +115,32 @@ def matvec(u):
115115
u = solve(matvec, v)
116116

117117
masked_optimality_fn = MaskedOptimalityFn(
118-
optimality_fn, solution, result_is_tensor, argnums, *args
118+
optimality_fn, solution, output_is_tensor, argnums, *args
119119
)
120120

121121
if getattr(solve, 'is_sdp', False):
122-
if result_is_tensor:
123-
result = u[0]
122+
if output_is_tensor:
123+
output = u[0]
124124
else:
125-
result = u
125+
output = u
126126
else:
127127
_, vjp_optimality_fn, *_ = functorch.vjp(
128128
masked_optimality_fn, *masked_optimality_fn.post_filled
129129
)
130130

131-
if result_is_tensor:
132-
result = vjp_optimality_fn(u[0])
131+
if output_is_tensor:
132+
output = vjp_optimality_fn(u[0])
133133
else:
134-
result = vjp_optimality_fn(u)
134+
output = vjp_optimality_fn(u)
135135

136-
true_result = [None]
136+
true_output = [None]
137137
for idx in range(masked_optimality_fn.len_args):
138138
if idx + 1 in argnums: # plus 1 because we exclude the first argument
139-
true_result.append(result[idx])
139+
true_output.append(output[idx])
140140
else:
141-
true_result.append(None)
141+
true_output.append(None)
142142

143-
return tuple(true_result)
143+
return tuple(true_output)
144144

145145

146146
def _extract_kwargs(kwarg_keys: Sequence[str], flat_args: Tuple[Any, ...]) -> Tuple[Args, KwArgs]:
@@ -251,6 +251,8 @@ def make_custom_vjp_solver_fn(solver_fn, kwarg_keys, args_sign):
251251
class ImplicitMetaGradient(Function):
252252
@staticmethod
253253
def forward(ctx, *flat_args): # pylint: disable=arguments-differ
254+
output, aux, output_is_tensor = None, None, False
255+
254256
args = []
255257
for idx, (start_point, is_tuple) in enumerate(args_sign):
256258
if is_tuple:
@@ -260,7 +262,23 @@ def forward(ctx, *flat_args): # pylint: disable=arguments-differ
260262
args = tuple(args)
261263

262264
args, kwargs = _extract_kwargs(kwarg_keys, args)
263-
res = solver_fn(*args, **kwargs)
265+
output = solver_fn(*args, **kwargs)
266+
if has_aux:
267+
if not (isinstance(output, tuple) and len(output) == 2):
268+
raise RuntimeError(
269+
"custom_root(optimality_fn)(solver_fn)(*args): output of function "
270+
"solver_fn should be a tuple: (output, aux) if has_aux is True"
271+
)
272+
output, aux = output
273+
if isinstance(output, torch.Tensor):
274+
output_is_tensor = True
275+
output = (output,)
276+
elif not (isinstance(output, tuple) and all(map(torch.is_tensor, output))):
277+
raise RuntimeError(
278+
"custom_root(optimality_fn)(solver_fn)(*args): output of function "
279+
"solver_fn should be a torch.Tensor or a tuple of torch.Tensor"
280+
)
281+
264282
(
265283
args_treedef,
266284
args_is_tensor_mask,
@@ -270,34 +288,19 @@ def forward(ctx, *flat_args): # pylint: disable=arguments-differ
270288
ctx.args_treedef = args_treedef
271289
ctx.args_is_tensor_mask = args_is_tensor_mask
272290
ctx.args_non_tensors = args_non_tensors
273-
if has_aux:
274-
res, aux = res
275-
if torch.is_tensor(res):
276-
ctx.save_for_backward(res, *args_tensors)
277-
ctx.result_is_tensor = True
278-
return (res, aux, True, torch.tensor)
279-
280-
ctx.save_for_backward(*res, *args_tensors)
281-
ctx.result_is_tensor = False
282-
return (*res, aux, False, type(res))
283-
284-
if isinstance(res, torch.Tensor):
285-
ctx.save_for_backward(res, *args_tensors)
286-
else:
287-
ctx.save_for_backward(*res, *args_tensors)
288-
ctx.result_is_tensor = isinstance(res, torch.Tensor)
289-
return res
291+
292+
ctx.save_for_backward(*output, *args_tensors)
293+
ctx.output_is_tensor = output_is_tensor
294+
295+
return (*output, aux, output_is_tensor, type(output))
290296

291297
@staticmethod
292298
def backward(ctx, *grad_outputs): # pylint: disable=too-many-locals
293-
if has_aux:
294-
grad_outputs = grad_outputs[:-3]
299+
grad_outputs = grad_outputs[:-3]
295300

296301
saved_tensors = ctx.saved_tensors
297-
res, args_tensors = (
298-
saved_tensors[: len(grad_outputs)],
299-
saved_tensors[len(grad_outputs) :],
300-
)
302+
output = saved_tensors[: len(grad_outputs)]
303+
args_tensors = saved_tensors[len(grad_outputs) :]
301304
args_treedef = ctx.args_treedef
302305
args_is_tensor_mask = ctx.args_is_tensor_mask
303306
args_non_tensors = ctx.args_non_tensors
@@ -307,7 +310,6 @@ def backward(ctx, *grad_outputs): # pylint: disable=too-many-locals
307310

308311
args, kwargs = _extract_kwargs(kwarg_keys, args)
309312

310-
solution = res
311313
bound_args, bound_kwargs, map_args_back = _signature_bind_and_match(
312314
reference_signature, *args, **kwargs # type: ignore[arg-type]
313315
)
@@ -323,10 +325,10 @@ def backward(ctx, *grad_outputs): # pylint: disable=too-many-locals
323325
# Compute VJPs w.r.t. args.
324326
vjps = _root_vjp(
325327
optimality_fn=optimality_fn,
326-
solution=solution,
328+
solution=output,
327329
args=bound_args[1:],
328330
grad_outputs=grad_outputs,
329-
result_is_tensor=ctx.result_is_tensor,
331+
output_is_tensor=ctx.output_is_tensor,
330332
argnums=argnums,
331333
solve=solve,
332334
)
@@ -374,20 +376,21 @@ def wrapped_solver_fn(*args, **kwargs):
374376
flat_args = tuple(flat_args)
375377

376378
result = make_custom_vjp_solver_fn(solver_fn, keys, args_sign).apply(*flat_args, *vals)
379+
*output, aux, output_is_tensor, output_type = result
380+
if output_is_tensor:
381+
output = output[0]
382+
else:
383+
output = output_type(output)
377384
if has_aux:
378-
*res, aux, result_is_tensor, res_type = result
379-
if result_is_tensor:
380-
return res[0], aux
381-
res = res_type(res)
382-
return res, aux
383-
return result
385+
return output, aux
386+
return output
384387

385388
return wrapped_solver_fn
386389

387390

388391
def custom_root(
389392
optimality_fn: Callable,
390-
argnums: Union[int, Tuple[int, ...]] = 0,
393+
argnums: Union[int, Tuple[int, ...]],
391394
has_aux: bool = False,
392395
solve: Callable = linear_solve.solve_normal_cg(),
393396
) -> Callable[[Callable], Callable]:
@@ -417,7 +420,7 @@ def solver_fn(params, arg1, arg2, ...):
417420
optimality_fn: (callable)
418421
An equation function, ``optimality_fn(params, *args)``. The invariant is
419422
``optimality_fn(solution, *args) == 0`` at the solution / root of ``solution``.
420-
argnums: (int or tuple of int, default: :const:`0`)
423+
argnums: (int or tuple of ints)
421424
Specifies arguments to compute gradients with respect to. The ``argnums`` can be an
422425
integer or a tuple of integers, which respect to the zero-based indices of the arguments
423426
of the ``solver_fn(params, *args)`` function. The argument ``params`` is included

torchopt/diff/implicit/nn/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 2022 MetaOPT Team. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""The base class for differentiable implicit meta-gradient models."""
16+
17+
from torchopt.diff.implicit.nn.module import ImplicitMetaGradientModule
18+
19+
20+
__all__ = ['ImplicitMetaGradientModule']

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy