diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a884496..ea37106e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,7 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Fix transpose empty iterable with `zip(*nested)` in transformations by [@XuehaiPan](https://github.com/XuehaiPan) in [#145](https://github.com/metaopt/torchopt/pull/145). ### Removed diff --git a/codecov.yml b/codecov.yml index 65b70e6e..e1d3aab2 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,9 +1,12 @@ coverage: + precision: 2 round: nearest status: project: default: + target: auto threshold: 0.05% patch: default: + target: 100% informational: true diff --git a/tests/.coveragerc b/tests/.coveragerc index 462c4c3a..4238e71d 100644 --- a/tests/.coveragerc +++ b/tests/.coveragerc @@ -6,3 +6,12 @@ omit = ../docs/* ../examples/* ../tutorials/* + +[report] +exclude_lines = + pragma: no cover + raise NotImplementedError + class .*\bProtocol\): + @(abc\.)?abstractmethod + if __name__ == ('__main__'|"__main__"): + if TYPE_CHECKING: diff --git a/tests/test_alias.py b/tests/test_alias.py index b609cf58..c5cb5e90 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -24,7 +24,56 @@ import helpers import torchopt +from torchopt import pytree from torchopt.alias.utils import _set_use_chain_flat +from torchopt.typing import TensorTree + + +@helpers.parametrize( + optimizer=[ + torchopt.sgd, + torchopt.adam, + torchopt.adamw, + torchopt.rmsprop, + ], + tensortree=[ + {}, + (), + [], + (None,), + {'a': (), 'b': {'c': []}, 'd': None}, + ], + maximize=[False, True], + inplace=[True, False], + use_chain_flat=[True, False], +) +def test_empty( + optimizer: Callable, + tensortree: TensorTree, + maximize: bool, + inplace: bool, + use_chain_flat: bool, +) -> None: + _set_use_chain_flat(use_chain_flat) + + params = pytree.tree_map(lambda x: x, tensortree) + grads = pytree.tree_map(lambda x: x, tensortree) + + optim = optimizer(1e-3, maximize=maximize) + optim_state = optim.init(params) + updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) + _ = torchopt.apply_updates(params, updates) + + try: + optim = optimizer(1e-3, maximize=maximize, use_accelerated_op=True) + except TypeError: + pass + else: + optim_state = optim.init(params) + updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) + _ = torchopt.apply_updates(params, updates) + + _set_use_chain_flat(True) @helpers.parametrize( diff --git a/tests/test_implicit.py b/tests/test_implicit.py index 8672c588..6002eb31 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -16,6 +16,7 @@ from __future__ import annotations import copy +import re from collections import OrderedDict from types import FunctionType @@ -690,3 +691,184 @@ def outer_level(params_jax, l2reg_jax, xs, ys, xq, yq): l2reg_jax_as_tensor = torch.tensor(np.asarray(l2reg_jax), dtype=dtype) helpers.assert_all_close(l2reg_torch, l2reg_jax_as_tensor) + + +def test_module_empty_parameters() -> None: + class EmptyParameters(ImplicitMetaGradientModule): + def __init__(self, x): + super().__init__() + self.x = x + + def objective(self): + return self.x.mean() + + def solve(self): + pass + + model = EmptyParameters(torch.zeros(8)) + with pytest.raises(RuntimeError, match='The module has no parameters.'): + model.solve() + + model = EmptyParameters(torch.zeros(8)) + model.register_parameter('y', torch.zeros(8, requires_grad=True)) + with pytest.raises(RuntimeError, match='The module has no meta-parameters.'): + model.solve() + + model = EmptyParameters(torch.zeros(8, requires_grad=True)) + with pytest.raises(RuntimeError, match='The module has no parameters.'): + model.solve() + + model = EmptyParameters(torch.zeros(8, requires_grad=True)) + with pytest.raises(RuntimeError, match='The module has no parameters.'): + model.optimality() + + model = EmptyParameters(torch.zeros(8)) + model.register_parameter('y', torch.zeros(8, requires_grad=True)) + with pytest.raises(RuntimeError, match='The module has no meta-parameters.'): + model.optimality() + + model = EmptyParameters(torch.zeros(8, requires_grad=True)) + model.register_parameter('y', torch.zeros(8, requires_grad=True)) + model.solve() + + model = EmptyParameters(nn.Linear(8, 8).eval()) + with pytest.raises(RuntimeError, match='The module has no meta-parameters.'): + model.solve() + + model = EmptyParameters(nn.Linear(8, 8)) + model.register_parameter('y', torch.zeros(8, requires_grad=True)) + model.solve() + + +def test_module_enable_implicit_gradients_twice() -> None: + class MyModule1(torchopt.nn.ImplicitMetaGradientModule): + def objective(self): + return torch.tensor(0.0) + + def solve(self): + pass + + from torchopt.diff.implicit.nn.module import ( + enable_implicit_gradients, + make_optimality_from_objective, + ) + + with pytest.raises( + TypeError, + match='Implicit gradients are already enabled for the `solve` method.', + ): + enable_implicit_gradients(MyModule1) + + class MyModule2(torchopt.nn.ImplicitMetaGradientModule): + def optimality(self): + return torch.tensor(0.0) + + def solve(self): + pass + + with pytest.raises( + TypeError, + match='The objective function is not defined.', + ): + make_optimality_from_objective(MyModule2) + + +def test_module_abstract_methods() -> None: + class MyModule1(torchopt.nn.ImplicitMetaGradientModule): + def objective(self): + return torch.tensor(0.0) + + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + MyModule1() + + with pytest.raises( + TypeError, + match=re.escape( + 'ImplicitMetaGradientModule requires either an optimality() method or an objective() method' + ), + ): + + class MyModule2(torchopt.nn.ImplicitMetaGradientModule): + def solve(self): + pass + + class MyModule3(torchopt.nn.ImplicitMetaGradientModule): + def optimality(self): + return () + + def solve(self): + pass + + with pytest.raises( + TypeError, + match=re.escape('method optimality() must not be a staticmethod.'), + ): + + class MyModule4(torchopt.nn.ImplicitMetaGradientModule): + @staticmethod + def optimality(): + return () + + def solve(self): + pass + + with pytest.raises( + TypeError, + match=re.escape('method optimality() must not be a classmethod.'), + ): + + class MyModule5(torchopt.nn.ImplicitMetaGradientModule): + @classmethod + def optimality(self): + return () + + def solve(self): + pass + + with pytest.raises( + TypeError, + match=re.escape('method optimality() must be callable.'), + ): + + class MyModule6(torchopt.nn.ImplicitMetaGradientModule): + optimality = 0 + + def solve(self): + pass + + with pytest.raises( + TypeError, + match=re.escape('method objective() must not be a staticmethod.'), + ): + + class MyModule7(torchopt.nn.ImplicitMetaGradientModule): + @staticmethod + def objective(): + return () + + def solve(self): + pass + + with pytest.raises( + TypeError, + match=re.escape('method objective() must not be a classmethod.'), + ): + + class MyModule8(torchopt.nn.ImplicitMetaGradientModule): + @classmethod + def objective(self): + return () + + def solve(self): + pass + + with pytest.raises( + TypeError, + match=re.escape('method objective() must be callable.'), + ): + + class MyModule9(torchopt.nn.ImplicitMetaGradientModule): + objective = 0 + + def solve(self): + pass diff --git a/tests/test_nn.py b/tests/test_nn.py index 1b48c06b..1e524ba5 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -69,7 +69,13 @@ def test_register_tensors() -> None: assert m._meta_parameters['x'] is x assert m._parameters['y'] is y - assert hasattr(m, 'z') and m.z is z and 'z' not in m._buffers + assert ( + hasattr(m, 'z') + and m.z is z + and 'z' not in m._meta_parameters + and 'z' not in m._parameters + and 'z' not in m._buffers + ) del m.x object.__setattr__(m, 'x', x) @@ -82,6 +88,67 @@ def test_register_tensors() -> None: m.b = b assert m.b is b and 'b' in m._buffers + m = torchopt.nn.MetaGradientModule(x, b) + + with pytest.raises( + TypeError, + match=re.escape('parameter name should be a string. Got bytes'), + ): + m.register_meta_parameter(b'x', x) + + with pytest.raises( + KeyError, + match=re.escape("parameter name can't contain '.'"), + ): + m.register_meta_parameter('x.x', x) + + with pytest.raises( + KeyError, + match=re.escape("parameter name can't be empty string ''"), + ): + m.register_meta_parameter('', x) + + m.register_buffer('z', None) + with pytest.raises( + KeyError, + match=re.escape("attribute 'z' already exists"), + ): + m.register_meta_parameter('z', x) + + with pytest.raises( + ValueError, + match=re.escape( + "cannot assign Tensor that is a meta-parameter to parameter 'x'. " + 'Use self.register_meta_parameter() instead.' + ), + ): + m.register_parameter('x', x) + + m.x = x + with pytest.raises( + KeyError, + match=re.escape("attribute 'x' already exists"), + ): + m.register_parameter('x', x) + + with pytest.raises( + TypeError, + match=re.escape('parameter name should be a string. Got bytes'), + ): + m.register_parameter(b'y', y) + + with pytest.raises( + KeyError, + match=re.escape("parameter name can't contain '.'"), + ): + m.register_parameter('y.x', y) + + with pytest.raises( + KeyError, + match=re.escape("parameter name can't be empty string ''"), + ): + m.register_parameter('', y) + def test_no_super_init() -> None: class NoSuper1(torchopt.nn.MetaGradientModule): diff --git a/tests/test_zero_order.py b/tests/test_zero_order.py index ac7ae840..5455f3af 100644 --- a/tests/test_zero_order.py +++ b/tests/test_zero_order.py @@ -14,6 +14,7 @@ # ============================================================================== import functorch +import pytest import torch import torch.nn as nn import torch.nn.functional as F @@ -117,3 +118,49 @@ def sample(self, sample_shape=torch.Size()): optimizer.zero_grad() loss.backward() # compute gradients optimizer.step() # update network parameters + + +def test_module_enable_zero_order_gradients_twice() -> None: + class MyModule(torchopt.nn.ZeroOrderGradientModule): + def forward(self): + return torch.tensor(0.0) + + def sample(self, sample_shape): + return torch.tensor(0.0) + + from torchopt.diff.zero_order.nn.module import enable_zero_order_gradients + + with pytest.raises( + TypeError, + match='Zero-order gradient estimation is already enabled for the `forward` method.', + ): + enable_zero_order_gradients(MyModule) + + +def test_module_empty_parameters() -> None: + class MyModule(torchopt.nn.ZeroOrderGradientModule): + def forward(self): + return torch.tensor(0.0) + + def sample(self, sample_shape): + return torch.tensor(0.0) + + m = MyModule() + with pytest.raises(RuntimeError, match='The module has no parameters.'): + m() + + +def test_module_abstract_methods() -> None: + class MyModule1(torchopt.nn.ZeroOrderGradientModule): + def forward(self): + return torch.tensor(0.0) + + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + MyModule1() + + class MyModule2(torchopt.nn.ZeroOrderGradientModule): + def sample(self, sample_shape): + return torch.tensor(0.0) + + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + MyModule2() diff --git a/torchopt/base.py b/torchopt/base.py index 7678d543..b5f35a1b 100644 --- a/torchopt/base.py +++ b/torchopt/base.py @@ -38,7 +38,7 @@ from typing import TYPE_CHECKING, Callable, NamedTuple, Protocol -if TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: from torchopt.typing import OptState, Params, Updates diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py index 2aceb656..cd7ba126 100644 --- a/torchopt/diff/implicit/nn/module.py +++ b/torchopt/diff/implicit/nn/module.py @@ -20,6 +20,7 @@ import abc import functools +import inspect import itertools from typing import Any, Iterable @@ -36,38 +37,40 @@ def _stateless_objective_fn( - __flat_params: TupleOfTensors, - __flat_meta_params: TupleOfTensors, - __params_names: Iterable[str], - __meta_params_names: Iterable[str], + flat_params: TupleOfTensors, + flat_meta_params: TupleOfTensors, + params_names: Iterable[str], + meta_params_names: Iterable[str], self: ImplicitMetaGradientModule, + /, *input: Any, **kwargs: Any, ) -> torch.Tensor: with reparametrize( self, itertools.chain( - zip(__params_names, __flat_params), - zip(__meta_params_names, __flat_meta_params), + zip(params_names, flat_params), + zip(meta_params_names, flat_meta_params), ), ): return self.objective(*input, **kwargs) def _stateless_optimality_fn( - __flat_params: TupleOfTensors, - __flat_meta_params: TupleOfTensors, - __params_names: Iterable[str], - __meta_params_names: Iterable[str], + flat_params: TupleOfTensors, + flat_meta_params: TupleOfTensors, + params_names: Iterable[str], + meta_params_names: Iterable[str], self: ImplicitMetaGradientModule, + /, *input: Any, **kwargs: Any, ) -> TupleOfTensors: with reparametrize( self, itertools.chain( - zip(__params_names, __flat_params), - zip(__meta_params_names, __flat_meta_params), + zip(params_names, flat_params), + zip(meta_params_names, flat_meta_params), ), ): return self.optimality(*input, **kwargs) @@ -77,15 +80,20 @@ def make_optimality_from_objective( cls: type[ImplicitMetaGradientModule], ) -> type[ImplicitMetaGradientModule]: """Derives the optimality function of the objective function.""" - if ( - getattr(cls, 'objective', ImplicitMetaGradientModule.objective) - is ImplicitMetaGradientModule.objective - ): + static_super_objective = inspect.getattr_static(ImplicitMetaGradientModule, 'objective') + static_cls_objective = inspect.getattr_static(cls, 'objective', static_super_objective) + if static_cls_objective is static_super_objective: raise TypeError('The objective function is not defined.') def optimality(self: ImplicitMetaGradientModule, *input: Any, **kwargs: Any) -> TupleOfTensors: - params_names, flat_params = tuple(zip(*self.named_parameters())) - meta_params_names, flat_meta_params = tuple(zip(*self.named_meta_parameters())) + named_params = tuple(self.named_parameters()) + named_meta_params = tuple(self.named_meta_parameters()) + if len(named_params) == 0: + raise RuntimeError('The module has no parameters.') + if len(named_meta_params) == 0: + raise RuntimeError('The module has no meta-parameters.') + params_names, flat_params = tuple(zip(*named_params)) + meta_params_names, flat_meta_params = tuple(zip(*named_meta_params)) objective_grad_fn = functorch.grad(_stateless_objective_fn, argnums=0) return objective_grad_fn( @@ -98,7 +106,7 @@ def optimality(self: ImplicitMetaGradientModule, *input: Any, **kwargs: Any) -> **kwargs, ) - cls.optimality = optimality # type: ignore[assignment] + cls.optimality = optimality # type: ignore[method-assign] return cls @@ -115,12 +123,13 @@ def enable_implicit_gradients( @custom_root(_stateless_optimality_fn, argnums=1, has_aux=True, **solve_kwargs) def stateless_solver_fn( # pylint: disable=unused-argument - __flat_params: TupleOfTensors, - __flat_meta_params: TupleOfTensors, - __params_names: Iterable[str], - __meta_params_names: Iterable[str], + flat_params: TupleOfTensors, + flat_meta_params: TupleOfTensors, + params_names: Iterable[str], + meta_params_names: Iterable[str], # pylint: enable=unused-argument self: ImplicitMetaGradientModule, + /, *input: Any, **kwargs: Any, ) -> tuple[TupleOfTensors, Any]: @@ -132,8 +141,14 @@ def stateless_solver_fn( @functools.wraps(cls_solve) def wrapped(self: ImplicitMetaGradientModule, *input: Any, **kwargs: Any) -> Any: """Solve the optimization problem.""" - params_names, flat_params = tuple(zip(*self.named_parameters())) - meta_params_names, flat_meta_params = tuple(zip(*self.named_meta_parameters())) + named_params = tuple(self.named_parameters()) + named_meta_params = tuple(self.named_meta_parameters()) + if len(named_params) == 0: + raise RuntimeError('The module has no parameters.') + if len(named_meta_params) == 0: + raise RuntimeError('The module has no meta-parameters.') + params_names, flat_params = tuple(zip(*named_params)) + meta_params_names, flat_meta_params = tuple(zip(*named_meta_params)) flat_optimal_params, output = stateless_solver_fn( flat_params, @@ -148,11 +163,11 @@ def wrapped(self: ImplicitMetaGradientModule, *input: Any, **kwargs: Any) -> Any return output wrapped.__implicit_gradients_enabled__ = True # type: ignore[attr-defined] - cls.solve = wrapped # type: ignore[assignment] + cls.solve = wrapped # type: ignore[method-assign] return cls -class ImplicitMetaGradientModule(MetaGradientModule): +class ImplicitMetaGradientModule(MetaGradientModule, metaclass=abc.ABCMeta): """The base class for differentiable implicit meta-gradient models.""" _custom_optimality: bool @@ -164,28 +179,30 @@ def __init_subclass__(cls, linear_solve: LinearSolver | None = None) -> None: super().__init_subclass__() cls.linear_solve = linear_solve - optimality = getattr(cls, 'optimality', ImplicitMetaGradientModule.optimality) - objective = getattr(cls, 'objective', ImplicitMetaGradientModule.objective) - cls._custom_optimality = optimality is not ImplicitMetaGradientModule.optimality - cls._custom_objective = objective is not ImplicitMetaGradientModule.objective + static_super_optimality = inspect.getattr_static(ImplicitMetaGradientModule, 'optimality') + static_super_objective = inspect.getattr_static(ImplicitMetaGradientModule, 'objective') + static_cls_optimality = inspect.getattr_static(cls, 'optimality') + static_cls_objective = inspect.getattr_static(cls, 'objective') + cls._custom_optimality = static_cls_optimality is not static_super_optimality + cls._custom_objective = static_cls_objective is not static_super_objective if cls._custom_optimality: - if isinstance(optimality, staticmethod): + if isinstance(static_cls_optimality, staticmethod): raise TypeError('method optimality() must not be a staticmethod.') - if isinstance(optimality, classmethod): + if isinstance(static_cls_optimality, classmethod): raise TypeError('method optimality() must not be a classmethod.') - if not callable(optimality): + if not callable(static_cls_optimality): raise TypeError('method optimality() must be callable.') elif not cls._custom_objective: raise TypeError( 'ImplicitMetaGradientModule requires either an optimality() method or an objective() method' ) else: - if isinstance(objective, staticmethod): + if isinstance(static_cls_objective, staticmethod): raise TypeError('method objective() must not be a staticmethod.') - if isinstance(objective, classmethod): + if isinstance(static_cls_objective, classmethod): raise TypeError('method objective() must not be a classmethod.') - if not callable(objective): + if not callable(static_cls_objective): raise TypeError('method objective() must be callable.') make_optimality_from_objective(cls) diff --git a/torchopt/diff/zero_order/nn/module.py b/torchopt/diff/zero_order/nn/module.py index aa75890c..6e031300 100644 --- a/torchopt/diff/zero_order/nn/module.py +++ b/torchopt/diff/zero_order/nn/module.py @@ -49,7 +49,10 @@ def enable_zero_order_gradients( @functools.wraps(cls_forward) def wrapped(self: ZeroOrderGradientModule, *input: Any, **kwargs: Any) -> torch.Tensor: """Do the forward pass calculation.""" - params_names, flat_params = tuple(zip(*self.named_parameters())) + named_params = tuple(self.named_parameters()) + if len(named_params) == 0: + raise RuntimeError('The module has no parameters.') + params_names, flat_params = tuple(zip(*named_params)) @zero_order(self.sample, argnums=0, method=method, num_samples=num_samples, sigma=sigma) def forward_fn( @@ -63,7 +66,7 @@ def forward_fn( return forward_fn(flat_params, *input, **kwargs) wrapped.__zero_order_gradients_enabled__ = True # type: ignore[attr-defined] - cls.forward = wrapped # type: ignore[assignment] + cls.forward = wrapped # type: ignore[method-assign] return cls diff --git a/torchopt/transform/scale_by_adam.py b/torchopt/transform/scale_by_adam.py index c1f6274a..9b7159ae 100644 --- a/torchopt/transform/scale_by_adam.py +++ b/torchopt/transform/scale_by_adam.py @@ -301,9 +301,40 @@ def update_fn( count_inc = inc_count.impl(updates, state.count, already_flattened=True) # type: ignore[attr-defined] op = AdamOp(b1=b1, b2=b2, eps=eps, eps_root=eps_root, inplace=inplace) - out = tree_map_flat(op, state.mu, state.nu, updates, count_inc) - new_mu, new_nu, new_updates = tuple(zip(*out)) # transpose + def op_fn( + mu: torch.Tensor | None, + nu: torch.Tensor | None, + update: torch.Tensor | None, + count: torch.Tensor | None, + ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: + if mu is None: + return (None, None, None) + return op(mu, nu, update, count) # type: ignore[arg-type] + + out = tree_map_flat( + op_fn, + state.mu, + state.nu, + updates, + count_inc, + none_is_leaf=True, + ) + + if len(out) == 0: + new_mu, new_nu, new_updates = (), (), () + else: + new_mu, new_nu, new_updates = tuple(zip(*out)) # transpose + + new_mu, new_nu, new_updates = ( + new if type(new) is type(old) else type(old)(new) + for new, old in ( + (new_mu, state.mu), + (new_nu, state.nu), + (new_updates, updates), + ) + ) + return new_updates, ScaleByAdamState(mu=new_mu, nu=new_nu, count=count_inc) else: @@ -318,15 +349,43 @@ def update_fn( ) -> tuple[Updates, OptState]: count_inc = inc_count.impl(updates, state.count, already_flattened=False) # type: ignore[attr-defined] - treespec = pytree.tree_structure(updates, none_is_leaf=True) - - op = AdamOp(b1=b1, b2=b2, eps=eps, eps_root=eps_root, inplace=inplace) - out = pytree.tree_map(op, state.mu, state.nu, updates, count_inc) - new_mu: Updates new_nu: Updates new_updates: Updates - new_mu, new_nu, new_updates = pytree.tree_transpose(treespec, TRIPLE_PYTREE_SPEC, out) # type: ignore[misc] + + treespec = pytree.tree_structure(updates, none_is_leaf=True) + if treespec.num_leaves > 0: + op = AdamOp(b1=b1, b2=b2, eps=eps, eps_root=eps_root, inplace=inplace) + + def op_fn( + mu: torch.Tensor | None, + nu: torch.Tensor | None, + update: torch.Tensor | None, + count: torch.Tensor | None, + ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: + if mu is None: + return (None, None, None) + return op(mu, nu, update, count) # type: ignore[arg-type] + + out = pytree.tree_map( + op_fn, + state.mu, + state.nu, + updates, + count_inc, + none_is_leaf=True, + ) + + new_mu, new_nu, new_updates = pytree.tree_transpose( # type: ignore[misc] + treespec, + TRIPLE_PYTREE_SPEC, + out, + ) + else: + new_mu = pytree.tree_unflatten(treespec, ()) + new_nu = pytree.tree_unflatten(treespec, ()) + new_updates = pytree.tree_unflatten(treespec, ()) + return new_updates, ScaleByAdamState(mu=new_mu, nu=new_nu, count=count_inc) def init_fn(params: Params) -> OptState: diff --git a/torchopt/typing.py b/torchopt/typing.py index 6cd6cf67..510cb693 100644 --- a/torchopt/typing.py +++ b/torchopt/typing.py @@ -141,7 +141,7 @@ def sample( ) -> Union[Tensor, Sequence[Numeric]]: # pylint: disable-next=line-too-long """Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.""" - raise NotImplementedError # pragma: no cover + raise NotImplementedError Samplable.register(Distribution) diff --git a/torchopt/utils.py b/torchopt/utils.py index b56231c6..0d227049 100644 --- a/torchopt/utils.py +++ b/torchopt/utils.py @@ -28,7 +28,7 @@ from torchopt.typing import Device, ModuleTensorContainers, OptState, TensorContainer, TensorTree -if TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: from torchopt.optim.meta.base import MetaOptimizer
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: