|
36 | 36 |
|
37 | 37 |
|
38 | 38 | def _stateless_objective_fn(
|
39 |
| - __flat_params: TupleOfTensors, |
40 |
| - __flat_meta_params: TupleOfTensors, |
41 |
| - __params_names: Iterable[str], |
42 |
| - __meta_params_names: Iterable[str], |
| 39 | + flat_params: TupleOfTensors, |
| 40 | + flat_meta_params: TupleOfTensors, |
| 41 | + params_names: Iterable[str], |
| 42 | + meta_params_names: Iterable[str], |
43 | 43 | self: ImplicitMetaGradientModule,
|
| 44 | + /, |
44 | 45 | *input: Any,
|
45 | 46 | **kwargs: Any,
|
46 | 47 | ) -> torch.Tensor:
|
47 | 48 | with reparametrize(
|
48 | 49 | self,
|
49 | 50 | itertools.chain(
|
50 |
| - zip(__params_names, __flat_params), |
51 |
| - zip(__meta_params_names, __flat_meta_params), |
| 51 | + zip(params_names, flat_params), |
| 52 | + zip(meta_params_names, flat_meta_params), |
52 | 53 | ),
|
53 | 54 | ):
|
54 | 55 | return self.objective(*input, **kwargs)
|
55 | 56 |
|
56 | 57 |
|
57 | 58 | def _stateless_optimality_fn(
|
58 |
| - __flat_params: TupleOfTensors, |
59 |
| - __flat_meta_params: TupleOfTensors, |
60 |
| - __params_names: Iterable[str], |
61 |
| - __meta_params_names: Iterable[str], |
| 59 | + flat_params: TupleOfTensors, |
| 60 | + flat_meta_params: TupleOfTensors, |
| 61 | + params_names: Iterable[str], |
| 62 | + meta_params_names: Iterable[str], |
62 | 63 | self: ImplicitMetaGradientModule,
|
| 64 | + /, |
63 | 65 | *input: Any,
|
64 | 66 | **kwargs: Any,
|
65 | 67 | ) -> TupleOfTensors:
|
66 | 68 | with reparametrize(
|
67 | 69 | self,
|
68 | 70 | itertools.chain(
|
69 |
| - zip(__params_names, __flat_params), |
70 |
| - zip(__meta_params_names, __flat_meta_params), |
| 71 | + zip(params_names, flat_params), |
| 72 | + zip(meta_params_names, flat_meta_params), |
71 | 73 | ),
|
72 | 74 | ):
|
73 | 75 | return self.optimality(*input, **kwargs)
|
@@ -121,12 +123,13 @@ def enable_implicit_gradients(
|
121 | 123 | @custom_root(_stateless_optimality_fn, argnums=1, has_aux=True, **solve_kwargs)
|
122 | 124 | def stateless_solver_fn(
|
123 | 125 | # pylint: disable=unused-argument
|
124 |
| - __flat_params: TupleOfTensors, |
125 |
| - __flat_meta_params: TupleOfTensors, |
126 |
| - __params_names: Iterable[str], |
127 |
| - __meta_params_names: Iterable[str], |
| 126 | + flat_params: TupleOfTensors, |
| 127 | + flat_meta_params: TupleOfTensors, |
| 128 | + params_names: Iterable[str], |
| 129 | + meta_params_names: Iterable[str], |
128 | 130 | # pylint: enable=unused-argument
|
129 | 131 | self: ImplicitMetaGradientModule,
|
| 132 | + /, |
130 | 133 | *input: Any,
|
131 | 134 | **kwargs: Any,
|
132 | 135 | ) -> tuple[TupleOfTensors, Any]:
|
|
0 commit comments