@@ -39,13 +39,13 @@ def __init__(
39
39
self ,
40
40
optimality_fn : Callable ,
41
41
solution : Any ,
42
- result_is_tensor : bool ,
42
+ output_is_tensor : bool ,
43
43
argnums : Tuple [int , ...],
44
44
* args ,
45
45
) -> None :
46
46
self .optimality_fn = optimality_fn
47
47
self .solution = solution
48
- self .result_is_tensor = result_is_tensor
48
+ self .output_is_tensor = output_is_tensor
49
49
self .argnums = argnums
50
50
51
51
pre_filled = []
@@ -69,7 +69,7 @@ def __call__(self, *args) -> Any:
69
69
arg = self .pre_filled [pre_filled_counter ]
70
70
pre_filled_counter += 1
71
71
true_args .append (arg )
72
- if self .result_is_tensor :
72
+ if self .output_is_tensor :
73
73
return self .optimality_fn (self .solution [0 ], * true_args )
74
74
return self .optimality_fn (self .solution , * true_args )
75
75
@@ -80,12 +80,12 @@ def _root_vjp(
80
80
solution : Any ,
81
81
args : Args ,
82
82
grad_outputs : Any ,
83
- result_is_tensor : bool ,
83
+ output_is_tensor : bool ,
84
84
argnums : Tuple [int , ...],
85
85
solve : Callable = linear_solve .solve_normal_cg (),
86
86
) -> Tuple [Any , ...]:
87
87
88
- if result_is_tensor :
88
+ if output_is_tensor :
89
89
90
90
def optimality_cond (solution ):
91
91
return optimality_fn (solution [0 ], * args )
@@ -98,7 +98,7 @@ def optimality_cond(solution):
98
98
_ , vjp_optimality_cond , * _ = functorch .vjp (optimality_cond , solution )
99
99
100
100
# Compute the multiplication A^T u = (u^T A)^T.
101
- if result_is_tensor :
101
+ if output_is_tensor :
102
102
103
103
def matvec (u ):
104
104
return vjp_optimality_cond (u [0 ])[0 ]
@@ -115,32 +115,32 @@ def matvec(u):
115
115
u = solve (matvec , v )
116
116
117
117
masked_optimality_fn = MaskedOptimalityFn (
118
- optimality_fn , solution , result_is_tensor , argnums , * args
118
+ optimality_fn , solution , output_is_tensor , argnums , * args
119
119
)
120
120
121
121
if getattr (solve , 'is_sdp' , False ):
122
- if result_is_tensor :
123
- result = u [0 ]
122
+ if output_is_tensor :
123
+ output = u [0 ]
124
124
else :
125
- result = u
125
+ output = u
126
126
else :
127
127
_ , vjp_optimality_fn , * _ = functorch .vjp (
128
128
masked_optimality_fn , * masked_optimality_fn .post_filled
129
129
)
130
130
131
- if result_is_tensor :
132
- result = vjp_optimality_fn (u [0 ])
131
+ if output_is_tensor :
132
+ output = vjp_optimality_fn (u [0 ])
133
133
else :
134
- result = vjp_optimality_fn (u )
134
+ output = vjp_optimality_fn (u )
135
135
136
- true_result = [None ]
136
+ true_output = [None ]
137
137
for idx in range (masked_optimality_fn .len_args ):
138
138
if idx + 1 in argnums : # plus 1 because we exclude the first argument
139
- true_result .append (result [idx ])
139
+ true_output .append (output [idx ])
140
140
else :
141
- true_result .append (None )
141
+ true_output .append (None )
142
142
143
- return tuple (true_result )
143
+ return tuple (true_output )
144
144
145
145
146
146
def _extract_kwargs (kwarg_keys : Sequence [str ], flat_args : Tuple [Any , ...]) -> Tuple [Args , KwArgs ]:
@@ -251,6 +251,8 @@ def make_custom_vjp_solver_fn(solver_fn, kwarg_keys, args_sign):
251
251
class ImplicitMetaGradient (Function ):
252
252
@staticmethod
253
253
def forward (ctx , * flat_args ): # pylint: disable=arguments-differ
254
+ output , aux , output_is_tensor = None , None , False
255
+
254
256
args = []
255
257
for idx , (start_point , is_tuple ) in enumerate (args_sign ):
256
258
if is_tuple :
@@ -260,7 +262,23 @@ def forward(ctx, *flat_args): # pylint: disable=arguments-differ
260
262
args = tuple (args )
261
263
262
264
args , kwargs = _extract_kwargs (kwarg_keys , args )
263
- res = solver_fn (* args , ** kwargs )
265
+ output = solver_fn (* args , ** kwargs )
266
+ if has_aux :
267
+ if not (isinstance (output , tuple ) and len (output ) == 2 ):
268
+ raise RuntimeError (
269
+ "custom_root(optimality_fn)(solver_fn)(*args): output of function "
270
+ "solver_fn should be a tuple: (output, aux) if has_aux is True"
271
+ )
272
+ output , aux = output
273
+ if isinstance (output , torch .Tensor ):
274
+ output_is_tensor = True
275
+ output = (output ,)
276
+ elif not (isinstance (output , tuple ) and all (map (torch .is_tensor , output ))):
277
+ raise RuntimeError (
278
+ "custom_root(optimality_fn)(solver_fn)(*args): output of function "
279
+ "solver_fn should be a torch.Tensor or a tuple of torch.Tensor"
280
+ )
281
+
264
282
(
265
283
args_treedef ,
266
284
args_is_tensor_mask ,
@@ -270,34 +288,19 @@ def forward(ctx, *flat_args): # pylint: disable=arguments-differ
270
288
ctx .args_treedef = args_treedef
271
289
ctx .args_is_tensor_mask = args_is_tensor_mask
272
290
ctx .args_non_tensors = args_non_tensors
273
- if has_aux :
274
- res , aux = res
275
- if torch .is_tensor (res ):
276
- ctx .save_for_backward (res , * args_tensors )
277
- ctx .result_is_tensor = True
278
- return (res , aux , True , torch .tensor )
279
-
280
- ctx .save_for_backward (* res , * args_tensors )
281
- ctx .result_is_tensor = False
282
- return (* res , aux , False , type (res ))
283
-
284
- if isinstance (res , torch .Tensor ):
285
- ctx .save_for_backward (res , * args_tensors )
286
- else :
287
- ctx .save_for_backward (* res , * args_tensors )
288
- ctx .result_is_tensor = isinstance (res , torch .Tensor )
289
- return res
291
+
292
+ ctx .save_for_backward (* output , * args_tensors )
293
+ ctx .output_is_tensor = output_is_tensor
294
+
295
+ return (* output , aux , output_is_tensor , type (output ))
290
296
291
297
@staticmethod
292
298
def backward (ctx , * grad_outputs ): # pylint: disable=too-many-locals
293
- if has_aux :
294
- grad_outputs = grad_outputs [:- 3 ]
299
+ grad_outputs = grad_outputs [:- 3 ]
295
300
296
301
saved_tensors = ctx .saved_tensors
297
- res , args_tensors = (
298
- saved_tensors [: len (grad_outputs )],
299
- saved_tensors [len (grad_outputs ) :],
300
- )
302
+ output = saved_tensors [: len (grad_outputs )]
303
+ args_tensors = saved_tensors [len (grad_outputs ) :]
301
304
args_treedef = ctx .args_treedef
302
305
args_is_tensor_mask = ctx .args_is_tensor_mask
303
306
args_non_tensors = ctx .args_non_tensors
@@ -307,7 +310,6 @@ def backward(ctx, *grad_outputs): # pylint: disable=too-many-locals
307
310
308
311
args , kwargs = _extract_kwargs (kwarg_keys , args )
309
312
310
- solution = res
311
313
bound_args , bound_kwargs , map_args_back = _signature_bind_and_match (
312
314
reference_signature , * args , ** kwargs # type: ignore[arg-type]
313
315
)
@@ -323,10 +325,10 @@ def backward(ctx, *grad_outputs): # pylint: disable=too-many-locals
323
325
# Compute VJPs w.r.t. args.
324
326
vjps = _root_vjp (
325
327
optimality_fn = optimality_fn ,
326
- solution = solution ,
328
+ solution = output ,
327
329
args = bound_args [1 :],
328
330
grad_outputs = grad_outputs ,
329
- result_is_tensor = ctx .result_is_tensor ,
331
+ output_is_tensor = ctx .output_is_tensor ,
330
332
argnums = argnums ,
331
333
solve = solve ,
332
334
)
@@ -374,20 +376,21 @@ def wrapped_solver_fn(*args, **kwargs):
374
376
flat_args = tuple (flat_args )
375
377
376
378
result = make_custom_vjp_solver_fn (solver_fn , keys , args_sign ).apply (* flat_args , * vals )
379
+ * output , aux , output_is_tensor , output_type = result
380
+ if output_is_tensor :
381
+ output = output [0 ]
382
+ else :
383
+ output = output_type (output )
377
384
if has_aux :
378
- * res , aux , result_is_tensor , res_type = result
379
- if result_is_tensor :
380
- return res [0 ], aux
381
- res = res_type (res )
382
- return res , aux
383
- return result
385
+ return output , aux
386
+ return output
384
387
385
388
return wrapped_solver_fn
386
389
387
390
388
391
def custom_root (
389
392
optimality_fn : Callable ,
390
- argnums : Union [int , Tuple [int , ...]] = 0 ,
393
+ argnums : Union [int , Tuple [int , ...]],
391
394
has_aux : bool = False ,
392
395
solve : Callable = linear_solve .solve_normal_cg (),
393
396
) -> Callable [[Callable ], Callable ]:
@@ -417,7 +420,7 @@ def solver_fn(params, arg1, arg2, ...):
417
420
optimality_fn: (callable)
418
421
An equation function, ``optimality_fn(params, *args)``. The invariant is
419
422
``optimality_fn(solution, *args) == 0`` at the solution / root of ``solution``.
420
- argnums: (int or tuple of int, default: :const:`0` )
423
+ argnums: (int or tuple of ints )
421
424
Specifies arguments to compute gradients with respect to. The ``argnums`` can be an
422
425
integer or a tuple of integers, which respect to the zero-based indices of the arguments
423
426
of the ``solver_fn(params, *args)`` function. The argument ``params`` is included
0 commit comments