Skip to content

fix: fix transpose empty iterable with zip(*nested) in transformations #145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Fix transpose empty iterable with `zip(*nested)` in transformations by [@XuehaiPan](https://github.com/XuehaiPan) in [#145](https://github.com/metaopt/torchopt/pull/145).

### Removed

Expand Down
3 changes: 3 additions & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
coverage:
precision: 2
round: nearest
status:
project:
default:
target: auto
threshold: 0.05%
patch:
default:
target: 100%
informational: true
9 changes: 9 additions & 0 deletions tests/.coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,12 @@ omit =
../docs/*
../examples/*
../tutorials/*

[report]
exclude_lines =
pragma: no cover
raise NotImplementedError
class .*\bProtocol\):
@(abc\.)?abstractmethod
if __name__ == ('__main__'|"__main__"):
if TYPE_CHECKING:
49 changes: 49 additions & 0 deletions tests/test_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,56 @@

import helpers
import torchopt
from torchopt import pytree
from torchopt.alias.utils import _set_use_chain_flat
from torchopt.typing import TensorTree


@helpers.parametrize(
optimizer=[
torchopt.sgd,
torchopt.adam,
torchopt.adamw,
torchopt.rmsprop,
],
tensortree=[
{},
(),
[],
(None,),
{'a': (), 'b': {'c': []}, 'd': None},
],
maximize=[False, True],
inplace=[True, False],
use_chain_flat=[True, False],
)
def test_empty(
optimizer: Callable,
tensortree: TensorTree,
maximize: bool,
inplace: bool,
use_chain_flat: bool,
) -> None:
_set_use_chain_flat(use_chain_flat)

params = pytree.tree_map(lambda x: x, tensortree)
grads = pytree.tree_map(lambda x: x, tensortree)

optim = optimizer(1e-3, maximize=maximize)
optim_state = optim.init(params)
updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
_ = torchopt.apply_updates(params, updates)

try:
optim = optimizer(1e-3, maximize=maximize, use_accelerated_op=True)
except TypeError:
pass
else:
optim_state = optim.init(params)
updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
_ = torchopt.apply_updates(params, updates)

_set_use_chain_flat(True)


@helpers.parametrize(
Expand Down
182 changes: 182 additions & 0 deletions tests/test_implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from __future__ import annotations

import copy
import re
from collections import OrderedDict
from types import FunctionType

Expand Down Expand Up @@ -690,3 +691,184 @@ def outer_level(params_jax, l2reg_jax, xs, ys, xq, yq):

l2reg_jax_as_tensor = torch.tensor(np.asarray(l2reg_jax), dtype=dtype)
helpers.assert_all_close(l2reg_torch, l2reg_jax_as_tensor)


def test_module_empty_parameters() -> None:
class EmptyParameters(ImplicitMetaGradientModule):
def __init__(self, x):
super().__init__()
self.x = x

def objective(self):
return self.x.mean()

def solve(self):
pass

model = EmptyParameters(torch.zeros(8))
with pytest.raises(RuntimeError, match='The module has no parameters.'):
model.solve()

model = EmptyParameters(torch.zeros(8))
model.register_parameter('y', torch.zeros(8, requires_grad=True))
with pytest.raises(RuntimeError, match='The module has no meta-parameters.'):
model.solve()

model = EmptyParameters(torch.zeros(8, requires_grad=True))
with pytest.raises(RuntimeError, match='The module has no parameters.'):
model.solve()

model = EmptyParameters(torch.zeros(8, requires_grad=True))
with pytest.raises(RuntimeError, match='The module has no parameters.'):
model.optimality()

model = EmptyParameters(torch.zeros(8))
model.register_parameter('y', torch.zeros(8, requires_grad=True))
with pytest.raises(RuntimeError, match='The module has no meta-parameters.'):
model.optimality()

model = EmptyParameters(torch.zeros(8, requires_grad=True))
model.register_parameter('y', torch.zeros(8, requires_grad=True))
model.solve()

model = EmptyParameters(nn.Linear(8, 8).eval())
with pytest.raises(RuntimeError, match='The module has no meta-parameters.'):
model.solve()

model = EmptyParameters(nn.Linear(8, 8))
model.register_parameter('y', torch.zeros(8, requires_grad=True))
model.solve()


def test_module_enable_implicit_gradients_twice() -> None:
class MyModule1(torchopt.nn.ImplicitMetaGradientModule):
def objective(self):
return torch.tensor(0.0)

def solve(self):
pass

from torchopt.diff.implicit.nn.module import (
enable_implicit_gradients,
make_optimality_from_objective,
)

with pytest.raises(
TypeError,
match='Implicit gradients are already enabled for the `solve` method.',
):
enable_implicit_gradients(MyModule1)

class MyModule2(torchopt.nn.ImplicitMetaGradientModule):
def optimality(self):
return torch.tensor(0.0)

def solve(self):
pass

with pytest.raises(
TypeError,
match='The objective function is not defined.',
):
make_optimality_from_objective(MyModule2)


def test_module_abstract_methods() -> None:
class MyModule1(torchopt.nn.ImplicitMetaGradientModule):
def objective(self):
return torch.tensor(0.0)

with pytest.raises(TypeError, match="Can't instantiate abstract class"):
MyModule1()

with pytest.raises(
TypeError,
match=re.escape(
'ImplicitMetaGradientModule requires either an optimality() method or an objective() method'
),
):

class MyModule2(torchopt.nn.ImplicitMetaGradientModule):
def solve(self):
pass

class MyModule3(torchopt.nn.ImplicitMetaGradientModule):
def optimality(self):
return ()

def solve(self):
pass

with pytest.raises(
TypeError,
match=re.escape('method optimality() must not be a staticmethod.'),
):

class MyModule4(torchopt.nn.ImplicitMetaGradientModule):
@staticmethod
def optimality():
return ()

def solve(self):
pass

with pytest.raises(
TypeError,
match=re.escape('method optimality() must not be a classmethod.'),
):

class MyModule5(torchopt.nn.ImplicitMetaGradientModule):
@classmethod
def optimality(self):
return ()

def solve(self):
pass

with pytest.raises(
TypeError,
match=re.escape('method optimality() must be callable.'),
):

class MyModule6(torchopt.nn.ImplicitMetaGradientModule):
optimality = 0

def solve(self):
pass

with pytest.raises(
TypeError,
match=re.escape('method objective() must not be a staticmethod.'),
):

class MyModule7(torchopt.nn.ImplicitMetaGradientModule):
@staticmethod
def objective():
return ()

def solve(self):
pass

with pytest.raises(
TypeError,
match=re.escape('method objective() must not be a classmethod.'),
):

class MyModule8(torchopt.nn.ImplicitMetaGradientModule):
@classmethod
def objective(self):
return ()

def solve(self):
pass

with pytest.raises(
TypeError,
match=re.escape('method objective() must be callable.'),
):

class MyModule9(torchopt.nn.ImplicitMetaGradientModule):
objective = 0

def solve(self):
pass
69 changes: 68 additions & 1 deletion tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,13 @@ def test_register_tensors() -> None:

assert m._meta_parameters['x'] is x
assert m._parameters['y'] is y
assert hasattr(m, 'z') and m.z is z and 'z' not in m._buffers
assert (
hasattr(m, 'z')
and m.z is z
and 'z' not in m._meta_parameters
and 'z' not in m._parameters
and 'z' not in m._buffers
)

del m.x
object.__setattr__(m, 'x', x)
Expand All @@ -82,6 +88,67 @@ def test_register_tensors() -> None:
m.b = b
assert m.b is b and 'b' in m._buffers

m = torchopt.nn.MetaGradientModule(x, b)

with pytest.raises(
TypeError,
match=re.escape('parameter name should be a string. Got bytes'),
):
m.register_meta_parameter(b'x', x)

with pytest.raises(
KeyError,
match=re.escape("parameter name can't contain '.'"),
):
m.register_meta_parameter('x.x', x)

with pytest.raises(
KeyError,
match=re.escape("parameter name can't be empty string ''"),
):
m.register_meta_parameter('', x)

m.register_buffer('z', None)
with pytest.raises(
KeyError,
match=re.escape("attribute 'z' already exists"),
):
m.register_meta_parameter('z', x)

with pytest.raises(
ValueError,
match=re.escape(
"cannot assign Tensor that is a meta-parameter to parameter 'x'. "
'Use self.register_meta_parameter() instead.'
),
):
m.register_parameter('x', x)

m.x = x
with pytest.raises(
KeyError,
match=re.escape("attribute 'x' already exists"),
):
m.register_parameter('x', x)

with pytest.raises(
TypeError,
match=re.escape('parameter name should be a string. Got bytes'),
):
m.register_parameter(b'y', y)

with pytest.raises(
KeyError,
match=re.escape("parameter name can't contain '.'"),
):
m.register_parameter('y.x', y)

with pytest.raises(
KeyError,
match=re.escape("parameter name can't be empty string ''"),
):
m.register_parameter('', y)


def test_no_super_init() -> None:
class NoSuper1(torchopt.nn.MetaGradientModule):
Expand Down
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy