31
31
__all__ = ['ImplicitMetaGradientModule' ]
32
32
33
33
34
- def make_residual_from_objective (
34
+ def make_optimality_from_objective (
35
35
objective : Callable [..., torch .Tensor ]
36
36
) -> Callable [..., TupleOfTensors ]:
37
- """Make a function that computes the optimality residual of the objective function."""
37
+ """Make a function that computes the optimality function of the objective function."""
38
38
# pylint: disable-next=redefined-builtin
39
- def residual (self : 'ImplicitMetaGradientModule' , * input , ** kwargs ) -> TupleOfTensors :
39
+ def optimality (self : 'ImplicitMetaGradientModule' , * input , ** kwargs ) -> TupleOfTensors :
40
40
params_containers = extract_module_containers (self , with_buffers = False )[0 ]
41
41
params_containers_backups = [container .copy () for container in params_containers ]
42
42
flat_params : TupleOfTensors
@@ -69,7 +69,7 @@ def objective_fn(flat_params: TupleOfTensors, *input, **kwargs) -> torch.Tensor:
69
69
flat_grads = objective_grad_fn (flat_params , * input , ** kwargs )
70
70
return flat_grads
71
71
72
- return residual
72
+ return optimality
73
73
74
74
75
75
def enable_implicit_gradients (
@@ -131,7 +131,7 @@ def optimality_fn(
131
131
):
132
132
container .update (grad_tracking_container )
133
133
134
- return self .residual (* input , ** kwargs )
134
+ return self .optimality (* input , ** kwargs )
135
135
finally :
136
136
for container , container_backup in itertools .chain (
137
137
zip (params_containers , params_containers_backups ),
@@ -160,28 +160,28 @@ def solve_fn(
160
160
class ImplicitMetaGradientModule (torchopt .nn .MetaGradientModule ):
161
161
"""The base class for differentiable implicit meta-gradient models."""
162
162
163
- _custom_residual : bool
163
+ _custom_optimality : bool
164
164
_custom_objective : bool
165
165
166
166
def __init_subclass__ (cls ) -> None :
167
167
"""Initialize the subclass."""
168
168
super ().__init_subclass__ ()
169
169
170
- residual = getattr (cls , 'residual ' , ImplicitMetaGradientModule .residual )
170
+ optimality = getattr (cls , 'optimality ' , ImplicitMetaGradientModule .optimality )
171
171
objective = getattr (cls , 'objective' , ImplicitMetaGradientModule .objective )
172
- cls ._custom_residual = residual is not ImplicitMetaGradientModule .residual
172
+ cls ._custom_optimality = optimality is not ImplicitMetaGradientModule .optimality
173
173
cls ._custom_objective = objective is not ImplicitMetaGradientModule .objective
174
174
175
- if cls ._custom_residual :
176
- if isinstance (residual , staticmethod ):
177
- raise TypeError ('residual () must not be a staticmethod.' )
178
- if isinstance (residual , classmethod ):
179
- raise TypeError ('residual () must not be a classmethod.' )
180
- if not callable (residual ):
181
- raise TypeError ('residual () must be callable.' )
175
+ if cls ._custom_optimality :
176
+ if isinstance (optimality , staticmethod ):
177
+ raise TypeError ('optimality () must not be a staticmethod.' )
178
+ if isinstance (optimality , classmethod ):
179
+ raise TypeError ('optimality () must not be a classmethod.' )
180
+ if not callable (optimality ):
181
+ raise TypeError ('optimality () must be callable.' )
182
182
elif not cls ._custom_objective :
183
183
raise TypeError (
184
- 'ImplicitMetaGradientModule requires either an residual () or an objective() function'
184
+ 'ImplicitMetaGradientModule requires either an optimality () or an objective() function'
185
185
)
186
186
else :
187
187
if isinstance (objective , staticmethod ):
@@ -191,7 +191,7 @@ def __init_subclass__(cls) -> None:
191
191
if not callable (objective ):
192
192
raise TypeError ('objective() must be callable.' )
193
193
194
- cls .residual = make_residual_from_objective (objective ) # type: ignore[assignment]
194
+ cls .optimality = make_optimality_from_objective (objective ) # type: ignore[assignment]
195
195
196
196
cls .solve = enable_implicit_gradients (cls .solve ) # type: ignore[assignment]
197
197
@@ -228,53 +228,53 @@ def solve(self, batch, labels):
228
228
raise NotImplementedError # update parameters
229
229
230
230
# pylint: disable-next=redefined-builtin
231
- def residual (self , * input , ** kwargs ) -> TensorTree :
231
+ def optimality (self , * input , ** kwargs ) -> TensorTree :
232
232
r"""Computes the optimality residual.
233
233
234
- This method stands for the residual to the optimal parameters after solving the inner
235
- optimization problem (:meth:`solve`), i.e.:
234
+ This method stands for the optimality residual to the optimal parameters after solving the
235
+ inner optimization problem (:meth:`solve`), i.e.:
236
236
237
237
.. code-block:: python
238
238
239
239
module.solve(*input, **kwargs)
240
- module.residual (*input, **kwargs) # -> 0
240
+ module.optimality (*input, **kwargs) # -> 0
241
241
242
- 1. For gradient-based optimization, the :meth:`residual ` function is the KKT condition,
242
+ 1. For gradient-based optimization, the :meth:`optimality ` function is the KKT condition,
243
243
usually it is the gradients of the :meth:`objective` function with respect to the module
244
244
parameters (not the meta-parameters). If this method is not implemented, it will be
245
245
automatically derived from the gradient of the :meth:`objective` function.
246
246
247
247
.. math::
248
248
249
- \text{residual} = \nabla_{\boldsymbol{x}} f (\boldsymbol{x}, \boldsymbol{\theta}) \to \boldsymbol{0}
249
+ \text{optimality residual} = \nabla_{\boldsymbol{x}} f (\boldsymbol{x}, \boldsymbol{\theta}) \to \boldsymbol{0}
250
250
251
251
where :math:`\boldsymbol{x}` is the joint vector of the module parameters and
252
252
:math:`\boldsymbol{\theta}` is the joint vector of the meta-parameters.
253
253
254
254
References:
255
255
- Karush-Kuhn-Tucker (KKT) conditions: https://en.wikipedia.org/wiki/Karush-Kuhn-Tucker_conditions
256
256
257
- 2. For fixed point iteration, the :meth:`residual ` function can be the residual of the
257
+ 2. For fixed point iteration, the :meth:`optimality ` function can be the residual of the
258
258
parameters between iterations, i.e.:
259
259
260
260
.. math::
261
261
262
- \text{residual} = f (\boldsymbol{x}, \boldsymbol{\theta}) - \boldsymbol{x} \to \boldsymbol{0}
262
+ \text{optimality residual} = f (\boldsymbol{x}, \boldsymbol{\theta}) - \boldsymbol{x} \to \boldsymbol{0}
263
263
264
264
where :math:`\boldsymbol{x}` is the joint vector of the module parameters and
265
265
:math:`\boldsymbol{\theta}` is the joint vector of the meta-parameters.
266
266
267
267
Returns:
268
- A tree of tensors, the residual to the optimal parameters after solving the inner
269
- optimization problem.
268
+ A tree of tensors, the optimality residual to the optimal parameters after solving the
269
+ inner optimization problem.
270
270
"""
271
271
raise NotImplementedError
272
272
273
273
# pylint: disable-next=redefined-builtin
274
274
def objective (self , * input , ** kwargs ) -> torch .Tensor :
275
275
"""Computes the objective function value.
276
276
277
- This method is used to calculate the :meth:`residual ` if it is not implemented.
277
+ This method is used to calculate the :meth:`optimality ` if it is not implemented.
278
278
Otherwise, this method is optional.
279
279
280
280
Returns:
0 commit comments