From 6e1fa56f6307fb1435de2e014b4794aabf749a1b Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 3 Mar 2023 10:20:04 +0000 Subject: [PATCH 01/10] fix: fix transpose empty iterable with `zip(*nested)` --- torchopt/diff/implicit/nn/module.py | 20 +++++-- torchopt/diff/zero_order/nn/module.py | 5 +- torchopt/transform/scale_by_adam.py | 75 ++++++++++++++++++++++++--- 3 files changed, 87 insertions(+), 13 deletions(-) diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py index 2aceb656..b66be992 100644 --- a/torchopt/diff/implicit/nn/module.py +++ b/torchopt/diff/implicit/nn/module.py @@ -84,8 +84,14 @@ def make_optimality_from_objective( raise TypeError('The objective function is not defined.') def optimality(self: ImplicitMetaGradientModule, *input: Any, **kwargs: Any) -> TupleOfTensors: - params_names, flat_params = tuple(zip(*self.named_parameters())) - meta_params_names, flat_meta_params = tuple(zip(*self.named_meta_parameters())) + named_params = tuple(self.named_parameters()) + named_meta_params = tuple(self.named_meta_parameters()) + if len(named_params) == 0: + raise RuntimeError('The module has no parameters.') + if len(named_meta_params) == 0: + raise RuntimeError('The module has no meta-parameters.') + params_names, flat_params = tuple(zip(*named_params)) + meta_params_names, flat_meta_params = tuple(zip(*named_meta_params)) objective_grad_fn = functorch.grad(_stateless_objective_fn, argnums=0) return objective_grad_fn( @@ -132,8 +138,14 @@ def stateless_solver_fn( @functools.wraps(cls_solve) def wrapped(self: ImplicitMetaGradientModule, *input: Any, **kwargs: Any) -> Any: """Solve the optimization problem.""" - params_names, flat_params = tuple(zip(*self.named_parameters())) - meta_params_names, flat_meta_params = tuple(zip(*self.named_meta_parameters())) + named_params = tuple(self.named_parameters()) + named_meta_params = tuple(self.named_meta_parameters()) + if len(named_params) == 0: + raise RuntimeError('The module has no parameters.') + if len(named_meta_params) == 0: + raise RuntimeError('The module has no meta-parameters.') + params_names, flat_params = tuple(zip(*named_params)) + meta_params_names, flat_meta_params = tuple(zip(*named_meta_params)) flat_optimal_params, output = stateless_solver_fn( flat_params, diff --git a/torchopt/diff/zero_order/nn/module.py b/torchopt/diff/zero_order/nn/module.py index aa75890c..b1d7f6fd 100644 --- a/torchopt/diff/zero_order/nn/module.py +++ b/torchopt/diff/zero_order/nn/module.py @@ -49,7 +49,10 @@ def enable_zero_order_gradients( @functools.wraps(cls_forward) def wrapped(self: ZeroOrderGradientModule, *input: Any, **kwargs: Any) -> torch.Tensor: """Do the forward pass calculation.""" - params_names, flat_params = tuple(zip(*self.named_parameters())) + named_params = tuple(self.named_parameters()) + if len(named_params) == 0: + raise RuntimeError('The module has no parameters.') + params_names, flat_params = tuple(zip(*named_params)) @zero_order(self.sample, argnums=0, method=method, num_samples=num_samples, sigma=sigma) def forward_fn( diff --git a/torchopt/transform/scale_by_adam.py b/torchopt/transform/scale_by_adam.py index c1f6274a..9b7159ae 100644 --- a/torchopt/transform/scale_by_adam.py +++ b/torchopt/transform/scale_by_adam.py @@ -301,9 +301,40 @@ def update_fn( count_inc = inc_count.impl(updates, state.count, already_flattened=True) # type: ignore[attr-defined] op = AdamOp(b1=b1, b2=b2, eps=eps, eps_root=eps_root, inplace=inplace) - out = tree_map_flat(op, state.mu, state.nu, updates, count_inc) - new_mu, new_nu, new_updates = tuple(zip(*out)) # transpose + def op_fn( + mu: torch.Tensor | None, + nu: torch.Tensor | None, + update: torch.Tensor | None, + count: torch.Tensor | None, + ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: + if mu is None: + return (None, None, None) + return op(mu, nu, update, count) # type: ignore[arg-type] + + out = tree_map_flat( + op_fn, + state.mu, + state.nu, + updates, + count_inc, + none_is_leaf=True, + ) + + if len(out) == 0: + new_mu, new_nu, new_updates = (), (), () + else: + new_mu, new_nu, new_updates = tuple(zip(*out)) # transpose + + new_mu, new_nu, new_updates = ( + new if type(new) is type(old) else type(old)(new) + for new, old in ( + (new_mu, state.mu), + (new_nu, state.nu), + (new_updates, updates), + ) + ) + return new_updates, ScaleByAdamState(mu=new_mu, nu=new_nu, count=count_inc) else: @@ -318,15 +349,43 @@ def update_fn( ) -> tuple[Updates, OptState]: count_inc = inc_count.impl(updates, state.count, already_flattened=False) # type: ignore[attr-defined] - treespec = pytree.tree_structure(updates, none_is_leaf=True) - - op = AdamOp(b1=b1, b2=b2, eps=eps, eps_root=eps_root, inplace=inplace) - out = pytree.tree_map(op, state.mu, state.nu, updates, count_inc) - new_mu: Updates new_nu: Updates new_updates: Updates - new_mu, new_nu, new_updates = pytree.tree_transpose(treespec, TRIPLE_PYTREE_SPEC, out) # type: ignore[misc] + + treespec = pytree.tree_structure(updates, none_is_leaf=True) + if treespec.num_leaves > 0: + op = AdamOp(b1=b1, b2=b2, eps=eps, eps_root=eps_root, inplace=inplace) + + def op_fn( + mu: torch.Tensor | None, + nu: torch.Tensor | None, + update: torch.Tensor | None, + count: torch.Tensor | None, + ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: + if mu is None: + return (None, None, None) + return op(mu, nu, update, count) # type: ignore[arg-type] + + out = pytree.tree_map( + op_fn, + state.mu, + state.nu, + updates, + count_inc, + none_is_leaf=True, + ) + + new_mu, new_nu, new_updates = pytree.tree_transpose( # type: ignore[misc] + treespec, + TRIPLE_PYTREE_SPEC, + out, + ) + else: + new_mu = pytree.tree_unflatten(treespec, ()) + new_nu = pytree.tree_unflatten(treespec, ()) + new_updates = pytree.tree_unflatten(treespec, ()) + return new_updates, ScaleByAdamState(mu=new_mu, nu=new_nu, count=count_inc) def init_fn(params: Params) -> OptState: From fc02c540e398b1c6e5366e312c6282ef09a0eb89 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 3 Mar 2023 14:08:13 +0000 Subject: [PATCH 02/10] test: add test for empty parameters --- tests/test_alias.py | 49 ++++++++++++++++++++++++++++++ tests/test_implicit.py | 38 +++++++++++++++++++++++ tests/test_nn.py | 69 +++++++++++++++++++++++++++++++++++++++++- 3 files changed, 155 insertions(+), 1 deletion(-) diff --git a/tests/test_alias.py b/tests/test_alias.py index b609cf58..c5cb5e90 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -24,7 +24,56 @@ import helpers import torchopt +from torchopt import pytree from torchopt.alias.utils import _set_use_chain_flat +from torchopt.typing import TensorTree + + +@helpers.parametrize( + optimizer=[ + torchopt.sgd, + torchopt.adam, + torchopt.adamw, + torchopt.rmsprop, + ], + tensortree=[ + {}, + (), + [], + (None,), + {'a': (), 'b': {'c': []}, 'd': None}, + ], + maximize=[False, True], + inplace=[True, False], + use_chain_flat=[True, False], +) +def test_empty( + optimizer: Callable, + tensortree: TensorTree, + maximize: bool, + inplace: bool, + use_chain_flat: bool, +) -> None: + _set_use_chain_flat(use_chain_flat) + + params = pytree.tree_map(lambda x: x, tensortree) + grads = pytree.tree_map(lambda x: x, tensortree) + + optim = optimizer(1e-3, maximize=maximize) + optim_state = optim.init(params) + updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) + _ = torchopt.apply_updates(params, updates) + + try: + optim = optimizer(1e-3, maximize=maximize, use_accelerated_op=True) + except TypeError: + pass + else: + optim_state = optim.init(params) + updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) + _ = torchopt.apply_updates(params, updates) + + _set_use_chain_flat(True) @helpers.parametrize( diff --git a/tests/test_implicit.py b/tests/test_implicit.py index 8672c588..fa3b6f86 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -690,3 +690,41 @@ def outer_level(params_jax, l2reg_jax, xs, ys, xq, yq): l2reg_jax_as_tensor = torch.tensor(np.asarray(l2reg_jax), dtype=dtype) helpers.assert_all_close(l2reg_torch, l2reg_jax_as_tensor) + + +def test_module_empty_parameters() -> None: + class EmptyParameters(ImplicitMetaGradientModule): + def __init__(self, x): + super().__init__() + self.x = x + + def optimality(self): + return (self.x,) + + def solve(self): + pass + + model = EmptyParameters(torch.zeros(8)) + with pytest.raises(RuntimeError, match='The module has no parameters.'): + model.solve() + + model = EmptyParameters(torch.zeros(8)) + model.register_parameter('y', torch.zeros(8, requires_grad=True)) + with pytest.raises(RuntimeError, match='The module has no meta-parameters.'): + model.solve() + + model = EmptyParameters(torch.zeros(8, requires_grad=True)) + with pytest.raises(RuntimeError, match='The module has no parameters.'): + model.solve() + + model = EmptyParameters(torch.zeros(8, requires_grad=True)) + model.register_parameter('y', torch.zeros(8, requires_grad=True)) + model.solve() + + model = EmptyParameters(nn.Linear(8, 8).eval()) + with pytest.raises(RuntimeError, match='The module has no meta-parameters.'): + model.solve() + + model = EmptyParameters(nn.Linear(8, 8)) + model.register_parameter('y', torch.zeros(8, requires_grad=True)) + model.solve() diff --git a/tests/test_nn.py b/tests/test_nn.py index 1b48c06b..1e524ba5 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -69,7 +69,13 @@ def test_register_tensors() -> None: assert m._meta_parameters['x'] is x assert m._parameters['y'] is y - assert hasattr(m, 'z') and m.z is z and 'z' not in m._buffers + assert ( + hasattr(m, 'z') + and m.z is z + and 'z' not in m._meta_parameters + and 'z' not in m._parameters + and 'z' not in m._buffers + ) del m.x object.__setattr__(m, 'x', x) @@ -82,6 +88,67 @@ def test_register_tensors() -> None: m.b = b assert m.b is b and 'b' in m._buffers + m = torchopt.nn.MetaGradientModule(x, b) + + with pytest.raises( + TypeError, + match=re.escape('parameter name should be a string. Got bytes'), + ): + m.register_meta_parameter(b'x', x) + + with pytest.raises( + KeyError, + match=re.escape("parameter name can't contain '.'"), + ): + m.register_meta_parameter('x.x', x) + + with pytest.raises( + KeyError, + match=re.escape("parameter name can't be empty string ''"), + ): + m.register_meta_parameter('', x) + + m.register_buffer('z', None) + with pytest.raises( + KeyError, + match=re.escape("attribute 'z' already exists"), + ): + m.register_meta_parameter('z', x) + + with pytest.raises( + ValueError, + match=re.escape( + "cannot assign Tensor that is a meta-parameter to parameter 'x'. " + 'Use self.register_meta_parameter() instead.' + ), + ): + m.register_parameter('x', x) + + m.x = x + with pytest.raises( + KeyError, + match=re.escape("attribute 'x' already exists"), + ): + m.register_parameter('x', x) + + with pytest.raises( + TypeError, + match=re.escape('parameter name should be a string. Got bytes'), + ): + m.register_parameter(b'y', y) + + with pytest.raises( + KeyError, + match=re.escape("parameter name can't contain '.'"), + ): + m.register_parameter('y.x', y) + + with pytest.raises( + KeyError, + match=re.escape("parameter name can't be empty string ''"), + ): + m.register_parameter('', y) + def test_no_super_init() -> None: class NoSuper1(torchopt.nn.MetaGradientModule): From 95af6b1e079630cdae8ce14fe19a95d4afcffa89 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 3 Mar 2023 14:09:11 +0000 Subject: [PATCH 03/10] style: use postional-only arguments --- torchopt/diff/implicit/nn/module.py | 35 ++++++++++++++++------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py index b66be992..86408a54 100644 --- a/torchopt/diff/implicit/nn/module.py +++ b/torchopt/diff/implicit/nn/module.py @@ -36,38 +36,40 @@ def _stateless_objective_fn( - __flat_params: TupleOfTensors, - __flat_meta_params: TupleOfTensors, - __params_names: Iterable[str], - __meta_params_names: Iterable[str], + flat_params: TupleOfTensors, + flat_meta_params: TupleOfTensors, + params_names: Iterable[str], + meta_params_names: Iterable[str], self: ImplicitMetaGradientModule, + /, *input: Any, **kwargs: Any, ) -> torch.Tensor: with reparametrize( self, itertools.chain( - zip(__params_names, __flat_params), - zip(__meta_params_names, __flat_meta_params), + zip(params_names, flat_params), + zip(meta_params_names, flat_meta_params), ), ): return self.objective(*input, **kwargs) def _stateless_optimality_fn( - __flat_params: TupleOfTensors, - __flat_meta_params: TupleOfTensors, - __params_names: Iterable[str], - __meta_params_names: Iterable[str], + flat_params: TupleOfTensors, + flat_meta_params: TupleOfTensors, + params_names: Iterable[str], + meta_params_names: Iterable[str], self: ImplicitMetaGradientModule, + /, *input: Any, **kwargs: Any, ) -> TupleOfTensors: with reparametrize( self, itertools.chain( - zip(__params_names, __flat_params), - zip(__meta_params_names, __flat_meta_params), + zip(params_names, flat_params), + zip(meta_params_names, flat_meta_params), ), ): return self.optimality(*input, **kwargs) @@ -121,12 +123,13 @@ def enable_implicit_gradients( @custom_root(_stateless_optimality_fn, argnums=1, has_aux=True, **solve_kwargs) def stateless_solver_fn( # pylint: disable=unused-argument - __flat_params: TupleOfTensors, - __flat_meta_params: TupleOfTensors, - __params_names: Iterable[str], - __meta_params_names: Iterable[str], + flat_params: TupleOfTensors, + flat_meta_params: TupleOfTensors, + params_names: Iterable[str], + meta_params_names: Iterable[str], # pylint: enable=unused-argument self: ImplicitMetaGradientModule, + /, *input: Any, **kwargs: Any, ) -> tuple[TupleOfTensors, Any]: From e0c41cfe975d879d1b0434279961b673ab085266 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 3 Mar 2023 14:17:56 +0000 Subject: [PATCH 04/10] docs(CHANGELOG): update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a884496..ea37106e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,7 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Fix transpose empty iterable with `zip(*nested)` in transformations by [@XuehaiPan](https://github.com/XuehaiPan) in [#145](https://github.com/metaopt/torchopt/pull/145). ### Removed From fd1642c478ef906581f3d1a418bd2a2bb55ddb81 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 3 Mar 2023 15:35:35 +0000 Subject: [PATCH 05/10] test: add test for empty parameters --- tests/test_implicit.py | 130 +++++++++++++++++++++++++++- tests/test_zero_order.py | 53 ++++++++++++ torchopt/diff/implicit/nn/module.py | 32 +++---- 3 files changed, 198 insertions(+), 17 deletions(-) diff --git a/tests/test_implicit.py b/tests/test_implicit.py index fa3b6f86..5ed08561 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -16,6 +16,7 @@ from __future__ import annotations import copy +import re from collections import OrderedDict from types import FunctionType @@ -698,8 +699,8 @@ def __init__(self, x): super().__init__() self.x = x - def optimality(self): - return (self.x,) + def objective(self): + return self.x def solve(self): pass @@ -717,6 +718,10 @@ def solve(self): with pytest.raises(RuntimeError, match='The module has no parameters.'): model.solve() + model = EmptyParameters(torch.zeros(8, requires_grad=True)) + with pytest.raises(RuntimeError, match='The module has no parameters.'): + model.optimality() + model = EmptyParameters(torch.zeros(8, requires_grad=True)) model.register_parameter('y', torch.zeros(8, requires_grad=True)) model.solve() @@ -728,3 +733,124 @@ def solve(self): model = EmptyParameters(nn.Linear(8, 8)) model.register_parameter('y', torch.zeros(8, requires_grad=True)) model.solve() + + +def test_module_enable_implicit_gradients_twice() -> None: + class MyModule(torchopt.nn.ImplicitMetaGradientModule): + def objective(self): + return torch.tensor(0.0) + + def solve(self): + pass + + from torchopt.diff.implicit.nn.module import enable_implicit_gradients + + with pytest.raises( + TypeError, + match='Implicit gradients are already enabled for the `solve` method.', + ): + enable_implicit_gradients(MyModule) + + +def test_module_abstract_methods() -> None: + class MyModule1(torchopt.nn.ImplicitMetaGradientModule): + def objective(self): + return torch.tensor(0.0) + + with pytest.raises( + TypeError, + match="Can't instantiate abstract class", + ): + MyModule1() + + with pytest.raises( + TypeError, + match=re.escape( + 'ImplicitMetaGradientModule requires either an optimality() method or an objective() method' + ), + ): + + class MyModule2(torchopt.nn.ImplicitMetaGradientModule): + def solve(self): + pass + + class MyModule3(torchopt.nn.ImplicitMetaGradientModule): + def optimality(self): + return () + + def solve(self): + pass + + with pytest.raises( + TypeError, + match=re.escape('method optimality() must not be a staticmethod.'), + ): + + class MyModule4(torchopt.nn.ImplicitMetaGradientModule): + @staticmethod + def optimality(): + return () + + def solve(self): + pass + + with pytest.raises( + TypeError, + match=re.escape('method optimality() must not be a classmethod.'), + ): + + class MyModule5(torchopt.nn.ImplicitMetaGradientModule): + @classmethod + def optimality(self): + return () + + def solve(self): + pass + + with pytest.raises( + TypeError, + match=re.escape('method optimality() must be callable.'), + ): + + class MyModule6(torchopt.nn.ImplicitMetaGradientModule): + optimality = 0 + + def solve(self): + pass + + with pytest.raises( + TypeError, + match=re.escape('method optimality() must not be a staticmethod.'), + ): + + class MyModule7(torchopt.nn.ImplicitMetaGradientModule): + @staticmethod + def optimality(): + return () + + def solve(self): + pass + + with pytest.raises( + TypeError, + match=re.escape('method objective() must not be a classmethod.'), + ): + + class MyModule8(torchopt.nn.ImplicitMetaGradientModule): + @classmethod + def objective(self): + return () + + def solve(self): + pass + + with pytest.raises( + TypeError, + match=re.escape('method objective() must be callable.'), + ): + + class MyModule9(torchopt.nn.ImplicitMetaGradientModule): + objective = 0 + + def solve(self): + pass diff --git a/tests/test_zero_order.py b/tests/test_zero_order.py index ac7ae840..3a87d6bc 100644 --- a/tests/test_zero_order.py +++ b/tests/test_zero_order.py @@ -14,6 +14,7 @@ # ============================================================================== import functorch +import pytest import torch import torch.nn as nn import torch.nn.functional as F @@ -117,3 +118,55 @@ def sample(self, sample_shape=torch.Size()): optimizer.zero_grad() loss.backward() # compute gradients optimizer.step() # update network parameters + + +def test_module_enable_zero_order_gradients_twice() -> None: + class MyModule(torchopt.nn.ZeroOrderGradientModule): + def forward(self): + return torch.tensor(0.0) + + def sample(self, sample_shape): + return torch.tensor(0.0) + + from torchopt.diff.zero_order.nn.module import enable_zero_order_gradients + + with pytest.raises( + TypeError, + match='Zero-order gradient estimation is already enabled for the `forward` method.', + ): + enable_zero_order_gradients(MyModule) + + +def test_module_empty_parameters() -> None: + class MyModule(torchopt.nn.ZeroOrderGradientModule): + def forward(self): + return torch.tensor(0.0) + + def sample(self, sample_shape): + return torch.tensor(0.0) + + m = MyModule() + with pytest.raises(RuntimeError, match='The module has no parameters.'): + m() + + +def test_module_abstract_methods() -> None: + class MyModule1(torchopt.nn.ZeroOrderGradientModule): + def forward(self): + return torch.tensor(0.0) + + with pytest.raises( + TypeError, + match="Can't instantiate abstract class", + ): + MyModule1() + + class MyModule2(torchopt.nn.ZeroOrderGradientModule): + def sample(self, sample_shape): + return torch.tensor(0.0) + + with pytest.raises( + TypeError, + match="Can't instantiate abstract class", + ): + MyModule2() diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py index 86408a54..9d6eb35b 100644 --- a/torchopt/diff/implicit/nn/module.py +++ b/torchopt/diff/implicit/nn/module.py @@ -20,6 +20,7 @@ import abc import functools +import inspect import itertools from typing import Any, Iterable @@ -79,10 +80,9 @@ def make_optimality_from_objective( cls: type[ImplicitMetaGradientModule], ) -> type[ImplicitMetaGradientModule]: """Derives the optimality function of the objective function.""" - if ( - getattr(cls, 'objective', ImplicitMetaGradientModule.objective) - is ImplicitMetaGradientModule.objective - ): + static_super_objective = inspect.getattr_static(ImplicitMetaGradientModule, 'objective') + static_cls_optimality = inspect.getattr_static(cls, 'optimality', static_super_objective) + if static_cls_optimality is static_super_objective: raise TypeError('The objective function is not defined.') def optimality(self: ImplicitMetaGradientModule, *input: Any, **kwargs: Any) -> TupleOfTensors: @@ -167,7 +167,7 @@ def wrapped(self: ImplicitMetaGradientModule, *input: Any, **kwargs: Any) -> Any return cls -class ImplicitMetaGradientModule(MetaGradientModule): +class ImplicitMetaGradientModule(MetaGradientModule, metaclass=abc.ABCMeta): """The base class for differentiable implicit meta-gradient models.""" _custom_optimality: bool @@ -179,28 +179,30 @@ def __init_subclass__(cls, linear_solve: LinearSolver | None = None) -> None: super().__init_subclass__() cls.linear_solve = linear_solve - optimality = getattr(cls, 'optimality', ImplicitMetaGradientModule.optimality) - objective = getattr(cls, 'objective', ImplicitMetaGradientModule.objective) - cls._custom_optimality = optimality is not ImplicitMetaGradientModule.optimality - cls._custom_objective = objective is not ImplicitMetaGradientModule.objective + static_super_optimality = inspect.getattr_static(ImplicitMetaGradientModule, 'optimality') + static_super_objective = inspect.getattr_static(ImplicitMetaGradientModule, 'objective') + static_cls_optimality = inspect.getattr_static(cls, 'optimality') + static_cls_objective = inspect.getattr_static(cls, 'objective') + cls._custom_optimality = static_cls_optimality is not static_super_optimality + cls._custom_objective = static_cls_objective is not static_super_objective if cls._custom_optimality: - if isinstance(optimality, staticmethod): + if isinstance(static_cls_optimality, staticmethod): raise TypeError('method optimality() must not be a staticmethod.') - if isinstance(optimality, classmethod): + if isinstance(static_cls_optimality, classmethod): raise TypeError('method optimality() must not be a classmethod.') - if not callable(optimality): + if not callable(static_cls_optimality): raise TypeError('method optimality() must be callable.') elif not cls._custom_objective: raise TypeError( 'ImplicitMetaGradientModule requires either an optimality() method or an objective() method' ) else: - if isinstance(objective, staticmethod): + if isinstance(static_cls_objective, staticmethod): raise TypeError('method objective() must not be a staticmethod.') - if isinstance(objective, classmethod): + if isinstance(static_cls_objective, classmethod): raise TypeError('method objective() must not be a classmethod.') - if not callable(objective): + if not callable(static_cls_objective): raise TypeError('method objective() must be callable.') make_optimality_from_objective(cls) From f3d86ee15ab807f10a6400497837bb2488573070 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 4 Mar 2023 01:34:27 +0800 Subject: [PATCH 06/10] test: add more tests --- tests/.coveragerc | 9 +++++++++ tests/test_implicit.py | 16 +++++++++------- tests/test_zero_order.py | 10 ++-------- torchopt/base.py | 2 +- torchopt/typing.py | 2 +- torchopt/utils.py | 2 +- 6 files changed, 23 insertions(+), 18 deletions(-) diff --git a/tests/.coveragerc b/tests/.coveragerc index 462c4c3a..4238e71d 100644 --- a/tests/.coveragerc +++ b/tests/.coveragerc @@ -6,3 +6,12 @@ omit = ../docs/* ../examples/* ../tutorials/* + +[report] +exclude_lines = + pragma: no cover + raise NotImplementedError + class .*\bProtocol\): + @(abc\.)?abstractmethod + if __name__ == ('__main__'|"__main__"): + if TYPE_CHECKING: diff --git a/tests/test_implicit.py b/tests/test_implicit.py index 5ed08561..0cf00136 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -700,7 +700,7 @@ def __init__(self, x): self.x = x def objective(self): - return self.x + return self.x.mean() def solve(self): pass @@ -722,6 +722,11 @@ def solve(self): with pytest.raises(RuntimeError, match='The module has no parameters.'): model.optimality() + model = EmptyParameters(torch.zeros(8)) + model.register_parameter('y', torch.zeros(8, requires_grad=True)) + with pytest.raises(RuntimeError, match='The module has no meta-parameters.'): + model.optimality() + model = EmptyParameters(torch.zeros(8, requires_grad=True)) model.register_parameter('y', torch.zeros(8, requires_grad=True)) model.solve() @@ -757,10 +762,7 @@ class MyModule1(torchopt.nn.ImplicitMetaGradientModule): def objective(self): return torch.tensor(0.0) - with pytest.raises( - TypeError, - match="Can't instantiate abstract class", - ): + with pytest.raises(TypeError, match="Can't instantiate abstract class"): MyModule1() with pytest.raises( @@ -820,12 +822,12 @@ def solve(self): with pytest.raises( TypeError, - match=re.escape('method optimality() must not be a staticmethod.'), + match=re.escape('method objective(() must not be a staticmethod.'), ): class MyModule7(torchopt.nn.ImplicitMetaGradientModule): @staticmethod - def optimality(): + def objective(): return () def solve(self): diff --git a/tests/test_zero_order.py b/tests/test_zero_order.py index 3a87d6bc..5455f3af 100644 --- a/tests/test_zero_order.py +++ b/tests/test_zero_order.py @@ -155,18 +155,12 @@ class MyModule1(torchopt.nn.ZeroOrderGradientModule): def forward(self): return torch.tensor(0.0) - with pytest.raises( - TypeError, - match="Can't instantiate abstract class", - ): + with pytest.raises(TypeError, match="Can't instantiate abstract class"): MyModule1() class MyModule2(torchopt.nn.ZeroOrderGradientModule): def sample(self, sample_shape): return torch.tensor(0.0) - with pytest.raises( - TypeError, - match="Can't instantiate abstract class", - ): + with pytest.raises(TypeError, match="Can't instantiate abstract class"): MyModule2() diff --git a/torchopt/base.py b/torchopt/base.py index 7678d543..b5f35a1b 100644 --- a/torchopt/base.py +++ b/torchopt/base.py @@ -38,7 +38,7 @@ from typing import TYPE_CHECKING, Callable, NamedTuple, Protocol -if TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: from torchopt.typing import OptState, Params, Updates diff --git a/torchopt/typing.py b/torchopt/typing.py index 6cd6cf67..510cb693 100644 --- a/torchopt/typing.py +++ b/torchopt/typing.py @@ -141,7 +141,7 @@ def sample( ) -> Union[Tensor, Sequence[Numeric]]: # pylint: disable-next=line-too-long """Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.""" - raise NotImplementedError # pragma: no cover + raise NotImplementedError Samplable.register(Distribution) diff --git a/torchopt/utils.py b/torchopt/utils.py index b56231c6..0d227049 100644 --- a/torchopt/utils.py +++ b/torchopt/utils.py @@ -28,7 +28,7 @@ from torchopt.typing import Device, ModuleTensorContainers, OptState, TensorContainer, TensorTree -if TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: from torchopt.optim.meta.base import MetaOptimizer From f450638765e216c24d07a862db771e3f41122e20 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 4 Mar 2023 01:44:01 +0800 Subject: [PATCH 07/10] test: add more tests --- tests/test_implicit.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/tests/test_implicit.py b/tests/test_implicit.py index 0cf00136..5ab45995 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -741,20 +741,36 @@ def solve(self): def test_module_enable_implicit_gradients_twice() -> None: - class MyModule(torchopt.nn.ImplicitMetaGradientModule): + class MyModule1(torchopt.nn.ImplicitMetaGradientModule): def objective(self): return torch.tensor(0.0) def solve(self): pass - from torchopt.diff.implicit.nn.module import enable_implicit_gradients + from torchopt.diff.implicit.nn.module import ( + enable_implicit_gradients, + make_optimality_from_objective, + ) with pytest.raises( TypeError, match='Implicit gradients are already enabled for the `solve` method.', ): - enable_implicit_gradients(MyModule) + enable_implicit_gradients(MyModule1) + + class MyModule2(torchopt.nn.ImplicitMetaGradientModule): + def optimality(self): + return torch.tensor(0.0) + + def solve(self): + pass + + with pytest.raises( + TypeError, + match='The objective function is not defined.', + ): + make_optimality_from_objective(MyModule2) def test_module_abstract_methods() -> None: From 5ae44b5ea6eb365b978c94333a29003b4bff5545 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 4 Mar 2023 01:52:38 +0800 Subject: [PATCH 08/10] chore: update codecov.yaml --- codecov.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/codecov.yml b/codecov.yml index 65b70e6e..e1d3aab2 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,9 +1,12 @@ coverage: + precision: 2 round: nearest status: project: default: + target: auto threshold: 0.05% patch: default: + target: 100% informational: true From c014379edf0978c701bbee04a97d395c01e24e80 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 4 Mar 2023 16:52:57 +0800 Subject: [PATCH 09/10] test: fix tests --- tests/test_implicit.py | 2 +- torchopt/diff/implicit/nn/module.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_implicit.py b/tests/test_implicit.py index 5ab45995..6002eb31 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -838,7 +838,7 @@ def solve(self): with pytest.raises( TypeError, - match=re.escape('method objective(() must not be a staticmethod.'), + match=re.escape('method objective() must not be a staticmethod.'), ): class MyModule7(torchopt.nn.ImplicitMetaGradientModule): diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py index 9d6eb35b..ab2705fe 100644 --- a/torchopt/diff/implicit/nn/module.py +++ b/torchopt/diff/implicit/nn/module.py @@ -81,8 +81,8 @@ def make_optimality_from_objective( ) -> type[ImplicitMetaGradientModule]: """Derives the optimality function of the objective function.""" static_super_objective = inspect.getattr_static(ImplicitMetaGradientModule, 'objective') - static_cls_optimality = inspect.getattr_static(cls, 'optimality', static_super_objective) - if static_cls_optimality is static_super_objective: + static_cls_objective = inspect.getattr_static(cls, 'objective', static_super_objective) + if static_cls_objective is static_super_objective: raise TypeError('The objective function is not defined.') def optimality(self: ImplicitMetaGradientModule, *input: Any, **kwargs: Any) -> TupleOfTensors: From de6b176ae0689fbd7b186c728cbb765edb0002e5 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 7 Mar 2023 10:32:50 +0000 Subject: [PATCH 10/10] lint: appease linters --- torchopt/diff/implicit/nn/module.py | 4 ++-- torchopt/diff/zero_order/nn/module.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py index ab2705fe..cd7ba126 100644 --- a/torchopt/diff/implicit/nn/module.py +++ b/torchopt/diff/implicit/nn/module.py @@ -106,7 +106,7 @@ def optimality(self: ImplicitMetaGradientModule, *input: Any, **kwargs: Any) -> **kwargs, ) - cls.optimality = optimality # type: ignore[assignment] + cls.optimality = optimality # type: ignore[method-assign] return cls @@ -163,7 +163,7 @@ def wrapped(self: ImplicitMetaGradientModule, *input: Any, **kwargs: Any) -> Any return output wrapped.__implicit_gradients_enabled__ = True # type: ignore[attr-defined] - cls.solve = wrapped # type: ignore[assignment] + cls.solve = wrapped # type: ignore[method-assign] return cls diff --git a/torchopt/diff/zero_order/nn/module.py b/torchopt/diff/zero_order/nn/module.py index b1d7f6fd..6e031300 100644 --- a/torchopt/diff/zero_order/nn/module.py +++ b/torchopt/diff/zero_order/nn/module.py @@ -66,7 +66,7 @@ def forward_fn( return forward_fn(flat_params, *input, **kwargs) wrapped.__zero_order_gradients_enabled__ = True # type: ignore[attr-defined] - cls.forward = wrapped # type: ignore[assignment] + cls.forward = wrapped # type: ignore[method-assign] return cls pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy