From 955446bf41497b284c29f15c5f0ee4b0fba8f966 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 13 Oct 2022 23:46:01 +0800 Subject: [PATCH 01/11] feat(nn.module): add object-oriented modules support for implicit meta-gradient --- torchopt/diff/implicit/__init__.py | 5 + torchopt/diff/implicit/decorator.py | 107 +++--- torchopt/diff/implicit/nn/__init__.py | 20 ++ torchopt/diff/implicit/nn/module.py | 292 ++++++++++++++++ torchopt/nn/__init__.py | 17 + torchopt/nn/module.py | 458 ++++++++++++++++++++++++++ torchopt/optim/meta/base.py | 6 +- torchopt/utils.py | 11 +- 8 files changed, 855 insertions(+), 61 deletions(-) create mode 100644 torchopt/diff/implicit/nn/__init__.py create mode 100644 torchopt/diff/implicit/nn/module.py create mode 100644 torchopt/nn/__init__.py create mode 100644 torchopt/nn/module.py diff --git a/torchopt/diff/implicit/__init__.py b/torchopt/diff/implicit/__init__.py index 2f3f184c..4e50b615 100644 --- a/torchopt/diff/implicit/__init__.py +++ b/torchopt/diff/implicit/__init__.py @@ -14,4 +14,9 @@ # ============================================================================== """Implicit Meta-Gradient.""" +from torchopt.diff.implicit import nn from torchopt.diff.implicit.decorator import custom_root +from torchopt.diff.implicit.nn import ImplicitMetaGradientModule + + +__all__ = ['custom_root', 'ImplicitMetaGradientModule'] diff --git a/torchopt/diff/implicit/decorator.py b/torchopt/diff/implicit/decorator.py index 7fdf58a8..8a5a7f52 100644 --- a/torchopt/diff/implicit/decorator.py +++ b/torchopt/diff/implicit/decorator.py @@ -39,13 +39,13 @@ def __init__( self, optimality_fn: Callable, solution: Any, - result_is_tensor: bool, + output_is_tensor: bool, argnums: Tuple[int, ...], *args, ) -> None: self.optimality_fn = optimality_fn self.solution = solution - self.result_is_tensor = result_is_tensor + self.output_is_tensor = output_is_tensor self.argnums = argnums pre_filled = [] @@ -69,7 +69,7 @@ def __call__(self, *args) -> Any: arg = self.pre_filled[pre_filled_counter] pre_filled_counter += 1 true_args.append(arg) - if self.result_is_tensor: + if self.output_is_tensor: return self.optimality_fn(self.solution[0], *true_args) return self.optimality_fn(self.solution, *true_args) @@ -80,12 +80,12 @@ def _root_vjp( solution: Any, args: Args, grad_outputs: Any, - result_is_tensor: bool, + output_is_tensor: bool, argnums: Tuple[int, ...], solve: Callable = linear_solve.solve_normal_cg(), ) -> Tuple[Any, ...]: - if result_is_tensor: + if output_is_tensor: def optimality_cond(solution): return optimality_fn(solution[0], *args) @@ -98,7 +98,7 @@ def optimality_cond(solution): _, vjp_optimality_cond, *_ = functorch.vjp(optimality_cond, solution) # Compute the multiplication A^T u = (u^T A)^T. - if result_is_tensor: + if output_is_tensor: def matvec(u): return vjp_optimality_cond(u[0])[0] @@ -115,32 +115,32 @@ def matvec(u): u = solve(matvec, v) masked_optimality_fn = MaskedOptimalityFn( - optimality_fn, solution, result_is_tensor, argnums, *args + optimality_fn, solution, output_is_tensor, argnums, *args ) if getattr(solve, 'is_sdp', False): - if result_is_tensor: - result = u[0] + if output_is_tensor: + output = u[0] else: - result = u + output = u else: _, vjp_optimality_fn, *_ = functorch.vjp( masked_optimality_fn, *masked_optimality_fn.post_filled ) - if result_is_tensor: - result = vjp_optimality_fn(u[0]) + if output_is_tensor: + output = vjp_optimality_fn(u[0]) else: - result = vjp_optimality_fn(u) + output = vjp_optimality_fn(u) - true_result = [None] + true_output = [None] for idx in range(masked_optimality_fn.len_args): if idx + 1 in argnums: # plus 1 because we exclude the first argument - true_result.append(result[idx]) + true_output.append(output[idx]) else: - true_result.append(None) + true_output.append(None) - return tuple(true_result) + return tuple(true_output) def _extract_kwargs(kwarg_keys: Sequence[str], flat_args: Tuple[Any, ...]) -> Tuple[Args, KwArgs]: @@ -251,6 +251,8 @@ def make_custom_vjp_solver_fn(solver_fn, kwarg_keys, args_sign): class ImplicitMetaGradient(Function): @staticmethod def forward(ctx, *flat_args): # pylint: disable=arguments-differ + output, aux, output_is_tensor = None, None, False + args = [] for idx, (start_point, is_tuple) in enumerate(args_sign): if is_tuple: @@ -260,7 +262,23 @@ def forward(ctx, *flat_args): # pylint: disable=arguments-differ args = tuple(args) args, kwargs = _extract_kwargs(kwarg_keys, args) - res = solver_fn(*args, **kwargs) + output = solver_fn(*args, **kwargs) + if has_aux: + if not (isinstance(output, tuple) and len(output) == 2): + raise RuntimeError( + 'custom_root(optimality_fn)(solver_fn)(*args): output of function ' + 'solver_fn should be a tuple: (output, aux) if has_aux is True' + ) + output, aux = output + if isinstance(output, torch.Tensor): + output_is_tensor = True + output = (output,) + elif not (isinstance(output, tuple) and all(map(torch.is_tensor, output))): + raise RuntimeError( + 'custom_root(optimality_fn)(solver_fn)(*args): output of function ' + 'solver_fn should be a torch.Tensor or a tuple of torch.Tensor' + ) + ( args_treedef, args_is_tensor_mask, @@ -270,34 +288,19 @@ def forward(ctx, *flat_args): # pylint: disable=arguments-differ ctx.args_treedef = args_treedef ctx.args_is_tensor_mask = args_is_tensor_mask ctx.args_non_tensors = args_non_tensors - if has_aux: - res, aux = res - if torch.is_tensor(res): - ctx.save_for_backward(res, *args_tensors) - ctx.result_is_tensor = True - return (res, aux, True, torch.tensor) - - ctx.save_for_backward(*res, *args_tensors) - ctx.result_is_tensor = False - return (*res, aux, False, type(res)) - - if isinstance(res, torch.Tensor): - ctx.save_for_backward(res, *args_tensors) - else: - ctx.save_for_backward(*res, *args_tensors) - ctx.result_is_tensor = isinstance(res, torch.Tensor) - return res + + ctx.save_for_backward(*output, *args_tensors) + ctx.output_is_tensor = output_is_tensor + + return (*output, aux, output_is_tensor, type(output)) @staticmethod def backward(ctx, *grad_outputs): # pylint: disable=too-many-locals - if has_aux: - grad_outputs = grad_outputs[:-3] + grad_outputs = grad_outputs[:-3] saved_tensors = ctx.saved_tensors - res, args_tensors = ( - saved_tensors[: len(grad_outputs)], - saved_tensors[len(grad_outputs) :], - ) + output = saved_tensors[: len(grad_outputs)] + args_tensors = saved_tensors[len(grad_outputs) :] args_treedef = ctx.args_treedef args_is_tensor_mask = ctx.args_is_tensor_mask args_non_tensors = ctx.args_non_tensors @@ -307,7 +310,6 @@ def backward(ctx, *grad_outputs): # pylint: disable=too-many-locals args, kwargs = _extract_kwargs(kwarg_keys, args) - solution = res bound_args, bound_kwargs, map_args_back = _signature_bind_and_match( reference_signature, *args, **kwargs # type: ignore[arg-type] ) @@ -323,10 +325,10 @@ def backward(ctx, *grad_outputs): # pylint: disable=too-many-locals # Compute VJPs w.r.t. args. vjps = _root_vjp( optimality_fn=optimality_fn, - solution=solution, + solution=output, args=bound_args[1:], grad_outputs=grad_outputs, - result_is_tensor=ctx.result_is_tensor, + output_is_tensor=ctx.output_is_tensor, argnums=argnums, solve=solve, ) @@ -374,20 +376,21 @@ def wrapped_solver_fn(*args, **kwargs): flat_args = tuple(flat_args) result = make_custom_vjp_solver_fn(solver_fn, keys, args_sign).apply(*flat_args, *vals) + *output, aux, output_is_tensor, output_type = result + if output_is_tensor: + output = output[0] + else: + output = output_type(output) if has_aux: - *res, aux, result_is_tensor, res_type = result - if result_is_tensor: - return res[0], aux - res = res_type(res) - return res, aux - return result + return output, aux + return output return wrapped_solver_fn def custom_root( optimality_fn: Callable, - argnums: Union[int, Tuple[int, ...]] = 0, + argnums: Union[int, Tuple[int, ...]], has_aux: bool = False, solve: Callable = linear_solve.solve_normal_cg(), ) -> Callable[[Callable], Callable]: @@ -417,7 +420,7 @@ def solver_fn(params, arg1, arg2, ...): optimality_fn: (callable) An equation function, ``optimality_fn(params, *args)``. The invariant is ``optimality_fn(solution, *args) == 0`` at the solution / root of ``solution``. - argnums: (int or tuple of int, default: :const:`0`) + argnums: (int or tuple of ints) Specifies arguments to compute gradients with respect to. The ``argnums`` can be an integer or a tuple of integers, which respect to the zero-based indices of the arguments of the ``solver_fn(params, *args)`` function. The argument ``params`` is included diff --git a/torchopt/diff/implicit/nn/__init__.py b/torchopt/diff/implicit/nn/__init__.py new file mode 100644 index 00000000..6d30ffc2 --- /dev/null +++ b/torchopt/diff/implicit/nn/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The base class for differentiable implicit meta-gradient models.""" + +from torchopt.diff.implicit.nn.module import ImplicitMetaGradientModule + + +__all__ = ['ImplicitMetaGradientModule'] diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py new file mode 100644 index 00000000..7cda4c80 --- /dev/null +++ b/torchopt/diff/implicit/nn/module.py @@ -0,0 +1,292 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The base class for differentiable implicit meta-gradient models.""" + +import functools +import itertools +from typing import Callable, Dict, List, Optional, Tuple + +import functorch +import torch + +import torchopt.nn +from torchopt import pytree +from torchopt.diff.implicit.decorator import custom_root +from torchopt.typing import TensorTree +from torchopt.utils import extract_module_containers + + +__all__ = ['ImplicitMetaGradientModule'] + + +def make_residual_from_objective( + objective: Callable[..., torch.Tensor] +) -> Callable[..., Tuple[torch.Tensor, ...]]: + """Make a function that computes the optimality residual of the objective function.""" + # pylint: disable-next=redefined-builtin + def residual(self: 'ImplicitMetaGradientModule', *input, **kwargs) -> Tuple[torch.Tensor, ...]: + params_containers = extract_module_containers(self, with_buffers=False)[0] + params_containers_backups = [container.copy() for container in params_containers] + flat_params_or_none: List[Optional[torch.Tensor]] + flat_params_or_none, params_containers_treedef = pytree.tree_flatten( + params_containers # type: ignore[arg-type] + ) + + # pylint: disable-next=redefined-builtin + def objective_fn(flat_params: Tuple[torch.Tensor, ...], *input, **kwargs) -> torch.Tensor: + grad_tracking_params_iter = iter(flat_params) + flat_grad_tracking_params_or_none = [ + next(grad_tracking_params_iter) + if isinstance(old_param_or_none, torch.Tensor) + else old_param_or_none + for old_param_or_none in flat_params_or_none + ] + grad_tracking_params_containers: Tuple[Dict[str, Optional[torch.Tensor]], ...] + grad_tracking_params_containers = pytree.tree_unflatten( # type: ignore[assignment] + params_containers_treedef, flat_grad_tracking_params_or_none + ) + + try: + for container, grad_tracking_container in zip( + params_containers, grad_tracking_params_containers + ): + container.update(grad_tracking_container) + + return objective(self, *input, **kwargs) + finally: + for container, container_backup in zip( + params_containers, params_containers_backups + ): + container.update(container_backup) + + objective_grad_fn = functorch.grad(objective_fn, argnums=0) + flat_params = tuple(filter(torch.is_tensor, flat_params_or_none)) + flat_grads = objective_grad_fn(flat_params, *input, **kwargs) + return flat_grads + + return residual + + +def enable_implicit_gradients( + solve: Callable[..., 'ImplicitMetaGradientModule'] +) -> Callable[..., 'ImplicitMetaGradientModule']: + """Enable implicit gradients for the :func:`solve` function.""" + if getattr(solve, '__implicit_gradients_enabled__', False): + raise ValueError('Implicit gradients are already enabled for the solve function.') + + @functools.wraps(solve) + def wrapped( # pylint: disable=too-many-locals + self: 'ImplicitMetaGradientModule', *input, **kwargs # pylint: disable=redefined-builtin + ) -> 'ImplicitMetaGradientModule': + """Solve the optimization problem.""" + params_containers = extract_module_containers(self, with_buffers=False)[0] + meta_params_containers = [self._meta_parameters] # pylint: disable=protected-access + for meta_module in self.meta_children(): + meta_params_containers.extend( + extract_module_containers(meta_module, with_buffers=False)[0] + ) + meta_params_containers = tuple(meta_params_containers) + params_containers_backups = tuple(container.copy() for container in params_containers) + meta_params_containers_backups = tuple( + container.copy() for container in meta_params_containers + ) + + flat_params_or_none: List[Optional[torch.Tensor]] + flat_meta_params_or_none: List[Optional[torch.Tensor]] + flat_params_or_none, params_containers_treedef = pytree.tree_flatten( + params_containers # type: ignore[arg-type] + ) + flat_meta_params_or_none, meta_params_containers_treedef = pytree.tree_flatten( + meta_params_containers # type: ignore[arg-type] + ) + + def optimality_fn( + flat_params: Tuple[torch.Tensor, ...], + flat_meta_params: Tuple[torch.Tensor, ...], + *input, # pylint: disable=redefined-builtin + **kwargs, + ) -> Tuple[torch.Tensor, ...]: + grad_tracking_params_iter = iter(flat_params) + flat_grad_tracking_params_or_none = [ + next(grad_tracking_params_iter) + if isinstance(old_params_or_none, torch.Tensor) + else old_params_or_none + for old_params_or_none in flat_params_or_none + ] + grad_tracking_params_containers: Tuple[Dict[str, Optional[torch.Tensor]], ...] + grad_tracking_params_containers = pytree.tree_unflatten( # type: ignore[assignment] + params_containers_treedef, flat_grad_tracking_params_or_none + ) + grad_tracking_meta_params_iter = iter(flat_meta_params) + flat_grad_tracking_meta_params_or_none = [ + next(grad_tracking_meta_params_iter) + if isinstance(old_meta_param_or_none, torch.Tensor) + else old_meta_param_or_none + for old_meta_param_or_none in flat_meta_params_or_none + ] + grad_tracking_meta_params_containers: Tuple[Dict[str, Optional[torch.Tensor]], ...] + grad_tracking_meta_params_containers = pytree.tree_unflatten( # type: ignore[assignment] + meta_params_containers_treedef, flat_grad_tracking_meta_params_or_none + ) + + try: + for container, grad_tracking_container in itertools.chain( + zip(params_containers, grad_tracking_params_containers), + zip(meta_params_containers, grad_tracking_meta_params_containers), + ): + container.update(grad_tracking_container) + + return self.residual(*input, **kwargs) # type: ignore[return-value] + finally: + for container, container_backup in itertools.chain( + zip(params_containers, params_containers_backups), + zip(meta_params_containers, meta_params_containers_backups), + ): + container.update(container_backup) + + @custom_root(optimality_fn, argnums=1) + def solve_fn( + flat_params: Tuple[torch.Tensor, ...], # pylint: disable=unused-argument + flat_meta_params: Tuple[torch.Tensor, ...], # pylint: disable=unused-argument + *input, # pylint: disable=redefined-builtin + **kwargs, + ) -> Tuple[torch.Tensor, ...]: + solve(self, *input, **kwargs) + return tuple(filter(torch.is_tensor, pytree.tree_leaves(params_containers))) # type: ignore[arg-type] + + flat_params = tuple(filter(torch.is_tensor, flat_params_or_none)) + flat_meta_params = tuple(filter(torch.is_tensor, flat_meta_params_or_none)) + # pylint: disable-next=unused-variable + flat_optimal_params = solve_fn(flat_params, flat_meta_params, *input, **kwargs) + return self + + wrapped.__implicit_gradients_enabled__ = True # type: ignore[attr-defined] + return wrapped + + +class ImplicitMetaGradientModule(torchopt.nn.MetaGradientModule): + """The base class for differentiable implicit meta-gradient models.""" + + _custom_residual: bool + _custom_objective: bool + + def __init_subclass__(cls) -> None: + """Initialize the subclass.""" + super().__init_subclass__() + + residual = getattr(cls, 'residual', ImplicitMetaGradientModule.residual) + objective = getattr(cls, 'objective', ImplicitMetaGradientModule.objective) + cls._custom_residual = residual is not ImplicitMetaGradientModule.residual + cls._custom_objective = objective is not ImplicitMetaGradientModule.objective + + if cls._custom_residual: + if isinstance(residual, staticmethod): + raise TypeError('residual() must not be a staticmethod.') + if isinstance(residual, classmethod): + raise TypeError('residual() must not be a classmethod.') + if not callable(residual): + raise TypeError('residual() must be callable.') + elif not cls._custom_objective: + raise TypeError( + 'ImplicitMetaGradientModule requires either an residual() or an objective() function' + ) + else: + if isinstance(objective, staticmethod): + raise TypeError('objective() must not be a staticmethod.') + if isinstance(objective, classmethod): + raise TypeError('objective() must not be a classmethod.') + if not callable(objective): + raise TypeError('objective() must be callable.') + + cls.residual = make_residual_from_objective(objective) # type: ignore[assignment] + + cls.solve = enable_implicit_gradients(cls.solve) # type: ignore[assignment] + + # pylint: disable-next=redefined-builtin + def solve(self, *input, **kwargs) -> 'ImplicitMetaGradientModule': + """Solves the inner optimization problem. + + .. warning:: + + For gradient-based optimization methods, the parameter inputs should be explicitly + specified in the :func:`torch.autograd.backward` function as argument ``inputs``. + Otherwise, if not provided, the gradient is accumulated into all the leaf Tensors + (including the meta-parameters) that were used to compute the objective output. + Alternatively, please use :func:`torch.autograd.grad` instead. + + Example:: + + def solve(self, batch, labels): + parameters = tuple(self.parameters()) + optimizer = torch.optim.Adam(parameters, lr=1e-3) + with torch.enable_grad(): + for _ in range(100): + loss = self.objective(batch, labels) + optimizer.zero_grad() + # Only update the `.grad` attribute for parameters + # and leave the meta-parameters unchanged + loss.backward(inputs=parameters) + optimizer.step() + return self + """ + raise NotImplementedError # update parameters + + # pylint: disable-next=redefined-builtin + def residual(self, *input, **kwargs) -> 'TensorTree': + r"""Computes the optimality residual. + + This method stands for the residual to the optimal parameters after solving the inner + optimization problem (:meth:`solve`), i.e.: + + .. code-block:: python + + module.solve(*input, **kwargs) + module.residual(*input, **kwargs) # -> 0 + + 1. For gradient-based optimization, the :meth:`residual` is the KKT condition, usually the + gradients of the :meth`objective` function with respect to the module parameters (not the + meta-parameters). If this method is not implemented, it will be automatically calculated + from the gradient of the :meth:`objective` function. + + .. math:: + + \text{residual} = \nabla_{\boldsymbol{x}} f (\boldsymbol{x}, \boldsymbol{\theta}) \to \boldsymbol{0} + + where :math:`\boldsymbol{x}` is the joint vector of the module parameters and + :math:`\boldsymbol{\theta}` is the joint vector of the meta-parameters. + + References: + - Karush-Kuhn-Tucker (KKT) conditions: https://en.wikipedia.org/wiki/Karush-Kuhn-Tucker_conditions + + 2. For fixed point iteration, the :meth:`residual` can be the residual of the + parameters between iterations, i.e.: + + .. math:: + + \text{residual} = f (\boldsymbol{x}, \boldsymbol{\theta}) - \boldsymbol{x} \to \boldsymbol{0} + + where :math:`\boldsymbol{x}` is the joint vector of the module parameters and + :math:`\boldsymbol{\theta}` is the joint vector of the meta-parameters. + """ + raise NotImplementedError + + # pylint: disable-next=redefined-builtin + def objective(self, *input, **kwargs) -> torch.Tensor: + """Computes the objective function value. + + This method is used to calculate the :meth:`residual` if it is not implemented. + Otherwise, this method is optional. + """ + raise NotImplementedError diff --git a/torchopt/nn/__init__.py b/torchopt/nn/__init__.py new file mode 100644 index 00000000..3cb7fc78 --- /dev/null +++ b/torchopt/nn/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for neural network modules that hold meta-parameters and meta-modules.""" + +from torchopt.nn.module import MetaGradientModule diff --git a/torchopt/nn/module.py b/torchopt/nn/module.py new file mode 100644 index 00000000..982fd57c --- /dev/null +++ b/torchopt/nn/module.py @@ -0,0 +1,458 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for neural network modules that hold meta-parameters and meta-modules.""" + +from collections import OrderedDict +from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Set, Tuple, Union + +import torch +import torch.nn as nn + +from torchopt import pytree + + +class MetaInputsContainer(NamedTuple): + """Container for parameters and modules in the constructor input arguments.""" + + meta_parameters: Set[torch.Tensor] + meta_modules: Set[nn.Module] + + +class MetaGradientModule(nn.Module): # pylint: disable=abstract-method + """Base class for neural network modules that hold meta-parameters and meta-modules.""" + + _meta_inputs: MetaInputsContainer + _meta_parameters: Dict[str, Optional[torch.Tensor]] + _meta_modules: Dict[str, Optional[nn.Module]] + + def __new__(cls, *args, **kwargs) -> 'MetaGradientModule': + """Creates a new module instance.""" + instance = super().__new__(cls) + flat_args: List[Any] + flat_args = pytree.tree_leaves((args, kwargs)) # type: ignore[arg-type] + meta_parameters = set( + x for x in flat_args if isinstance(x, torch.Tensor) and x.requires_grad + ) + meta_modules = set(x for x in flat_args if isinstance(x, nn.Module) and x.training) + for meta_module in tuple(meta_modules): + meta_parameters.update(meta_module.parameters()) + meta_modules.update(meta_module.modules()) + + instance._meta_inputs = MetaInputsContainer(meta_parameters, meta_modules) + instance._meta_parameters: Dict[str, Optional[torch.Tensor]] = OrderedDict() # type: ignore[misc] + instance._meta_modules: Dict[str, Optional[nn.Module]] = OrderedDict() # type: ignore[misc] + return instance + + def __getattr__(self, name: str) -> Union[torch.Tensor, nn.Module]: + """Gets an attribute of the module.""" + if '_parameters' in self.__dict__: + _parameters = self.__dict__['_parameters'] + if name in _parameters: + return _parameters[name] + if '_buffers' in self.__dict__: + _buffers = self.__dict__['_buffers'] + if name in _buffers: + return _buffers[name] + if '_modules' in self.__dict__: + modules = self.__dict__['_modules'] + if name in modules: + return modules[name] + if '_meta_parameters' in self.__dict__: + _meta_parameters = self.__dict__['_meta_parameters'] + if name in _meta_parameters: + return _meta_parameters[name] + if '_meta_modules' in self.__dict__: + _meta_modules = self.__dict__['_meta_modules'] + if name in _meta_modules: + return _meta_modules[name] + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + # pylint: disable-next=too-many-branches,too-many-statements + def __setattr__(self, name: str, value: Union[torch.Tensor, nn.Module]) -> None: + """Sets an attribute of the module.""" + + def remove_from(*dicts_or_sets): + for dict_or_set in dicts_or_sets: + if name in dict_or_set: + if isinstance(dict_or_set, dict): + del dict_or_set[name] + else: + dict_or_set.discard(name) + + params = self.__dict__.get('_parameters') + meta_params = self.__dict__.get('_meta_parameters') + if isinstance(value, torch.Tensor) and value.requires_grad: + if params is None: + raise AttributeError('cannot assign parameters before Module.__init__() call') + if meta_params is None: + raise AttributeError( + 'cannot assign meta-parameters before MetaGradientModule.__init__() call' + ) + remove_from( + self.__dict__, + self._buffers, + self._modules, + self._non_persistent_buffers_set, + self._meta_parameters, + self._meta_modules, + ) + if value in self._meta_inputs.meta_parameters: + self.register_meta_parameter(name, value) + else: + self.register_parameter(name, value) + elif params is not None and name in params: + if value is not None: + raise TypeError( + f"cannot assign '{torch.typename(value)}' as parameter '{name}' " + f'(torch.Tensor or None expected)' + ) + self.register_parameter(name, value) # type: ignore[unreachable] + elif meta_params is not None and name in meta_params: + if value is not None: + raise TypeError( + f"cannot assign '{torch.typename(value)}' as meta-parameter '{name}' " + f'(torch.Tensor or None expected)' + ) + self.register_meta_parameter(name, value) # type: ignore[unreachable] + else: + modules = self.__dict__.get('_modules') + meta_modules = self.__dict__.get('_meta_modules') + if isinstance(value, nn.Module): + if modules is None: + raise AttributeError('cannot assign module before Module.__init__() call') + if meta_modules is None: + raise AttributeError( + 'cannot assign module before MetaGradientModule.__init__() call' + ) + remove_from( + self.__dict__, + self._parameters, + self._buffers, + self._non_persistent_buffers_set, + self._meta_parameters, + self._meta_modules, + ) + if value in self._meta_inputs.meta_modules: + meta_modules[name] = value + else: + modules[name] = value + elif modules is not None and name in modules: + if value is not None: + raise TypeError( + f"cannot assign '{torch.typename(value)}' as child module '{name}' " + f'(torch.nn.Module or None expected)' + ) + modules[name] = value # type: ignore[unreachable] + else: + buffers = self.__dict__.get('_buffers') + if buffers is not None and name in buffers: + if value is not None and not isinstance(value, torch.Tensor): + raise TypeError( + f"cannot assign '{torch.typename(value)}' as buffer '{name}' " + f'(torch.Tensor or None expected)' + ) + buffers[name] = value + else: + object.__setattr__(self, name, value) + + def __delattr__(self, name: str) -> None: + """Deletes an attribute of the module.""" + if name in self._parameters: + del self._parameters[name] + elif name in self._buffers: + del self._buffers[name] + self._non_persistent_buffers_set.discard(name) + elif name in self._modules: + del self._modules[name] + elif name in self._meta_parameters: + del self._meta_parameters[name] + elif name in self._meta_modules: + del self._meta_modules[name] + else: + object.__delattr__(self, name) + + def register_parameter(self, name: str, param: Optional[torch.Tensor]) -> None: + r"""Adds a parameter to the module. + + The parameter can be accessed as an attribute using given name. + + Args: + name (string): name of the parameter. The parameter can be accessed + from this module using the given name + param (torch.Tensor or None): parameter to be added to the module. If + ``None``, then operations that run on parameters, such as :attr:`cuda`, + are ignored. If ``None``, the parameter is **not** included in the + module's :attr:`state_dict`. + """ + if '_parameters' not in self.__dict__: + raise AttributeError('cannot assign parameter before Module.__init__() call') + if not isinstance(name, str): + raise TypeError(f'parameter name should be a string. Got {torch.typename(name)}') + if '.' in name: + raise KeyError("parameter name can't contain \".\"") + if name == '': + raise KeyError("parameter name can't be empty string \"\"") + if hasattr(self, name) and name not in self._parameters: + raise KeyError(f"attribute '{name}' already exists") + + if param is None: + self._parameters[name] = None + return + + if not isinstance(param, torch.Tensor): + raise TypeError( + f"cannot assign '{torch.typename(param)}' object to parameter '{name}' " + f'(torch.Tensor or None required)' + ) + if not param.requires_grad: + raise ValueError( + f"cannot assign Tensor that `requires_grad=False` to parameter '{name}'" + ) + if param in self._meta_inputs.meta_parameters: + raise ValueError( + f"cannot assign Tensor that is a meta-parameter to parameter '{name}'. " + f'Use self.register_meta_parameter() instead.' + ) + + self._parameters[name] = param # type: ignore + + def register_meta_parameter(self, name: str, param: Optional[torch.Tensor]) -> None: + r"""Adds a meta-parameter to the module. + + The meta-parameter can be accessed as an attribute using given name. + + Args: + name (string): name of the parameter. The parameter can be accessed + from this module using the given name + param (torch.Tensor or None): parameter to be added to the module. If + ``None``, then operations that run on parameters, such as :attr:`cuda`, + are ignored. If ``None``, the parameter is **not** included in the + module's :attr:`state_dict`. + """ + if '_meta_parameters' not in self.__dict__: + raise AttributeError( + 'cannot assign meta-parameter before MetaGradientModule.__init__() call' + ) + if not isinstance(name, str): + raise TypeError(f'meta-parameter name should be a string. Got {torch.typename(name)}') + if '.' in name: + raise KeyError("meta-parameter name can't contain \".\"") + if name == '': + raise KeyError("meta-parameter name can't be empty string \"\"") + if hasattr(self, name) and name not in self._meta_parameters: + raise KeyError(f"attribute '{name}' already exists") + + if param is None: + self._meta_parameters[name] = None + return + + if not isinstance(param, torch.Tensor): + raise TypeError( + f"cannot assign '{torch.typename(param)}' object to meta-parameter '{name}' " + f'(torch.Tensor or None required)' + ) + if not param.requires_grad: + raise ValueError( + f"cannot assign Tensor that `requires_grad=False` to meta-parameter '{name}'" + ) + + self._meta_parameters[name] = param + + def add_module(self, name: str, module: Optional[nn.Module]) -> None: + r"""Adds a child module to the current module. + + The module can be accessed as an attribute using the given name. + + Args: + name (string): name of the child module. The child module can be + accessed from this module using the given name + module (Module): child module to be added to the module. + """ + if not isinstance(module, nn.Module) and module is not None: + raise TypeError(f'{torch.typename(module)} is not a Module subclass') + if not isinstance(name, str): + raise TypeError(f'module name should be a string. Got {torch.typename(name)}') + if hasattr(self, name) and name not in self._modules: + raise KeyError(f"attribute '{name}' already exists") + if '.' in name: + raise KeyError(f"module name can't contain \".\", got: {name}") + if name == '': + raise KeyError("module name can't be empty string \"\"") + if module in self._meta_inputs.meta_modules: + raise ValueError( + f"cannot add module that is a meta-module to module '{name}'. " + f'Use self.add_meta_module() instead.' + ) + + self._modules[name] = module + + def register_module(self, name: str, module: Optional[nn.Module]) -> None: + r"""Alias for :func:`add_module`.""" + self.add_module(name, module) + + def add_meta_module(self, name: str, meta_module: Optional[nn.Module]) -> None: + r"""Adds a child meta-module to the current module. + + The meta-module can be accessed as an attribute using the given name. + + Args: + name (string): name of the child meta-module. The child meta-module can be + accessed from this module using the given name + meta_module (Module): child meta-module to be added to the module. + """ + if not isinstance(meta_module, nn.Module) and meta_module is not None: + raise TypeError(f'{torch.typename(meta_module)} is not a Module subclass') + if not isinstance(name, str): + raise TypeError(f'meta-module name should be a string. Got {torch.typename(name)}') + if hasattr(self, name) and name not in self._meta_modules: + raise KeyError(f"attribute '{name}' already exists") + if '.' in name: + raise KeyError(f"meta-module name can't contain \".\", got: {name}") + if name == '': + raise KeyError("meta-module name can't be empty string \"\"") + + self._meta_modules[name] = meta_module + + def register_meta_module(self, name: str, meta_module: Optional[nn.Module]) -> None: + r"""Alias for :func:`add_meta_module`.""" + self.add_meta_module(name, meta_module) + + def meta_parameters(self, recurse: bool = True) -> Iterator[torch.Tensor]: + r"""Returns an iterator over module meta-parameters. + + This is typically passed to an optimizer. + + Args: + recurse (bool): if True, then yields parameters of this module + and all submodules. Otherwise, yields only meta-parameters that + are direct members of this module. + + Yields: + Parameter: module meta-parameter + + Example:: + + >>> for param in model.meta_parameters(): + >>> print(type(param), param.size()) + (20L,) + (20L, 1L, 5L, 5L) + + """ + for _, meta_param in self.named_meta_parameters(recurse=recurse): + yield meta_param + + def named_meta_parameters( + self, prefix: str = '', recurse: bool = True + ) -> Iterator[Tuple[str, torch.Tensor]]: + r"""Returns an iterator over module meta-parameters, yielding both the name of the meta-parameter as well as the meta-parameter itself. + + Args: + prefix (str): prefix to prepend to all meta-parameter names. + recurse (bool): if True, then yields meta-parameters of this module + and all submodules. Otherwise, yields only meta-parameters that + are direct members of this module. + + Yields: + (string, Parameter): Tuple containing the name and parameter + + Example:: + + >>> for name, meta_param in self.named_meta_parameters(): + >>> if name in ['bias']: + >>> print(meta_param.size()) + + """ # pylint: disable=line-too-long + memo = set() + for name, param in getattr(self, '_meta_parameters', {}).items(): + if param is None or param in memo: + continue + memo.add(param) + yield prefix + name, param + for name, meta_module in getattr(self, '_meta_modules', {}).items(): + if meta_module is None: + continue + submodule_prefix = prefix + name + yield from meta_module.named_parameters(submodule_prefix, recurse) + + def meta_children(self) -> Iterator[nn.Module]: + r"""Returns an iterator over immediate children meta-modules. + + Yields: + Module: a child meta-module + """ + for _, module in self.named_meta_children(): + yield module + + def named_meta_children(self) -> Iterator[Tuple[str, nn.Module]]: + r"""Returns an iterator over immediate children meta-modules, yielding both the name of the meta-module as well as the meta-module itself. + + Yields: + (string, Module): Tuple containing a name and child meta-module + + Example:: + + >>> for name, meta_module in model.named_meta_children(): + >>> if name in ['conv4', 'conv5']: + >>> print(meta_module) + + """ # pylint: disable=line-too-long + memo = set() + for name, meta_module in self._meta_modules.items(): + if meta_module is not None and meta_module not in memo: + memo.add(meta_module) + yield name, meta_module + + def meta_modules(self) -> Iterator[nn.Module]: + r"""Returns an iterator over all meta-modules in the network. + + Yields: + Module: a meta-module in the network + + Note: + Duplicate meta-modules are returned only once. + """ + for _, meta_module in self.named_meta_modules(): + yield meta_module + + def named_meta_modules( + self, memo: Optional[Set[nn.Module]] = None, prefix: str = '', remove_duplicate: bool = True + ) -> Iterator[Tuple[str, nn.Module]]: + r"""Returns an iterator over all meta-modules in the network, yielding both the name of the meta-module as well as the meta-module itself. + + Args: + memo: a memo to store the set of meta-modules already added to the result + prefix: a prefix that will be added to the name of the meta-module + remove_duplicate: whether to remove the duplicated meta-module instances in the result + or not + + Yields: + (string, Module): Tuple of name and meta-module + + Note: + Duplicate modules are returned only once. + """ # pylint: disable=line-too-long + if memo is None: + memo = set() + if self in memo: + return + + if remove_duplicate: + memo.add(self) + + for name, meta_module in self._meta_modules.items(): + if meta_module is None: + continue + submodule_prefix = prefix + ('.' if prefix else '') + name + yield from meta_module.named_modules(memo, submodule_prefix, remove_duplicate) diff --git a/torchopt/optim/meta/base.py b/torchopt/optim/meta/base.py index 668c46a4..25d8c947 100644 --- a/torchopt/optim/meta/base.py +++ b/torchopt/optim/meta/base.py @@ -22,6 +22,7 @@ from torchopt import pytree from torchopt.typing import GradientTransformation, OptState from torchopt.update import apply_updates +from torchopt.utils import extract_module_containers __all__ = ['MetaOptimizer'] @@ -102,10 +103,7 @@ def step(self, loss: torch.Tensor): # pylint: disable=too-many-locals def add_param_group(self, net: nn.Module) -> None: """Add a param group to the optimizer's :attr:`state_groups`.""" - # pylint: disable-next=import-outside-toplevel - from torchopt.utils import _extract_container - - params_container, _ = _extract_container(net, with_buffers=False) + params_container = extract_module_containers(net, with_buffers=False)[0] flat_params = tuple( filter( torch.is_tensor, # type: ignore[arg-type] diff --git a/torchopt/utils.py b/torchopt/utils.py index 211bcdd6..35214a39 100644 --- a/torchopt/utils.py +++ b/torchopt/utils.py @@ -277,12 +277,13 @@ def get_variable(t): raise RuntimeError(f'Unexpected class of {target}') -def _extract_container( +def extract_module_containers( module: nn.Module, with_buffers: bool = True ) -> Tuple[ Tuple[Dict[str, Optional[torch.Tensor]], ...], Tuple[Dict[str, Optional[torch.Tensor]], ...], ]: + """Extract the references to the containers of parameters and buffers from a module.""" if isinstance(module, nn.Module): params: List[Dict[str, Optional[torch.Tensor]]] = [] buffers: List[Dict[str, Optional[torch.Tensor]]] = [] @@ -329,7 +330,7 @@ def recover_state_dict( if isinstance(target, nn.Module): params, buffers, *_ = state = cast(ModuleState, state) - params_container, buffers_container = _extract_container(target, with_buffers=True) + params_containers, buffers_containers = extract_module_containers(target, with_buffers=True) if state.detach_buffers: @@ -342,8 +343,8 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor: ) for tgt, src in itertools.chain( - zip(params_container, params), - zip(buffers_container, buffers), + zip(params_containers, params), + zip(buffers_containers, buffers), ): tgt.update(src) elif isinstance(target, MetaOptimizer): @@ -426,7 +427,7 @@ def module_clone( if isinstance(target, (nn.Module, MetaOptimizer)): if isinstance(target, nn.Module): - containers = cast(TensorTree, _extract_container(target, with_buffers=True)) + containers = cast(TensorTree, extract_module_containers(target, with_buffers=True)) else: containers = cast(TensorTree, target.state_dict()) tensors = pytree.tree_leaves(containers) From 3111f7b6e601b27a3c4fa9f64bf0fc2151f0a8dc Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 14 Oct 2022 21:26:21 +0800 Subject: [PATCH 02/11] docs(CHANGELOG): update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b537d9c..9cb41810 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Add object-oriented modules support for implicit meta-gradient by [@XuehaiPan](https://github.com/XuehaiPan) in [#101](https://github.com/metaopt/torchopt/pull/101). - Add zero-order gradient estimation by [@JieRen98](https://github.com/JieRen98) in [#93](https://github.com/metaopt/torchopt/pull/93). - Add RPC-based distributed training support and add distributed MAML example by [@XuehaiPan](https://github.com/XuehaiPan) in [#83](https://github.com/metaopt/torchopt/pull/83). - Add full type hints by [@XuehaiPan](https://github.com/XuehaiPan) in [#92](https://github.com/metaopt/torchopt/pull/92). From 0e6b6a32abb63ad181e20555fc2fc6b1ec3a0b27 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 15 Oct 2022 00:29:04 +0800 Subject: [PATCH 03/11] docs: update dictionary --- docs/source/spelling_wordlist.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/spelling_wordlist.txt b/docs/source/spelling_wordlist.txt index f590b546..e8e37ad2 100644 --- a/docs/source/spelling_wordlist.txt +++ b/docs/source/spelling_wordlist.txt @@ -88,3 +88,4 @@ deepcopy deepclone RRef rref +ints From 2e3b5140685a9dfbd90c87f9c05ddee653ba8fd8 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 17 Oct 2022 14:33:01 +0800 Subject: [PATCH 04/11] docs: update docstrings --- torchopt/diff/implicit/nn/module.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py index 7cda4c80..0429c68a 100644 --- a/torchopt/diff/implicit/nn/module.py +++ b/torchopt/diff/implicit/nn/module.py @@ -24,7 +24,7 @@ import torchopt.nn from torchopt import pytree from torchopt.diff.implicit.decorator import custom_root -from torchopt.typing import TensorTree +from torchopt.typing import TensorTree # pylint: disable=unused-import from torchopt.utils import extract_module_containers @@ -240,6 +240,9 @@ def solve(self, batch, labels): loss.backward(inputs=parameters) optimizer.step() return self + + Returns: + The module itself after solving the inner optimization problem. """ raise NotImplementedError # update parameters @@ -255,10 +258,10 @@ def residual(self, *input, **kwargs) -> 'TensorTree': module.solve(*input, **kwargs) module.residual(*input, **kwargs) # -> 0 - 1. For gradient-based optimization, the :meth:`residual` is the KKT condition, usually the - gradients of the :meth`objective` function with respect to the module parameters (not the - meta-parameters). If this method is not implemented, it will be automatically calculated - from the gradient of the :meth:`objective` function. + 1. For gradient-based optimization, the :meth:`residual` function is the KKT condition, + usually it is the gradients of the :meth:`objective` function with respect to the module + parameters (not the meta-parameters). If this method is not implemented, it will be + automatically derived from the gradient of the :meth:`objective` function. .. math:: @@ -270,7 +273,7 @@ def residual(self, *input, **kwargs) -> 'TensorTree': References: - Karush-Kuhn-Tucker (KKT) conditions: https://en.wikipedia.org/wiki/Karush-Kuhn-Tucker_conditions - 2. For fixed point iteration, the :meth:`residual` can be the residual of the + 2. For fixed point iteration, the :meth:`residual` function can be the residual of the parameters between iterations, i.e.: .. math:: @@ -279,6 +282,10 @@ def residual(self, *input, **kwargs) -> 'TensorTree': where :math:`\boldsymbol{x}` is the joint vector of the module parameters and :math:`\boldsymbol{\theta}` is the joint vector of the meta-parameters. + + Returns: + A tree of tensors, the residual to the optimal parameters after solving the inner + optimization problem. """ raise NotImplementedError @@ -288,5 +295,8 @@ def objective(self, *input, **kwargs) -> torch.Tensor: This method is used to calculate the :meth:`residual` if it is not implemented. Otherwise, this method is optional. + + Returns: + A scalar tensor (``dim=0``), the objective function value. """ raise NotImplementedError From c27200913ae320d15250ea12d67666a7d248b758 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 17 Oct 2022 15:00:26 +0800 Subject: [PATCH 05/11] docs: add API references --- docs/source/api/api.rst | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/source/api/api.rst b/docs/source/api/api.rst index d0366f60..1ec8e1d2 100644 --- a/docs/source/api/api.rst +++ b/docs/source/api/api.rst @@ -139,12 +139,22 @@ Implicit differentiation .. autosummary:: custom_root + nn.ImplicitMetaGradientModule Custom solvers ~~~~~~~~~~~~~~ .. autofunction:: custom_root + +Implicit Meta-Gradient Module +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. currentmodule:: torchopt.diff.implicit.nn + +.. autoclass:: ImplicitMetaGradientModule + :members: + ------ Linear system solving From f502f8a731841344da40430e630cd38d9aa974eb Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 17 Oct 2022 18:09:50 +0800 Subject: [PATCH 06/11] docs: update dictionary --- docs/source/spelling_wordlist.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/spelling_wordlist.txt b/docs/source/spelling_wordlist.txt index e8e37ad2..25f11953 100644 --- a/docs/source/spelling_wordlist.txt +++ b/docs/source/spelling_wordlist.txt @@ -89,3 +89,6 @@ deepclone RRef rref ints +Karush +Kuhn +Tucker From 438eb449d9e67a3d21a85a45686602cf856900cd Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Wed, 26 Oct 2022 19:39:39 +0800 Subject: [PATCH 07/11] chore: handle `none_is_leaf` change in optree --- torchopt/diff/implicit/nn/module.py | 64 ++++++++++------------------- torchopt/optim/meta/base.py | 2 +- 2 files changed, 23 insertions(+), 43 deletions(-) diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py index 0429c68a..755e9b5c 100644 --- a/torchopt/diff/implicit/nn/module.py +++ b/torchopt/diff/implicit/nn/module.py @@ -39,23 +39,16 @@ def make_residual_from_objective( def residual(self: 'ImplicitMetaGradientModule', *input, **kwargs) -> Tuple[torch.Tensor, ...]: params_containers = extract_module_containers(self, with_buffers=False)[0] params_containers_backups = [container.copy() for container in params_containers] - flat_params_or_none: List[Optional[torch.Tensor]] - flat_params_or_none, params_containers_treedef = pytree.tree_flatten( - params_containers # type: ignore[arg-type] - ) + flat_params: List[torch.Tensor] + flat_params, params_containers_treespec = pytree.tree_flatten(params_containers) # type: ignore[arg-type] # pylint: disable-next=redefined-builtin def objective_fn(flat_params: Tuple[torch.Tensor, ...], *input, **kwargs) -> torch.Tensor: - grad_tracking_params_iter = iter(flat_params) - flat_grad_tracking_params_or_none = [ - next(grad_tracking_params_iter) - if isinstance(old_param_or_none, torch.Tensor) - else old_param_or_none - for old_param_or_none in flat_params_or_none - ] - grad_tracking_params_containers: Tuple[Dict[str, Optional[torch.Tensor]], ...] - grad_tracking_params_containers = pytree.tree_unflatten( # type: ignore[assignment] - params_containers_treedef, flat_grad_tracking_params_or_none + flat_grad_tracking_params = flat_params + grad_tracking_params_containers: Tuple[ + Dict[str, Optional[torch.Tensor]], ... + ] = pytree.tree_unflatten( # type: ignore[assignment] + params_containers_treespec, flat_grad_tracking_params ) try: @@ -72,7 +65,6 @@ def objective_fn(flat_params: Tuple[torch.Tensor, ...], *input, **kwargs) -> tor container.update(container_backup) objective_grad_fn = functorch.grad(objective_fn, argnums=0) - flat_params = tuple(filter(torch.is_tensor, flat_params_or_none)) flat_grads = objective_grad_fn(flat_params, *input, **kwargs) return flat_grads @@ -103,12 +95,12 @@ def wrapped( # pylint: disable=too-many-locals container.copy() for container in meta_params_containers ) - flat_params_or_none: List[Optional[torch.Tensor]] - flat_meta_params_or_none: List[Optional[torch.Tensor]] - flat_params_or_none, params_containers_treedef = pytree.tree_flatten( + flat_params: List[torch.Tensor] + flat_meta_params: List[torch.Tensor] + flat_params, params_containers_treespec = pytree.tree_flatten( params_containers # type: ignore[arg-type] ) - flat_meta_params_or_none, meta_params_containers_treedef = pytree.tree_flatten( + flat_meta_params, meta_params_containers_treespec = pytree.tree_flatten( meta_params_containers # type: ignore[arg-type] ) @@ -118,27 +110,17 @@ def optimality_fn( *input, # pylint: disable=redefined-builtin **kwargs, ) -> Tuple[torch.Tensor, ...]: - grad_tracking_params_iter = iter(flat_params) - flat_grad_tracking_params_or_none = [ - next(grad_tracking_params_iter) - if isinstance(old_params_or_none, torch.Tensor) - else old_params_or_none - for old_params_or_none in flat_params_or_none - ] - grad_tracking_params_containers: Tuple[Dict[str, Optional[torch.Tensor]], ...] - grad_tracking_params_containers = pytree.tree_unflatten( # type: ignore[assignment] - params_containers_treedef, flat_grad_tracking_params_or_none + flat_grad_tracking_params = flat_params + grad_tracking_params_containers: Tuple[ + Dict[str, Optional[torch.Tensor]], ... + ] = pytree.tree_unflatten( # type: ignore[assignment] + params_containers_treespec, flat_grad_tracking_params ) - grad_tracking_meta_params_iter = iter(flat_meta_params) - flat_grad_tracking_meta_params_or_none = [ - next(grad_tracking_meta_params_iter) - if isinstance(old_meta_param_or_none, torch.Tensor) - else old_meta_param_or_none - for old_meta_param_or_none in flat_meta_params_or_none - ] - grad_tracking_meta_params_containers: Tuple[Dict[str, Optional[torch.Tensor]], ...] - grad_tracking_meta_params_containers = pytree.tree_unflatten( # type: ignore[assignment] - meta_params_containers_treedef, flat_grad_tracking_meta_params_or_none + flat_grad_tracking_meta_params = flat_meta_params + grad_tracking_meta_params_containers: Tuple[ + Dict[str, Optional[torch.Tensor]], ... + ] = pytree.tree_unflatten( # type: ignore[assignment] + meta_params_containers_treespec, flat_grad_tracking_meta_params ) try: @@ -148,7 +130,7 @@ def optimality_fn( ): container.update(grad_tracking_container) - return self.residual(*input, **kwargs) # type: ignore[return-value] + return self.residual(*input, **kwargs) finally: for container, container_backup in itertools.chain( zip(params_containers, params_containers_backups), @@ -166,8 +148,6 @@ def solve_fn( solve(self, *input, **kwargs) return tuple(filter(torch.is_tensor, pytree.tree_leaves(params_containers))) # type: ignore[arg-type] - flat_params = tuple(filter(torch.is_tensor, flat_params_or_none)) - flat_meta_params = tuple(filter(torch.is_tensor, flat_meta_params_or_none)) # pylint: disable-next=unused-variable flat_optimal_params = solve_fn(flat_params, flat_meta_params, *input, **kwargs) return self diff --git a/torchopt/optim/meta/base.py b/torchopt/optim/meta/base.py index 30a14e7b..5dd300a7 100644 --- a/torchopt/optim/meta/base.py +++ b/torchopt/optim/meta/base.py @@ -96,7 +96,7 @@ def add_param_group(self, net: nn.Module) -> None: """Add a param group to the optimizer's :attr:`state_groups`.""" params_container = extract_module_containers(net, with_buffers=False)[0] flat_params: Tuple[torch.Tensor, ...] = tuple(pytree.tree_leaves(params_container)) # type: ignore[arg-type] - optimizer_state = self.impl.init(flat_params) # type: ignore[arg-type] + optimizer_state = self.impl.init(flat_params) self.param_containers_groups.append(params_container) self.state_groups.append(optimizer_state) From ba43fc1efddaf62e3ddd70860183ca72d6b155d8 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 28 Oct 2022 21:37:07 +0800 Subject: [PATCH 08/11] chore: update type hint --- torchopt/diff/implicit/decorator.py | 149 +++++++++++++----------- torchopt/diff/implicit/nn/module.py | 37 +++--- torchopt/distributed/autograd.py | 35 ++++-- torchopt/linear_solve/normal_cg.py | 10 +- torchopt/optim/base.py | 5 +- torchopt/optim/meta/base.py | 9 +- torchopt/pytree.py | 27 ++++- torchopt/transform/scale_by_adam.py | 6 +- torchopt/transform/scale_by_schedule.py | 6 +- torchopt/typing.py | 24 +++- torchopt/visual.py | 5 +- 11 files changed, 199 insertions(+), 114 deletions(-) diff --git a/torchopt/diff/implicit/decorator.py b/torchopt/diff/implicit/decorator.py index d6a5477c..6f19c196 100644 --- a/torchopt/diff/implicit/decorator.py +++ b/torchopt/diff/implicit/decorator.py @@ -18,13 +18,14 @@ import functools import inspect -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union import functorch import torch from torch.autograd import Function from torchopt import linear_solve, pytree +from torchopt.typing import ListOfTensors, TensorOrTensors, TupleOfTensors __all__ = ['custom_root'] @@ -37,11 +38,11 @@ class MaskedOptimalityFn: # pylint: disable=missing-class-docstring,too-few-public-methods def __init__( self, - optimality_fn: Callable, - solution: Any, + optimality_fn: Callable[..., TensorOrTensors], + solution: TensorOrTensors, output_is_tensor: bool, argnums: Tuple[int, ...], - *args, + *args: Any, ) -> None: self.optimality_fn = optimality_fn self.solution = solution @@ -59,7 +60,7 @@ def __init__( self.pre_filled = tuple(pre_filled) self.post_filled = tuple(post_filled) - def __call__(self, *args) -> Any: + def __call__(self, *args: Any) -> TensorOrTensors: true_args = [] pre_filled_counter = 0 for idx in range(self.len_args): @@ -76,64 +77,66 @@ def __call__(self, *args) -> Any: # pylint: disable-next=too-many-arguments,too-many-locals,too-many-branches def _root_vjp( - optimality_fn: Callable, - solution: Any, + optimality_fn: Callable[..., TensorOrTensors], + solution: TupleOfTensors, args: Args, - grad_outputs: Any, + grad_outputs: TupleOfTensors, output_is_tensor: bool, argnums: Tuple[int, ...], - solve: Callable = linear_solve.solve_normal_cg(), -) -> Tuple[Any, ...]: + solve: Callable[..., TensorOrTensors] = linear_solve.solve_normal_cg(), +) -> Tuple[Optional[torch.Tensor], ...]: if output_is_tensor: - def optimality_cond(solution): + def optimality_cond(solution: TupleOfTensors) -> TensorOrTensors: return optimality_fn(solution[0], *args) else: - def optimality_cond(solution): + def optimality_cond(solution: TupleOfTensors) -> TensorOrTensors: return optimality_fn(solution, *args) - _, vjp_optimality_cond, *_ = functorch.vjp(optimality_cond, solution) + _, optimality_cond_vjp_fn, *_ = functorch.vjp(optimality_cond, solution) # Compute the multiplication A^T u = (u^T A)^T. if output_is_tensor: - def matvec(u): - return vjp_optimality_cond(u[0])[0] + def matvec(u: TupleOfTensors) -> TupleOfTensors: + return optimality_cond_vjp_fn(u[0])[0] else: - def matvec(u): - return vjp_optimality_cond(u)[0] + def matvec(u: TupleOfTensors) -> TupleOfTensors: + return optimality_cond_vjp_fn(u)[0] # The solution of A^T u = v, where # A = jacobian(optimality_fn, argnums=0) # v = -grad_outputs. - v = pytree.tree_map(torch.neg, grad_outputs) - u = solve(matvec, v) + v: TupleOfTensors = pytree.tree_map(torch.neg, grad_outputs) # type: ignore[arg-type,assignment] + u: TupleOfTensors = solve(matvec, v) # type: ignore[assignment] masked_optimality_fn = MaskedOptimalityFn( optimality_fn, solution, output_is_tensor, argnums, *args ) - if getattr(solve, 'is_sdp', False): + output: TupleOfTensors + if getattr(solve, 'is_spd', False): if output_is_tensor: - output = u[0] - else: output = u + else: + output = (u,) # type: ignore[assignment] else: - _, vjp_optimality_fn, *_ = functorch.vjp( + _, optimality_vjp_fn, *_ = functorch.vjp( masked_optimality_fn, *masked_optimality_fn.post_filled ) if output_is_tensor: - output = vjp_optimality_fn(u[0]) + output = optimality_vjp_fn(u[0]) else: - output = vjp_optimality_fn(u) + output = optimality_vjp_fn(u) - true_output = [None] + # Prepend None as the vjp for init_params. + true_output: List[Optional[torch.Tensor]] = [None] for idx in range(masked_optimality_fn.len_args): if idx + 1 in argnums: # plus 1 because we exclude the first argument true_output.append(output[idx]) @@ -150,14 +153,14 @@ def _extract_kwargs(kwarg_keys: Sequence[str], flat_args: Tuple[Any, ...]) -> Tu return args, kwargs -def _signature_bind(signature: inspect.Signature, *args, **kwargs) -> Tuple[Args, KwArgs]: +def _signature_bind(signature: inspect.Signature, *args: Any, **kwargs: Any) -> Tuple[Args, KwArgs]: bound = signature.bind(*args, **kwargs) bound.apply_defaults() return bound.args, bound.kwargs def _signature_bind_and_match( - signature: inspect.Signature, *args, **kwargs + signature: inspect.Signature, *args: Any, **kwargs: Any ) -> Tuple[Args, KwArgs, Callable[[Args], Tuple[Args, KwArgs]]]: # We want to bind *args and **kwargs based on the provided signature, but also to associate the # resulting positional arguments back. To achieve this, we lift arguments to a triple: @@ -192,10 +195,10 @@ def map_args_back(out_args): def _split_tensor_and_others( mixed_tuple: Tuple[Any, ...], -) -> Tuple[pytree.PyTreeSpec, Tuple[bool, ...], Tuple[torch.Tensor, ...], Tuple[Any, ...]]: +) -> Tuple[pytree.PyTreeSpec, Tuple[bool, ...], TupleOfTensors, Tuple[Any, ...]]: flattened: List[Any] flattened, treespec = pytree.tree_flatten(mixed_tuple, none_is_leaf=True) # type: ignore[arg-type] - tensors: List[torch.Tensor] = [] + tensors: ListOfTensors = [] non_tensors: List[Any] = [] is_tensor_mask: List[bool] = [] for item in flattened: @@ -211,9 +214,9 @@ def _split_tensor_and_others( def _merge_tensor_and_others( treespec: pytree.PyTreeSpec, is_tensor_mask: Tuple[bool, ...], - tensors: Tuple[torch.Tensor, ...], + tensors: TupleOfTensors, non_tensors: Tuple[Any, ...], -) -> Any: +) -> Tuple[Any, ...]: tensor_counter = 0 non_tensor_counter = 0 results = [] @@ -224,18 +227,18 @@ def _merge_tensor_and_others( else: results.append(non_tensors[non_tensor_counter]) non_tensor_counter += 1 - return pytree.tree_unflatten(treespec, results) + return pytree.tree_unflatten(treespec, results) # type: ignore[return-value] # pylint: disable-next=too-many-arguments,too-many-statements def _custom_root( - solver_fn: Callable, - optimality_fn: Callable, - solve: Callable, + solver_fn: Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]], + optimality_fn: Callable[..., TensorOrTensors], + solve: Callable[..., TensorOrTensors], argnums: Tuple[int, ...], has_aux: bool, reference_signature: Optional[Union[inspect.Signature, Callable]] = None, -) -> Callable: +) -> Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]]: solver_fn_signature = inspect.signature(solver_fn) if reference_signature is None: @@ -246,19 +249,25 @@ def _custom_root( fn = getattr(reference_signature, 'subfn', reference_signature) reference_signature = inspect.signature(fn) - def make_custom_vjp_solver_fn(solver_fn, kwarg_keys, args_sign): + def make_custom_vjp_solver_fn( + solver_fn: Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]], + kwarg_keys: Sequence[str], + args_signs: Tuple[Tuple[int, Optional[Union[Type[tuple], Type[list]]]], ...], + ) -> Type[Function]: # pylint: disable-next=missing-class-docstring,abstract-method class ImplicitMetaGradient(Function): @staticmethod - def forward(ctx, *flat_args): # pylint: disable=arguments-differ + def forward( # type: ignore[override] # pylint: disable=arguments-differ + ctx, *flat_args: Any + ) -> Tuple[Any, ...]: output, aux, output_is_tensor = None, None, False args = [] - for idx, (start_point, is_tuple) in enumerate(args_sign): - if is_tuple: - args.append(tuple(flat_args[start_point : args_sign[idx + 1][0]])) + for idx, (offset, arg_seq_type) in enumerate(args_signs): + if arg_seq_type is not None: + args.append(arg_seq_type(flat_args[offset : args_signs[idx + 1][0]])) else: - args.append(flat_args[start_point]) + args.append(flat_args[offset]) args = tuple(args) args, kwargs = _extract_kwargs(kwarg_keys, args) @@ -295,8 +304,10 @@ def forward(ctx, *flat_args): # pylint: disable=arguments-differ return (*output, aux, output_is_tensor, type(output)) @staticmethod - def backward(ctx, *grad_outputs): # pylint: disable=too-many-locals - grad_outputs = grad_outputs[:-3] + def backward( # pylint: disable=too-many-locals + ctx, *grad_outputs: Any + ) -> TupleOfTensors: + grad_outputs: TupleOfTensors = grad_outputs[:-3] saved_tensors = ctx.saved_tensors output = saved_tensors[: len(grad_outputs)] @@ -332,50 +343,53 @@ def backward(ctx, *grad_outputs): # pylint: disable=too-many-locals argnums=argnums, solve=solve, ) - # Prepend None as the vjp for init_params. args_vjps, kwargs_vjps = map_args_back(vjps) ordered_vjps = tuple(args_vjps) + tuple(kwargs_vjps[k] for k in kwargs.keys()) true_vjps = [] - for (_, is_tuple), vjp in zip(args_sign, ordered_vjps): - if is_tuple: - for item in vjp: - true_vjps.append(item) + for (_, arg_seq_type), vjp in zip(args_signs, ordered_vjps): + if arg_seq_type is not None: + true_vjps.extend(vjp) else: true_vjps.append(vjp) return tuple(true_vjps) return ImplicitMetaGradient - def wrapped_solver_fn(*args, **kwargs): + def wrapped_solver_fn( + *args: Any, **kwargs: Any + ) -> Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]: args, kwargs = _signature_bind(solver_fn_signature, *args, **kwargs) keys, vals = list(kwargs.keys()), list(kwargs.values()) - args_sign = [] - flat_args = [] + args_signs: List[Tuple[int, Optional[Union[Type[tuple], Type[list]]]]] = [] + flat_args: List[Any] = [] args_counter = 0 for idx, arg in enumerate(args): if idx in argnums: if isinstance(arg, torch.Tensor): - args_sign.append((args_counter, False)) # start position, is_tuple + args_signs.append((args_counter, None)) # start position, None flat_args.append(arg) args_counter += 1 - elif isinstance(arg, tuple): - args_sign.append((args_counter, True)) # start position, is_tuple - for arg_item in arg: - flat_args.append(arg_item) + elif isinstance(arg, (tuple, list)) and all(map(torch.is_tensor, arg)): + args_signs.append((args_counter, type(arg))) # start position, sequence type + flat_args.extend(arg) args_counter += len(arg) else: - raise RuntimeError('must be tensor or tensor tuple') + raise RuntimeError( + 'custom_root(optimality_fn)(solver_fn)(*args): argument of function ' + 'solver_fn specified with `argnums` should be a torch.Tensor or a tuple of ' + 'torch.Tensor' + ) else: - args_sign.append((args_counter, False)) # start position, is_tuple + args_signs.append((args_counter, None)) # start position, None flat_args.append(arg) args_counter += 1 - args_sign = tuple(args_sign) + args_signs = tuple(args_signs) flat_args = tuple(flat_args) - result = make_custom_vjp_solver_fn(solver_fn, keys, args_sign).apply(*flat_args, *vals) + result = make_custom_vjp_solver_fn(solver_fn, keys, args_signs).apply(*flat_args, *vals) *output, aux, output_is_tensor, output_type = result if output_is_tensor: output = output[0] @@ -389,11 +403,14 @@ def wrapped_solver_fn(*args, **kwargs): def custom_root( - optimality_fn: Callable, + optimality_fn: Callable[..., TensorOrTensors], argnums: Union[int, Tuple[int, ...]], has_aux: bool = False, - solve: Callable = linear_solve.solve_normal_cg(), -) -> Callable[[Callable], Callable]: + solve: Callable[..., TensorOrTensors] = linear_solve.solve_normal_cg(), +) -> Callable[ + [Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]]], + Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]], +]: """Decorator for adding implicit differentiation to a root solver. This wrapper should be used as a decorator: @@ -409,6 +426,8 @@ def solver_fn(params, arg1, arg2, ...): ... return optimal_params + optimal_params = solver_fn(init_params, ...) + The first argument to ``optimality_fn`` and ``solver_fn`` is preserved as the parameter input. The ``argnums`` argument refers to the indices of the variables in ``solver_fn``'s signature. For example, setting ``argnums=(1, 2)`` will compute the gradient of ``optimal_params`` with diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py index 755e9b5c..5d9f4a97 100644 --- a/torchopt/diff/implicit/nn/module.py +++ b/torchopt/diff/implicit/nn/module.py @@ -16,7 +16,7 @@ import functools import itertools -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple import functorch import torch @@ -24,7 +24,7 @@ import torchopt.nn from torchopt import pytree from torchopt.diff.implicit.decorator import custom_root -from torchopt.typing import TensorTree # pylint: disable=unused-import +from torchopt.typing import TensorTree, TupleOfTensors # pylint: disable=unused-import from torchopt.utils import extract_module_containers @@ -33,17 +33,18 @@ def make_residual_from_objective( objective: Callable[..., torch.Tensor] -) -> Callable[..., Tuple[torch.Tensor, ...]]: +) -> Callable[..., TupleOfTensors]: """Make a function that computes the optimality residual of the objective function.""" # pylint: disable-next=redefined-builtin - def residual(self: 'ImplicitMetaGradientModule', *input, **kwargs) -> Tuple[torch.Tensor, ...]: + def residual(self: 'ImplicitMetaGradientModule', *input, **kwargs) -> TupleOfTensors: params_containers = extract_module_containers(self, with_buffers=False)[0] params_containers_backups = [container.copy() for container in params_containers] - flat_params: List[torch.Tensor] - flat_params, params_containers_treespec = pytree.tree_flatten(params_containers) # type: ignore[arg-type] + flat_params: TupleOfTensors + # pylint: disable-next=line-too-long + flat_params, params_containers_treespec = pytree.tree_flatten_as_tuple(params_containers) # type: ignore[arg-type] # pylint: disable-next=redefined-builtin - def objective_fn(flat_params: Tuple[torch.Tensor, ...], *input, **kwargs) -> torch.Tensor: + def objective_fn(flat_params: TupleOfTensors, *input, **kwargs) -> torch.Tensor: flat_grad_tracking_params = flat_params grad_tracking_params_containers: Tuple[ Dict[str, Optional[torch.Tensor]], ... @@ -95,21 +96,21 @@ def wrapped( # pylint: disable=too-many-locals container.copy() for container in meta_params_containers ) - flat_params: List[torch.Tensor] - flat_meta_params: List[torch.Tensor] - flat_params, params_containers_treespec = pytree.tree_flatten( + flat_params: TupleOfTensors + flat_meta_params: TupleOfTensors + flat_params, params_containers_treespec = pytree.tree_flatten_as_tuple( params_containers # type: ignore[arg-type] ) - flat_meta_params, meta_params_containers_treespec = pytree.tree_flatten( + flat_meta_params, meta_params_containers_treespec = pytree.tree_flatten_as_tuple( meta_params_containers # type: ignore[arg-type] ) def optimality_fn( - flat_params: Tuple[torch.Tensor, ...], - flat_meta_params: Tuple[torch.Tensor, ...], + flat_params: TupleOfTensors, + flat_meta_params: TupleOfTensors, *input, # pylint: disable=redefined-builtin **kwargs, - ) -> Tuple[torch.Tensor, ...]: + ) -> TupleOfTensors: flat_grad_tracking_params = flat_params grad_tracking_params_containers: Tuple[ Dict[str, Optional[torch.Tensor]], ... @@ -140,13 +141,13 @@ def optimality_fn( @custom_root(optimality_fn, argnums=1) def solve_fn( - flat_params: Tuple[torch.Tensor, ...], # pylint: disable=unused-argument - flat_meta_params: Tuple[torch.Tensor, ...], # pylint: disable=unused-argument + flat_params: TupleOfTensors, # pylint: disable=unused-argument + flat_meta_params: TupleOfTensors, # pylint: disable=unused-argument *input, # pylint: disable=redefined-builtin **kwargs, - ) -> Tuple[torch.Tensor, ...]: + ) -> TupleOfTensors: solve(self, *input, **kwargs) - return tuple(filter(torch.is_tensor, pytree.tree_leaves(params_containers))) # type: ignore[arg-type] + return tuple(pytree.tree_leaves(params_containers)) # type: ignore[arg-type] # pylint: disable-next=unused-variable flat_optimal_params = solve_fn(flat_params, flat_meta_params, *input, **kwargs) diff --git a/torchopt/distributed/autograd.py b/torchopt/distributed/autograd.py index 9425b4a5..41b6b461 100644 --- a/torchopt/distributed/autograd.py +++ b/torchopt/distributed/autograd.py @@ -15,12 +15,14 @@ """Distributed Autograd.""" from threading import Lock -from typing import Optional, Sequence, Tuple, Union +from typing import Optional, overload import torch import torch.distributed.autograd as autograd from torch.distributed.autograd import context +from torchopt.typing import TensorOrTensors, TupleOfOptionalTensors, TupleOfTensors + __all__ = ['is_available', 'context'] @@ -34,14 +36,14 @@ def is_available(): if is_available(): - # pylint: disable-next=unused-import + # pylint: disable-next=unused-import,ungrouped-imports from torch.distributed.autograd import DistAutogradContext, get_gradients def backward( autograd_ctx_id: int, - tensors: Union[torch.Tensor, Sequence[torch.Tensor]], + tensors: TensorOrTensors, retain_graph: bool = False, - inputs: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]] = None, + inputs: Optional[TensorOrTensors] = None, ) -> None: """Perform distributed backward pass for local parameters. @@ -83,13 +85,32 @@ def backward( else: p.grad = g + @overload + def grad( + autograd_ctx_id: int, + outputs: TensorOrTensors, + inputs: TensorOrTensors, + retain_graph: bool = False, + ) -> TupleOfTensors: + ... + + @overload + def grad( + autograd_ctx_id: int, + outputs: TensorOrTensors, + inputs: TensorOrTensors, + retain_graph: bool = False, + allow_unused: bool = False, + ) -> TupleOfOptionalTensors: + ... + def grad( autograd_ctx_id: int, - outputs: Union[torch.Tensor, Sequence[torch.Tensor]], - inputs: Union[torch.Tensor, Sequence[torch.Tensor]], + outputs: TensorOrTensors, + inputs: TensorOrTensors, retain_graph: bool = False, allow_unused: bool = False, - ) -> Tuple[torch.Tensor, ...]: + ) -> TupleOfOptionalTensors: """Computes and returns the sum of gradients of outputs with respect to the inputs. Args: diff --git a/torchopt/linear_solve/normal_cg.py b/torchopt/linear_solve/normal_cg.py index 63e74e80..78ca75f2 100644 --- a/torchopt/linear_solve/normal_cg.py +++ b/torchopt/linear_solve/normal_cg.py @@ -47,7 +47,7 @@ def _solve_normal_cg( matvec: Callable[[TensorTree], TensorTree], # (x) -> A @ x b: TensorTree, - is_sdp: bool = False, + is_spd: bool = False, ridge: Optional[float] = None, init: Optional[TensorTree] = None, **kwargs, @@ -60,7 +60,7 @@ def _solve_normal_cg( Args: matvec: A function that returns the product between ``A`` and a vector. b: A tree of tensors for the right hand side of the equation. - is_sdp: Whether to assume matrix ``A`` is symmetric definite positive to speedup computation. + is_spd: Whether to assume matrix ``A`` is symmetric definite positive to speedup computation. ridge: Optional ridge regularization. Solves the equation for ``(A.T @ A + ridge * I) @ x = A.T @ b``. init: Optional initialization to be used by normal conjugate gradient. **kwargs: Additional keyword arguments for the conjugate gradient solver. @@ -73,9 +73,9 @@ def _solve_normal_cg( else: example_x = init - if is_sdp: + if is_spd: if ridge is not None: - raise ValueError('ridge must be specified with `is_sdp=False`.') + raise ValueError('ridge must be specified with `is_spd=False`.') # Returns solution for `A @ x = b`. return linalg.cg(matvec, b, x0=init, **kwargs) @@ -96,5 +96,5 @@ def _solve_normal_cg( def solve_normal_cg(**kwargs): """Wrapper for :func:`solve_normal_cg`.""" partial_fn = functools.partial(_solve_normal_cg, **kwargs) - setattr(partial_fn, 'is_sdp', kwargs.get('is_sdp', False)) + setattr(partial_fn, 'is_spd', kwargs.get('is_spd', False)) return partial_fn diff --git a/torchopt/optim/base.py b/torchopt/optim/base.py index 0f666cc1..a87920b4 100644 --- a/torchopt/optim/base.py +++ b/torchopt/optim/base.py @@ -23,6 +23,7 @@ GradientTransformation, OptState, Params, + TupleOfTensors, ) from torchopt.update import apply_updates @@ -49,7 +50,7 @@ def __init__(self, params: Iterable[torch.Tensor], impl: GradientTransformation) raise TypeError(f'{impl} (type: {type(impl).__name__}) is not a GradientTransformation') self.impl: GradientTransformation = impl - self.param_groups: List[Tuple[torch.Tensor]] = [] + self.param_groups: List[TupleOfTensors] = [] self.param_treespecs: List[pytree.PyTreeSpec] = [] self.state_groups: List[OptState] = [] @@ -123,7 +124,7 @@ def f(p): def add_param_group(self, params: 'Params') -> None: """Add a param group to the optimizer's :attr:`param_groups`.""" flat_params, params_treespec = pytree.tree_flatten(params) - flat_params: Tuple[torch.Tensor] = tuple(flat_params) # type: ignore[assignment] + flat_params: TupleOfTensors = tuple(flat_params) self.param_groups.append(flat_params) self.param_treespecs.append(params_treespec) self.state_groups.append(self.impl.init(flat_params)) diff --git a/torchopt/optim/meta/base.py b/torchopt/optim/meta/base.py index 5dd300a7..ef6db66c 100644 --- a/torchopt/optim/meta/base.py +++ b/torchopt/optim/meta/base.py @@ -20,7 +20,7 @@ import torch.nn as nn from torchopt import pytree -from torchopt.typing import GradientTransformation, OptState +from torchopt.typing import GradientTransformation, OptState, TupleOfTensors from torchopt.update import apply_updates from torchopt.utils import extract_module_containers @@ -68,9 +68,8 @@ def step(self, loss: torch.Tensor): # pylint: disable=too-many-locals for i, (param_container, new_state) in enumerate( zip(self.param_containers_groups, self.state_groups) ): - flat_params: Tuple[torch.Tensor, ...] - flat_params, container_treespec = pytree.tree_flatten(param_container) # type: ignore[arg-type,assignment] - flat_params = tuple(flat_params) + flat_params: TupleOfTensors + flat_params, container_treespec = pytree.tree_flatten_as_tuple(param_container) # type: ignore[arg-type] grads = torch.autograd.grad( loss, flat_params, @@ -95,7 +94,7 @@ def step(self, loss: torch.Tensor): # pylint: disable=too-many-locals def add_param_group(self, net: nn.Module) -> None: """Add a param group to the optimizer's :attr:`state_groups`.""" params_container = extract_module_containers(net, with_buffers=False)[0] - flat_params: Tuple[torch.Tensor, ...] = tuple(pytree.tree_leaves(params_container)) # type: ignore[arg-type] + flat_params: TupleOfTensors = tuple(pytree.tree_leaves(params_container)) # type: ignore[arg-type] optimizer_state = self.impl.init(flat_params) self.param_containers_groups.append(params_container) self.state_groups.append(optimizer_state) diff --git a/torchopt/pytree.py b/torchopt/pytree.py index 3d4dd3c1..65aba6a2 100644 --- a/torchopt/pytree.py +++ b/torchopt/pytree.py @@ -14,6 +14,8 @@ # ============================================================================== """The PyTree utilities.""" +from typing import Callable, List, Optional, Tuple + import optree import optree.typing as typing # pylint: disable=unused-import import torch.distributed.rpc as rpc @@ -22,7 +24,28 @@ from torchopt.typing import Future, PyTree, RRef, T -__all__ = [*optree.__all__, 'tree_wait'] +__all__ = [*optree.__all__, 'tree_flatten_as_tuple', 'tree_wait'] + + +def tree_flatten_as_tuple( + tree: PyTree[T], + is_leaf: Optional[Callable[[T], bool]] = None, + *, + none_is_leaf: bool = False, +) -> Tuple[Tuple[T, ...], PyTreeSpec]: + """Flatten a pytree to a tuple of leaves and a PyTreeSpec. + + Args: + tree: The pytree to flatten. + is_leaf: A function that returns True if a given node is a leaf. + none_is_leaf: If :data:`True`, None is considered a leaf rather than a internal node with no + children. + + Returns: + A tuple of (leaves, treespec). + """ + leaves, treespec = tree_flatten(tree, is_leaf, none_is_leaf=none_is_leaf) + return tuple(leaves), treespec def tree_wait(future_tree: PyTree[Future[T]]) -> PyTree[T]: @@ -59,4 +82,4 @@ def tree_local_value(rref_tree: 'PyTree[RRef[T]]'): __all__.extend(['tree_as_rref', 'tree_to_here']) -del optree, rpc, PyTree, T, RRef +del Callable, List, Optional, Tuple, optree, rpc, PyTree, T, RRef diff --git a/torchopt/transform/scale_by_adam.py b/torchopt/transform/scale_by_adam.py index d618ba91..4ccbdbeb 100644 --- a/torchopt/transform/scale_by_adam.py +++ b/torchopt/transform/scale_by_adam.py @@ -33,14 +33,14 @@ # pylint: disable=invalid-name -from typing import NamedTuple, Sequence +from typing import NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation from torchopt.transform.utils import inc_count, tree_map_flat, update_moment -from torchopt.typing import Updates +from torchopt.typing import SequenceOfTensors, Updates __all__ = ['scale_by_adam', 'scale_by_accelerated_adam'] @@ -54,7 +54,7 @@ class ScaleByAdamState(NamedTuple): mu: Updates nu: Updates - count: Sequence[torch.Tensor] # type: ignore + count: SequenceOfTensors # type: ignore def _bias_correction(moment, decay, count, *, already_flattened=False): diff --git a/torchopt/transform/scale_by_schedule.py b/torchopt/transform/scale_by_schedule.py index 5ec85899..49b6abb7 100644 --- a/torchopt/transform/scale_by_schedule.py +++ b/torchopt/transform/scale_by_schedule.py @@ -31,14 +31,14 @@ # ============================================================================== """Preset transformation for scaling updates by learning rate schedules.""" -from typing import NamedTuple, Sequence +from typing import NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation from torchopt.transform.utils import inc_count, tree_map_flat -from torchopt.typing import Schedule +from torchopt.typing import Schedule, SequenceOfTensors __all__ = ['scale_by_schedule'] @@ -47,7 +47,7 @@ class ScaleByScheduleState(NamedTuple): """Maintains count for scale scheduling.""" - count: Sequence[torch.Tensor] # type: ignore + count: SequenceOfTensors # type: ignore def scale_by_schedule(step_size_fn: Schedule) -> GradientTransformation: diff --git a/torchopt/typing.py b/torchopt/typing.py index 12751f96..f1dcd1dd 100644 --- a/torchopt/typing.py +++ b/torchopt/typing.py @@ -14,7 +14,7 @@ # ============================================================================== """Typing utilities.""" -from typing import Callable, Optional, TypeVar, Union +from typing import Callable, List, Optional, Sequence, Tuple, TypeVar, Union from typing_extensions import TypeAlias # Python 3.10+ import torch.distributed.rpc as rpc @@ -38,7 +38,16 @@ 'ScalarOrSchedule', 'PyTree', 'Tensor', + 'OptionalTensor', + 'ListOfTensors', + 'TupleOfTensors', + 'SequenceOfTensors', + 'TensorOrTensors', 'TensorTree', + 'ListOfOptionalTensors', + 'TupleOfOptionalTensors', + 'SequenceOfOptionalTensors', + 'OptionalTensorOrOptionalTensors', 'OptionalTensorTree', 'Future', ] @@ -51,8 +60,19 @@ Schedule: TypeAlias = Callable[[Numeric], Numeric] ScalarOrSchedule: TypeAlias = Union[float, Schedule] +OptionalTensor = Optional[Tensor] + +ListOfTensors = List[Tensor] +TupleOfTensors = Tuple[Tensor, ...] +SequenceOfTensors = Sequence[Tensor] +TensorOrTensors = Union[Tensor, SequenceOfTensors] TensorTree: TypeAlias = PyTreeTypeVar('TensorTree', Tensor) # type: ignore[valid-type] -OptionalTensorTree: TypeAlias = PyTreeTypeVar('OptionalTensorTree', Optional[Tensor]) # type: ignore[valid-type] + +ListOfOptionalTensors = List[OptionalTensor] +TupleOfOptionalTensors = Tuple[OptionalTensor, ...] +SequenceOfOptionalTensors = Sequence[OptionalTensor] +OptionalTensorOrOptionalTensors = Union[OptionalTensor, SequenceOfOptionalTensors] +OptionalTensorTree: TypeAlias = PyTreeTypeVar('OptionalTensorTree', OptionalTensor) # type: ignore[valid-type] # Parameters are arbitrary nests of `torch.Tensor`. Params: TypeAlias = TensorTree diff --git a/torchopt/visual.py b/torchopt/visual.py index 90f161b5..25a66ada 100644 --- a/torchopt/visual.py +++ b/torchopt/visual.py @@ -19,12 +19,13 @@ import warnings from collections import namedtuple -from typing import Generator, Iterable, Mapping, Optional, Sequence, Union, cast +from typing import Generator, Iterable, Mapping, Optional, Union, cast import torch from graphviz import Digraph from pkg_resources import parse_version +from torchopt.typing import TensorOrTensors from torchopt.utils import ModuleState @@ -71,7 +72,7 @@ def truncate(s): # pylint: disable=invalid-name # pylint: disable-next=too-many-branches,too-many-statements,too-many-locals def make_dot( - var: Union[torch.Tensor, Sequence[torch.Tensor]], + var: TensorOrTensors, params: Optional[ Union[ Mapping[str, torch.Tensor], From c8d71f58f7be894d7db746cd54275a5110b15c27 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sun, 30 Oct 2022 00:32:13 +0800 Subject: [PATCH 09/11] chore: suppress unresolved forwardrefs type hint warnings --- docs/source/conf.py | 22 +++++++++++++++++++++- torchopt/alias/adamw.py | 5 ++--- torchopt/diff/implicit/nn/module.py | 4 ++-- torchopt/optim/adamw.py | 4 ++-- torchopt/optim/base.py | 13 ++++--------- torchopt/optim/func/base.py | 6 +++--- torchopt/optim/meta/adamw.py | 4 ++-- torchopt/optim/meta/base.py | 4 ++-- torchopt/update.py | 4 ++-- torchopt/utils.py | 14 +++++++------- 10 files changed, 47 insertions(+), 33 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 694086fe..b213cfaa 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -25,6 +25,7 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # +import logging import os import pathlib import sys @@ -43,6 +44,24 @@ def get_version() -> str: return version.__version__ +try: + import sphinx_autodoc_typehints +except ImportError: + pass +else: + + class RecursiveForwardRefFilter(logging.Filter): + def filter(self, record): + if ( + "name 'TensorTree' is not defined" in record.getMessage() + or "name 'OptionalTensorTree' is not defined" in record.getMessage() + ): + return False + return super().filter(record) + + sphinx_autodoc_typehints._LOGGER.logger.addFilter(RecursiveForwardRefFilter()) + + # -- Project information ----------------------------------------------------- project = 'TorchOpt' @@ -75,7 +94,7 @@ def get_version() -> str: 'sphinxcontrib.bibtex', 'sphinxcontrib.katex', 'sphinx_autodoc_typehints', - 'myst_nb', # This is used for the .ipynb notebooks + 'myst_nb', # this is used for the .ipynb notebooks ] if not os.getenv('READTHEDOCS', None): @@ -120,6 +139,7 @@ def get_version() -> str: 'exclude-members': '__module__, __dict__, __repr__, __str__, __weakref__', } autoclass_content = 'both' +simplify_optional_unions = False # -- Options for bibtex ----------------------------------------------------- diff --git a/torchopt/alias/adamw.py b/torchopt/alias/adamw.py index 4c8d96b4..b088be60 100644 --- a/torchopt/alias/adamw.py +++ b/torchopt/alias/adamw.py @@ -36,8 +36,7 @@ from torchopt.alias.utils import flip_sign_and_add_weight_decay, scale_by_neg_lr from torchopt.combine import chain_flat from torchopt.transform import add_decayed_weights, scale_by_accelerated_adam, scale_by_adam -from torchopt.typing import Params # pylint: disable=unused-import -from torchopt.typing import GradientTransformation, ScalarOrSchedule +from torchopt.typing import GradientTransformation, Params, ScalarOrSchedule __all__ = ['adamw'] @@ -51,7 +50,7 @@ def adamw( weight_decay: float = 1e-2, *, eps_root: float = 0.0, - mask: Optional[Union[Any, Callable[['Params'], Any]]] = None, + mask: Optional[Union[Any, Callable[[Params], Any]]] = None, moment_requires_grad: bool = False, maximize: bool = False, use_accelerated_op: bool = False, diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py index 5d9f4a97..a176ab2b 100644 --- a/torchopt/diff/implicit/nn/module.py +++ b/torchopt/diff/implicit/nn/module.py @@ -24,7 +24,7 @@ import torchopt.nn from torchopt import pytree from torchopt.diff.implicit.decorator import custom_root -from torchopt.typing import TensorTree, TupleOfTensors # pylint: disable=unused-import +from torchopt.typing import TensorTree, TupleOfTensors from torchopt.utils import extract_module_containers @@ -228,7 +228,7 @@ def solve(self, batch, labels): raise NotImplementedError # update parameters # pylint: disable-next=redefined-builtin - def residual(self, *input, **kwargs) -> 'TensorTree': + def residual(self, *input, **kwargs) -> TensorTree: r"""Computes the optimality residual. This method stands for the residual to the optimal parameters after solving the inner diff --git a/torchopt/optim/adamw.py b/torchopt/optim/adamw.py index 9613cab4..24362d59 100644 --- a/torchopt/optim/adamw.py +++ b/torchopt/optim/adamw.py @@ -20,7 +20,7 @@ from torchopt import alias from torchopt.optim.base import Optimizer -from torchopt.typing import Params, ScalarOrSchedule # pylint: disable=unused-import +from torchopt.typing import Params, ScalarOrSchedule __all__ = ['AdamW'] @@ -44,7 +44,7 @@ def __init__( weight_decay: float = 1e-2, *, eps_root: float = 0.0, - mask: Optional[Union[Any, Callable[['Params'], Any]]] = None, + mask: Optional[Union[Any, Callable[[Params], Any]]] = None, maximize: bool = False, use_accelerated_op: bool = False, ) -> None: diff --git a/torchopt/optim/base.py b/torchopt/optim/base.py index a87920b4..e19f1187 100644 --- a/torchopt/optim/base.py +++ b/torchopt/optim/base.py @@ -19,12 +19,7 @@ import torch from torchopt import pytree -from torchopt.typing import ( # pylint: disable=unused-import - GradientTransformation, - OptState, - Params, - TupleOfTensors, -) +from torchopt.typing import GradientTransformation, OptState, Params, TupleOfTensors from torchopt.update import apply_updates @@ -84,11 +79,11 @@ def f(p): pytree.tree_map(f, self.param_groups) # type: ignore[arg-type] - def state_dict(self) -> Tuple['OptState', ...]: + def state_dict(self) -> Tuple[OptState, ...]: """Returns the state of the optimizer.""" return tuple(self.state_groups) - def load_state_dict(self, state_dict: Sequence['OptState']) -> None: + def load_state_dict(self, state_dict: Sequence[OptState]) -> None: """Loads the optimizer state. Args: @@ -121,7 +116,7 @@ def f(p): return loss - def add_param_group(self, params: 'Params') -> None: + def add_param_group(self, params: Params) -> None: """Add a param group to the optimizer's :attr:`param_groups`.""" flat_params, params_treespec = pytree.tree_flatten(params) flat_params: TupleOfTensors = tuple(flat_params) diff --git a/torchopt/optim/func/base.py b/torchopt/optim/func/base.py index f1d0b684..bb7ebf8b 100644 --- a/torchopt/optim/func/base.py +++ b/torchopt/optim/func/base.py @@ -19,7 +19,7 @@ import torch from torchopt.base import GradientTransformation -from torchopt.typing import OptState, Params # pylint: disable=unused-import +from torchopt.typing import OptState, Params from torchopt.update import apply_updates @@ -61,9 +61,9 @@ def __init__(self, impl: GradientTransformation, *, inplace: bool = False) -> No def step( self, loss: torch.Tensor, - params: 'Params', + params: Params, inplace: Optional[bool] = None, - ) -> 'Params': + ) -> Params: r"""Compute the gradients of loss to the network parameters and update network parameters. Graph of the derivative will be constructed, allowing to compute higher order derivative diff --git a/torchopt/optim/meta/adamw.py b/torchopt/optim/meta/adamw.py index 65e154d3..cb91f38f 100644 --- a/torchopt/optim/meta/adamw.py +++ b/torchopt/optim/meta/adamw.py @@ -20,7 +20,7 @@ from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer -from torchopt.typing import Params, ScalarOrSchedule # pylint: disable=unused-import +from torchopt.typing import Params, ScalarOrSchedule __all__ = ['MetaAdamW'] @@ -44,7 +44,7 @@ def __init__( weight_decay: float = 1e-2, *, eps_root: float = 0.0, - mask: Optional[Union[Any, Callable[['Params'], Any]]] = None, + mask: Optional[Union[Any, Callable[[Params], Any]]] = None, moment_requires_grad: bool = False, maximize: bool = False, use_accelerated_op: bool = False, diff --git a/torchopt/optim/meta/base.py b/torchopt/optim/meta/base.py index ef6db66c..b83d3fd2 100644 --- a/torchopt/optim/meta/base.py +++ b/torchopt/optim/meta/base.py @@ -99,7 +99,7 @@ def add_param_group(self, net: nn.Module) -> None: self.param_containers_groups.append(params_container) self.state_groups.append(optimizer_state) - def state_dict(self) -> Tuple['OptState', ...]: + def state_dict(self) -> Tuple[OptState, ...]: """Extract the references of the optimizer states. Note that the states are references, so any in-place operations will change the states @@ -107,6 +107,6 @@ def state_dict(self) -> Tuple['OptState', ...]: """ return tuple(self.state_groups) - def load_state_dict(self, state_dict: Sequence['OptState']) -> None: + def load_state_dict(self, state_dict: Sequence[OptState]) -> None: """Load the references of the optimizer states.""" self.state_groups[:] = list(state_dict) diff --git a/torchopt/update.py b/torchopt/update.py index bdcdc301..85e93673 100644 --- a/torchopt/update.py +++ b/torchopt/update.py @@ -32,13 +32,13 @@ """Helper functions for applying updates.""" from torchopt import pytree -from torchopt.typing import Params, Updates # pylint: disable=unused-import +from torchopt.typing import Params, Updates __all__ = ['apply_updates'] -def apply_updates(params: 'Params', updates: 'Updates', *, inplace: bool = True) -> 'Params': +def apply_updates(params: Params, updates: Updates, *, inplace: bool = True) -> Params: """Applies an update to the corresponding parameters. This is a utility functions that applies an update to a set of parameters, and then returns the diff --git a/torchopt/utils.py b/torchopt/utils.py index 35214a39..3301f92c 100644 --- a/torchopt/utils.py +++ b/torchopt/utils.py @@ -35,7 +35,7 @@ import torch.nn as nn from torchopt import pytree -from torchopt.typing import OptState, TensorTree # pylint: disable=unused-import +from torchopt.typing import OptState, TensorTree if TYPE_CHECKING: @@ -64,7 +64,7 @@ class ModuleState(NamedTuple): CopyMode: TypeAlias = Literal['reference', 'copy', 'deepcopy', 'ref', 'clone', 'deepclone'] -def stop_gradient(target: Union['TensorTree', ModuleState, nn.Module, 'MetaOptimizer']) -> None: +def stop_gradient(target: Union[TensorTree, ModuleState, nn.Module, 'MetaOptimizer']) -> None: """Stop the gradient for the input object. Since a tensor use :attr:`grad_fn` to connect itself with the previous computation graph, the @@ -123,7 +123,7 @@ def extract_state_dict( with_buffers: bool = True, enable_visual: bool = False, visual_prefix: str = '', -) -> Tuple['OptState', ...]: +) -> Tuple[OptState, ...]: ... @@ -137,7 +137,7 @@ def extract_state_dict( detach_buffers: bool = False, enable_visual: bool = False, visual_prefix: str = '', -) -> Union[ModuleState, Tuple['OptState', ...]]: +) -> Union[ModuleState, Tuple[OptState, ...]]: """Extract target state. Since a tensor use :attr:`grad_fn` to connect itself with the previous computation graph, the @@ -312,7 +312,7 @@ def update_container(container, items): def recover_state_dict( target: Union[nn.Module, 'MetaOptimizer'], - state: Union[ModuleState, Sequence['OptState']], + state: Union[ModuleState, Sequence[OptState]], ) -> None: """Recover state. @@ -478,8 +478,8 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor: def module_detach_( - target: Union['TensorTree', ModuleState, nn.Module, 'MetaOptimizer'] -) -> Union['TensorTree', ModuleState, nn.Module, 'MetaOptimizer']: + target: Union[TensorTree, ModuleState, nn.Module, 'MetaOptimizer'] +) -> Union[TensorTree, ModuleState, nn.Module, 'MetaOptimizer']: """Detach a module from the computation graph. Args: From 429e06feef4de6a7c75c156f7bd045c56d530517 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 31 Oct 2022 18:53:55 +0800 Subject: [PATCH 10/11] chore: add optree to conda recipes --- conda-recipe.yaml | 3 ++- docs/conda-recipe.yaml | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/conda-recipe.yaml b/conda-recipe.yaml index edc0767b..97a9dcca 100644 --- a/conda-recipe.yaml +++ b/conda-recipe.yaml @@ -54,10 +54,11 @@ dependencies: - gxx = 10 - nvidia/label/cuda-11.7.1::cuda-nvcc - nvidia/label/cuda-11.7.1::cuda-cudart-dev - - patchelf >= 0.9 + - patchelf >= 0.14 - pybind11 # Misc + - optree >= 0.3.0 - typing-extensions >= 4.0.0 - numpy - matplotlib-base diff --git a/docs/conda-recipe.yaml b/docs/conda-recipe.yaml index 4c6dba21..ad28ac50 100644 --- a/docs/conda-recipe.yaml +++ b/docs/conda-recipe.yaml @@ -33,6 +33,7 @@ dependencies: # Learning - pytorch::pytorch >= 1.13 # sync with project.dependencies - pytorch::cpuonly + - pytorch::pytorch-mutex = *=*cpu* - pip: - torchviz - sphinxcontrib-katex # for documentation @@ -47,6 +48,7 @@ dependencies: - pybind11 # Misc + - optree >= 0.3.0 - typing-extensions >= 4.0.0 - numpy - matplotlib-base From 2dc79c66f382c6b06bbff6904ae1e5d5de13fa8b Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 31 Oct 2022 22:45:55 +0800 Subject: [PATCH 11/11] refactor: rename variables --- torchopt/diff/implicit/nn/module.py | 58 ++++++++++++++--------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py index a176ab2b..3b08118a 100644 --- a/torchopt/diff/implicit/nn/module.py +++ b/torchopt/diff/implicit/nn/module.py @@ -31,12 +31,12 @@ __all__ = ['ImplicitMetaGradientModule'] -def make_residual_from_objective( +def make_optimality_from_objective( objective: Callable[..., torch.Tensor] ) -> Callable[..., TupleOfTensors]: - """Make a function that computes the optimality residual of the objective function.""" + """Make a function that computes the optimality function of the objective function.""" # pylint: disable-next=redefined-builtin - def residual(self: 'ImplicitMetaGradientModule', *input, **kwargs) -> TupleOfTensors: + def optimality(self: 'ImplicitMetaGradientModule', *input, **kwargs) -> TupleOfTensors: params_containers = extract_module_containers(self, with_buffers=False)[0] params_containers_backups = [container.copy() for container in params_containers] flat_params: TupleOfTensors @@ -69,7 +69,7 @@ def objective_fn(flat_params: TupleOfTensors, *input, **kwargs) -> torch.Tensor: flat_grads = objective_grad_fn(flat_params, *input, **kwargs) return flat_grads - return residual + return optimality def enable_implicit_gradients( @@ -131,7 +131,7 @@ def optimality_fn( ): container.update(grad_tracking_container) - return self.residual(*input, **kwargs) + return self.optimality(*input, **kwargs) finally: for container, container_backup in itertools.chain( zip(params_containers, params_containers_backups), @@ -160,28 +160,28 @@ def solve_fn( class ImplicitMetaGradientModule(torchopt.nn.MetaGradientModule): """The base class for differentiable implicit meta-gradient models.""" - _custom_residual: bool + _custom_optimality: bool _custom_objective: bool def __init_subclass__(cls) -> None: """Initialize the subclass.""" super().__init_subclass__() - residual = getattr(cls, 'residual', ImplicitMetaGradientModule.residual) + optimality = getattr(cls, 'optimality', ImplicitMetaGradientModule.optimality) objective = getattr(cls, 'objective', ImplicitMetaGradientModule.objective) - cls._custom_residual = residual is not ImplicitMetaGradientModule.residual + cls._custom_optimality = optimality is not ImplicitMetaGradientModule.optimality cls._custom_objective = objective is not ImplicitMetaGradientModule.objective - if cls._custom_residual: - if isinstance(residual, staticmethod): - raise TypeError('residual() must not be a staticmethod.') - if isinstance(residual, classmethod): - raise TypeError('residual() must not be a classmethod.') - if not callable(residual): - raise TypeError('residual() must be callable.') + if cls._custom_optimality: + if isinstance(optimality, staticmethod): + raise TypeError('optimality() must not be a staticmethod.') + if isinstance(optimality, classmethod): + raise TypeError('optimality() must not be a classmethod.') + if not callable(optimality): + raise TypeError('optimality() must be callable.') elif not cls._custom_objective: raise TypeError( - 'ImplicitMetaGradientModule requires either an residual() or an objective() function' + 'ImplicitMetaGradientModule requires either an optimality() or an objective() function' ) else: if isinstance(objective, staticmethod): @@ -191,7 +191,7 @@ def __init_subclass__(cls) -> None: if not callable(objective): raise TypeError('objective() must be callable.') - cls.residual = make_residual_from_objective(objective) # type: ignore[assignment] + cls.optimality = make_optimality_from_objective(objective) # type: ignore[assignment] cls.solve = enable_implicit_gradients(cls.solve) # type: ignore[assignment] @@ -228,25 +228,25 @@ def solve(self, batch, labels): raise NotImplementedError # update parameters # pylint: disable-next=redefined-builtin - def residual(self, *input, **kwargs) -> TensorTree: + def optimality(self, *input, **kwargs) -> TensorTree: r"""Computes the optimality residual. - This method stands for the residual to the optimal parameters after solving the inner - optimization problem (:meth:`solve`), i.e.: + This method stands for the optimality residual to the optimal parameters after solving the + inner optimization problem (:meth:`solve`), i.e.: .. code-block:: python module.solve(*input, **kwargs) - module.residual(*input, **kwargs) # -> 0 + module.optimality(*input, **kwargs) # -> 0 - 1. For gradient-based optimization, the :meth:`residual` function is the KKT condition, + 1. For gradient-based optimization, the :meth:`optimality` function is the KKT condition, usually it is the gradients of the :meth:`objective` function with respect to the module parameters (not the meta-parameters). If this method is not implemented, it will be automatically derived from the gradient of the :meth:`objective` function. .. math:: - \text{residual} = \nabla_{\boldsymbol{x}} f (\boldsymbol{x}, \boldsymbol{\theta}) \to \boldsymbol{0} + \text{optimality residual} = \nabla_{\boldsymbol{x}} f (\boldsymbol{x}, \boldsymbol{\theta}) \to \boldsymbol{0} where :math:`\boldsymbol{x}` is the joint vector of the module parameters and :math:`\boldsymbol{\theta}` is the joint vector of the meta-parameters. @@ -254,27 +254,27 @@ def residual(self, *input, **kwargs) -> TensorTree: References: - Karush-Kuhn-Tucker (KKT) conditions: https://en.wikipedia.org/wiki/Karush-Kuhn-Tucker_conditions - 2. For fixed point iteration, the :meth:`residual` function can be the residual of the + 2. For fixed point iteration, the :meth:`optimality` function can be the residual of the parameters between iterations, i.e.: .. math:: - \text{residual} = f (\boldsymbol{x}, \boldsymbol{\theta}) - \boldsymbol{x} \to \boldsymbol{0} + \text{optimality residual} = f (\boldsymbol{x}, \boldsymbol{\theta}) - \boldsymbol{x} \to \boldsymbol{0} where :math:`\boldsymbol{x}` is the joint vector of the module parameters and :math:`\boldsymbol{\theta}` is the joint vector of the meta-parameters. Returns: - A tree of tensors, the residual to the optimal parameters after solving the inner - optimization problem. - """ + A tree of tensors, the optimality residual to the optimal parameters after solving the + inner optimization problem. + """ # pylint: disable=line-too-long raise NotImplementedError # pylint: disable-next=redefined-builtin def objective(self, *input, **kwargs) -> torch.Tensor: """Computes the objective function value. - This method is used to calculate the :meth:`residual` if it is not implemented. + This method is used to calculate the :meth:`optimality` if it is not implemented. Otherwise, this method is optional. Returns: pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy