Skip to content

Commit 4a520c6

Browse files
committed
refactor: rename variables
1 parent 429e06f commit 4a520c6

File tree

1 file changed

+28
-28
lines changed

1 file changed

+28
-28
lines changed

torchopt/diff/implicit/nn/module.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@
3131
__all__ = ['ImplicitMetaGradientModule']
3232

3333

34-
def make_residual_from_objective(
34+
def make_optimality_from_objective(
3535
objective: Callable[..., torch.Tensor]
3636
) -> Callable[..., TupleOfTensors]:
37-
"""Make a function that computes the optimality residual of the objective function."""
37+
"""Make a function that computes the optimality function of the objective function."""
3838
# pylint: disable-next=redefined-builtin
39-
def residual(self: 'ImplicitMetaGradientModule', *input, **kwargs) -> TupleOfTensors:
39+
def optimality(self: 'ImplicitMetaGradientModule', *input, **kwargs) -> TupleOfTensors:
4040
params_containers = extract_module_containers(self, with_buffers=False)[0]
4141
params_containers_backups = [container.copy() for container in params_containers]
4242
flat_params: TupleOfTensors
@@ -69,7 +69,7 @@ def objective_fn(flat_params: TupleOfTensors, *input, **kwargs) -> torch.Tensor:
6969
flat_grads = objective_grad_fn(flat_params, *input, **kwargs)
7070
return flat_grads
7171

72-
return residual
72+
return optimality
7373

7474

7575
def enable_implicit_gradients(
@@ -131,7 +131,7 @@ def optimality_fn(
131131
):
132132
container.update(grad_tracking_container)
133133

134-
return self.residual(*input, **kwargs)
134+
return self.optimality(*input, **kwargs)
135135
finally:
136136
for container, container_backup in itertools.chain(
137137
zip(params_containers, params_containers_backups),
@@ -160,28 +160,28 @@ def solve_fn(
160160
class ImplicitMetaGradientModule(torchopt.nn.MetaGradientModule):
161161
"""The base class for differentiable implicit meta-gradient models."""
162162

163-
_custom_residual: bool
163+
_custom_optimality: bool
164164
_custom_objective: bool
165165

166166
def __init_subclass__(cls) -> None:
167167
"""Initialize the subclass."""
168168
super().__init_subclass__()
169169

170-
residual = getattr(cls, 'residual', ImplicitMetaGradientModule.residual)
170+
optimality = getattr(cls, 'optimality', ImplicitMetaGradientModule.optimality)
171171
objective = getattr(cls, 'objective', ImplicitMetaGradientModule.objective)
172-
cls._custom_residual = residual is not ImplicitMetaGradientModule.residual
172+
cls._custom_optimality = optimality is not ImplicitMetaGradientModule.optimality
173173
cls._custom_objective = objective is not ImplicitMetaGradientModule.objective
174174

175-
if cls._custom_residual:
176-
if isinstance(residual, staticmethod):
177-
raise TypeError('residual() must not be a staticmethod.')
178-
if isinstance(residual, classmethod):
179-
raise TypeError('residual() must not be a classmethod.')
180-
if not callable(residual):
181-
raise TypeError('residual() must be callable.')
175+
if cls._custom_optimality:
176+
if isinstance(optimality, staticmethod):
177+
raise TypeError('optimality() must not be a staticmethod.')
178+
if isinstance(optimality, classmethod):
179+
raise TypeError('optimality() must not be a classmethod.')
180+
if not callable(optimality):
181+
raise TypeError('optimality() must be callable.')
182182
elif not cls._custom_objective:
183183
raise TypeError(
184-
'ImplicitMetaGradientModule requires either an residual() or an objective() function'
184+
'ImplicitMetaGradientModule requires either an optimality() or an objective() function'
185185
)
186186
else:
187187
if isinstance(objective, staticmethod):
@@ -191,7 +191,7 @@ def __init_subclass__(cls) -> None:
191191
if not callable(objective):
192192
raise TypeError('objective() must be callable.')
193193

194-
cls.residual = make_residual_from_objective(objective) # type: ignore[assignment]
194+
cls.optimality = make_optimality_from_objective(objective) # type: ignore[assignment]
195195

196196
cls.solve = enable_implicit_gradients(cls.solve) # type: ignore[assignment]
197197

@@ -228,53 +228,53 @@ def solve(self, batch, labels):
228228
raise NotImplementedError # update parameters
229229

230230
# pylint: disable-next=redefined-builtin
231-
def residual(self, *input, **kwargs) -> TensorTree:
231+
def optimality(self, *input, **kwargs) -> TensorTree:
232232
r"""Computes the optimality residual.
233233
234-
This method stands for the residual to the optimal parameters after solving the inner
235-
optimization problem (:meth:`solve`), i.e.:
234+
This method stands for the optimality residual to the optimal parameters after solving the
235+
inner optimization problem (:meth:`solve`), i.e.:
236236
237237
.. code-block:: python
238238
239239
module.solve(*input, **kwargs)
240-
module.residual(*input, **kwargs) # -> 0
240+
module.optimality(*input, **kwargs) # -> 0
241241
242-
1. For gradient-based optimization, the :meth:`residual` function is the KKT condition,
242+
1. For gradient-based optimization, the :meth:`optimality` function is the KKT condition,
243243
usually it is the gradients of the :meth:`objective` function with respect to the module
244244
parameters (not the meta-parameters). If this method is not implemented, it will be
245245
automatically derived from the gradient of the :meth:`objective` function.
246246
247247
.. math::
248248
249-
\text{residual} = \nabla_{\boldsymbol{x}} f (\boldsymbol{x}, \boldsymbol{\theta}) \to \boldsymbol{0}
249+
\text{optimality residual} = \nabla_{\boldsymbol{x}} f (\boldsymbol{x}, \boldsymbol{\theta}) \to \boldsymbol{0}
250250
251251
where :math:`\boldsymbol{x}` is the joint vector of the module parameters and
252252
:math:`\boldsymbol{\theta}` is the joint vector of the meta-parameters.
253253
254254
References:
255255
- Karush-Kuhn-Tucker (KKT) conditions: https://en.wikipedia.org/wiki/Karush-Kuhn-Tucker_conditions
256256
257-
2. For fixed point iteration, the :meth:`residual` function can be the residual of the
257+
2. For fixed point iteration, the :meth:`optimality` function can be the residual of the
258258
parameters between iterations, i.e.:
259259
260260
.. math::
261261
262-
\text{residual} = f (\boldsymbol{x}, \boldsymbol{\theta}) - \boldsymbol{x} \to \boldsymbol{0}
262+
\text{optimality residual} = f (\boldsymbol{x}, \boldsymbol{\theta}) - \boldsymbol{x} \to \boldsymbol{0}
263263
264264
where :math:`\boldsymbol{x}` is the joint vector of the module parameters and
265265
:math:`\boldsymbol{\theta}` is the joint vector of the meta-parameters.
266266
267267
Returns:
268-
A tree of tensors, the residual to the optimal parameters after solving the inner
269-
optimization problem.
268+
A tree of tensors, the optimality residual to the optimal parameters after solving the
269+
inner optimization problem.
270270
"""
271271
raise NotImplementedError
272272

273273
# pylint: disable-next=redefined-builtin
274274
def objective(self, *input, **kwargs) -> torch.Tensor:
275275
"""Computes the objective function value.
276276
277-
This method is used to calculate the :meth:`residual` if it is not implemented.
277+
This method is used to calculate the :meth:`optimality` if it is not implemented.
278278
Otherwise, this method is optional.
279279
280280
Returns:

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy