Skip to content

Commit 181770a

Browse files
authored
fix: fix transpose empty iterable with zip(*nested) in transformations (#145)
1 parent 2ad69ca commit 181770a

File tree

13 files changed

+488
-52
lines changed

13 files changed

+488
-52
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2222

2323
### Fixed
2424

25-
-
25+
- Fix transpose empty iterable with `zip(*nested)` in transformations by [@XuehaiPan](https://github.com/XuehaiPan) in [#145](https://github.com/metaopt/torchopt/pull/145).
2626

2727
### Removed
2828

codecov.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
coverage:
2+
precision: 2
23
round: nearest
34
status:
45
project:
56
default:
7+
target: auto
68
threshold: 0.05%
79
patch:
810
default:
11+
target: 100%
912
informational: true

tests/.coveragerc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,12 @@ omit =
66
../docs/*
77
../examples/*
88
../tutorials/*
9+
10+
[report]
11+
exclude_lines =
12+
pragma: no cover
13+
raise NotImplementedError
14+
class .*\bProtocol\):
15+
@(abc\.)?abstractmethod
16+
if __name__ == ('__main__'|"__main__"):
17+
if TYPE_CHECKING:

tests/test_alias.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,56 @@
2424

2525
import helpers
2626
import torchopt
27+
from torchopt import pytree
2728
from torchopt.alias.utils import _set_use_chain_flat
29+
from torchopt.typing import TensorTree
30+
31+
32+
@helpers.parametrize(
33+
optimizer=[
34+
torchopt.sgd,
35+
torchopt.adam,
36+
torchopt.adamw,
37+
torchopt.rmsprop,
38+
],
39+
tensortree=[
40+
{},
41+
(),
42+
[],
43+
(None,),
44+
{'a': (), 'b': {'c': []}, 'd': None},
45+
],
46+
maximize=[False, True],
47+
inplace=[True, False],
48+
use_chain_flat=[True, False],
49+
)
50+
def test_empty(
51+
optimizer: Callable,
52+
tensortree: TensorTree,
53+
maximize: bool,
54+
inplace: bool,
55+
use_chain_flat: bool,
56+
) -> None:
57+
_set_use_chain_flat(use_chain_flat)
58+
59+
params = pytree.tree_map(lambda x: x, tensortree)
60+
grads = pytree.tree_map(lambda x: x, tensortree)
61+
62+
optim = optimizer(1e-3, maximize=maximize)
63+
optim_state = optim.init(params)
64+
updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
65+
_ = torchopt.apply_updates(params, updates)
66+
67+
try:
68+
optim = optimizer(1e-3, maximize=maximize, use_accelerated_op=True)
69+
except TypeError:
70+
pass
71+
else:
72+
optim_state = optim.init(params)
73+
updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
74+
_ = torchopt.apply_updates(params, updates)
75+
76+
_set_use_chain_flat(True)
2877

2978

3079
@helpers.parametrize(

tests/test_implicit.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from __future__ import annotations
1717

1818
import copy
19+
import re
1920
from collections import OrderedDict
2021
from types import FunctionType
2122

@@ -690,3 +691,184 @@ def outer_level(params_jax, l2reg_jax, xs, ys, xq, yq):
690691

691692
l2reg_jax_as_tensor = torch.tensor(np.asarray(l2reg_jax), dtype=dtype)
692693
helpers.assert_all_close(l2reg_torch, l2reg_jax_as_tensor)
694+
695+
696+
def test_module_empty_parameters() -> None:
697+
class EmptyParameters(ImplicitMetaGradientModule):
698+
def __init__(self, x):
699+
super().__init__()
700+
self.x = x
701+
702+
def objective(self):
703+
return self.x.mean()
704+
705+
def solve(self):
706+
pass
707+
708+
model = EmptyParameters(torch.zeros(8))
709+
with pytest.raises(RuntimeError, match='The module has no parameters.'):
710+
model.solve()
711+
712+
model = EmptyParameters(torch.zeros(8))
713+
model.register_parameter('y', torch.zeros(8, requires_grad=True))
714+
with pytest.raises(RuntimeError, match='The module has no meta-parameters.'):
715+
model.solve()
716+
717+
model = EmptyParameters(torch.zeros(8, requires_grad=True))
718+
with pytest.raises(RuntimeError, match='The module has no parameters.'):
719+
model.solve()
720+
721+
model = EmptyParameters(torch.zeros(8, requires_grad=True))
722+
with pytest.raises(RuntimeError, match='The module has no parameters.'):
723+
model.optimality()
724+
725+
model = EmptyParameters(torch.zeros(8))
726+
model.register_parameter('y', torch.zeros(8, requires_grad=True))
727+
with pytest.raises(RuntimeError, match='The module has no meta-parameters.'):
728+
model.optimality()
729+
730+
model = EmptyParameters(torch.zeros(8, requires_grad=True))
731+
model.register_parameter('y', torch.zeros(8, requires_grad=True))
732+
model.solve()
733+
734+
model = EmptyParameters(nn.Linear(8, 8).eval())
735+
with pytest.raises(RuntimeError, match='The module has no meta-parameters.'):
736+
model.solve()
737+
738+
model = EmptyParameters(nn.Linear(8, 8))
739+
model.register_parameter('y', torch.zeros(8, requires_grad=True))
740+
model.solve()
741+
742+
743+
def test_module_enable_implicit_gradients_twice() -> None:
744+
class MyModule1(torchopt.nn.ImplicitMetaGradientModule):
745+
def objective(self):
746+
return torch.tensor(0.0)
747+
748+
def solve(self):
749+
pass
750+
751+
from torchopt.diff.implicit.nn.module import (
752+
enable_implicit_gradients,
753+
make_optimality_from_objective,
754+
)
755+
756+
with pytest.raises(
757+
TypeError,
758+
match='Implicit gradients are already enabled for the `solve` method.',
759+
):
760+
enable_implicit_gradients(MyModule1)
761+
762+
class MyModule2(torchopt.nn.ImplicitMetaGradientModule):
763+
def optimality(self):
764+
return torch.tensor(0.0)
765+
766+
def solve(self):
767+
pass
768+
769+
with pytest.raises(
770+
TypeError,
771+
match='The objective function is not defined.',
772+
):
773+
make_optimality_from_objective(MyModule2)
774+
775+
776+
def test_module_abstract_methods() -> None:
777+
class MyModule1(torchopt.nn.ImplicitMetaGradientModule):
778+
def objective(self):
779+
return torch.tensor(0.0)
780+
781+
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
782+
MyModule1()
783+
784+
with pytest.raises(
785+
TypeError,
786+
match=re.escape(
787+
'ImplicitMetaGradientModule requires either an optimality() method or an objective() method'
788+
),
789+
):
790+
791+
class MyModule2(torchopt.nn.ImplicitMetaGradientModule):
792+
def solve(self):
793+
pass
794+
795+
class MyModule3(torchopt.nn.ImplicitMetaGradientModule):
796+
def optimality(self):
797+
return ()
798+
799+
def solve(self):
800+
pass
801+
802+
with pytest.raises(
803+
TypeError,
804+
match=re.escape('method optimality() must not be a staticmethod.'),
805+
):
806+
807+
class MyModule4(torchopt.nn.ImplicitMetaGradientModule):
808+
@staticmethod
809+
def optimality():
810+
return ()
811+
812+
def solve(self):
813+
pass
814+
815+
with pytest.raises(
816+
TypeError,
817+
match=re.escape('method optimality() must not be a classmethod.'),
818+
):
819+
820+
class MyModule5(torchopt.nn.ImplicitMetaGradientModule):
821+
@classmethod
822+
def optimality(self):
823+
return ()
824+
825+
def solve(self):
826+
pass
827+
828+
with pytest.raises(
829+
TypeError,
830+
match=re.escape('method optimality() must be callable.'),
831+
):
832+
833+
class MyModule6(torchopt.nn.ImplicitMetaGradientModule):
834+
optimality = 0
835+
836+
def solve(self):
837+
pass
838+
839+
with pytest.raises(
840+
TypeError,
841+
match=re.escape('method objective() must not be a staticmethod.'),
842+
):
843+
844+
class MyModule7(torchopt.nn.ImplicitMetaGradientModule):
845+
@staticmethod
846+
def objective():
847+
return ()
848+
849+
def solve(self):
850+
pass
851+
852+
with pytest.raises(
853+
TypeError,
854+
match=re.escape('method objective() must not be a classmethod.'),
855+
):
856+
857+
class MyModule8(torchopt.nn.ImplicitMetaGradientModule):
858+
@classmethod
859+
def objective(self):
860+
return ()
861+
862+
def solve(self):
863+
pass
864+
865+
with pytest.raises(
866+
TypeError,
867+
match=re.escape('method objective() must be callable.'),
868+
):
869+
870+
class MyModule9(torchopt.nn.ImplicitMetaGradientModule):
871+
objective = 0
872+
873+
def solve(self):
874+
pass

tests/test_nn.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,13 @@ def test_register_tensors() -> None:
6969

7070
assert m._meta_parameters['x'] is x
7171
assert m._parameters['y'] is y
72-
assert hasattr(m, 'z') and m.z is z and 'z' not in m._buffers
72+
assert (
73+
hasattr(m, 'z')
74+
and m.z is z
75+
and 'z' not in m._meta_parameters
76+
and 'z' not in m._parameters
77+
and 'z' not in m._buffers
78+
)
7379

7480
del m.x
7581
object.__setattr__(m, 'x', x)
@@ -82,6 +88,67 @@ def test_register_tensors() -> None:
8288
m.b = b
8389
assert m.b is b and 'b' in m._buffers
8490

91+
m = torchopt.nn.MetaGradientModule(x, b)
92+
93+
with pytest.raises(
94+
TypeError,
95+
match=re.escape('parameter name should be a string. Got bytes'),
96+
):
97+
m.register_meta_parameter(b'x', x)
98+
99+
with pytest.raises(
100+
KeyError,
101+
match=re.escape("parameter name can't contain '.'"),
102+
):
103+
m.register_meta_parameter('x.x', x)
104+
105+
with pytest.raises(
106+
KeyError,
107+
match=re.escape("parameter name can't be empty string ''"),
108+
):
109+
m.register_meta_parameter('', x)
110+
111+
m.register_buffer('z', None)
112+
with pytest.raises(
113+
KeyError,
114+
match=re.escape("attribute 'z' already exists"),
115+
):
116+
m.register_meta_parameter('z', x)
117+
118+
with pytest.raises(
119+
ValueError,
120+
match=re.escape(
121+
"cannot assign Tensor that is a meta-parameter to parameter 'x'. "
122+
'Use self.register_meta_parameter() instead.'
123+
),
124+
):
125+
m.register_parameter('x', x)
126+
127+
m.x = x
128+
with pytest.raises(
129+
KeyError,
130+
match=re.escape("attribute 'x' already exists"),
131+
):
132+
m.register_parameter('x', x)
133+
134+
with pytest.raises(
135+
TypeError,
136+
match=re.escape('parameter name should be a string. Got bytes'),
137+
):
138+
m.register_parameter(b'y', y)
139+
140+
with pytest.raises(
141+
KeyError,
142+
match=re.escape("parameter name can't contain '.'"),
143+
):
144+
m.register_parameter('y.x', y)
145+
146+
with pytest.raises(
147+
KeyError,
148+
match=re.escape("parameter name can't be empty string ''"),
149+
):
150+
m.register_parameter('', y)
151+
85152

86153
def test_no_super_init() -> None:
87154
class NoSuper1(torchopt.nn.MetaGradientModule):

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy