Skip to content

Commit fd1642c

Browse files
committed
test: add test for empty parameters
1 parent e0c41cf commit fd1642c

File tree

3 files changed

+198
-17
lines changed

3 files changed

+198
-17
lines changed

tests/test_implicit.py

Lines changed: 128 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from __future__ import annotations
1717

1818
import copy
19+
import re
1920
from collections import OrderedDict
2021
from types import FunctionType
2122

@@ -698,8 +699,8 @@ def __init__(self, x):
698699
super().__init__()
699700
self.x = x
700701

701-
def optimality(self):
702-
return (self.x,)
702+
def objective(self):
703+
return self.x
703704

704705
def solve(self):
705706
pass
@@ -717,6 +718,10 @@ def solve(self):
717718
with pytest.raises(RuntimeError, match='The module has no parameters.'):
718719
model.solve()
719720

721+
model = EmptyParameters(torch.zeros(8, requires_grad=True))
722+
with pytest.raises(RuntimeError, match='The module has no parameters.'):
723+
model.optimality()
724+
720725
model = EmptyParameters(torch.zeros(8, requires_grad=True))
721726
model.register_parameter('y', torch.zeros(8, requires_grad=True))
722727
model.solve()
@@ -728,3 +733,124 @@ def solve(self):
728733
model = EmptyParameters(nn.Linear(8, 8))
729734
model.register_parameter('y', torch.zeros(8, requires_grad=True))
730735
model.solve()
736+
737+
738+
def test_module_enable_implicit_gradients_twice() -> None:
739+
class MyModule(torchopt.nn.ImplicitMetaGradientModule):
740+
def objective(self):
741+
return torch.tensor(0.0)
742+
743+
def solve(self):
744+
pass
745+
746+
from torchopt.diff.implicit.nn.module import enable_implicit_gradients
747+
748+
with pytest.raises(
749+
TypeError,
750+
match='Implicit gradients are already enabled for the `solve` method.',
751+
):
752+
enable_implicit_gradients(MyModule)
753+
754+
755+
def test_module_abstract_methods() -> None:
756+
class MyModule1(torchopt.nn.ImplicitMetaGradientModule):
757+
def objective(self):
758+
return torch.tensor(0.0)
759+
760+
with pytest.raises(
761+
TypeError,
762+
match="Can't instantiate abstract class",
763+
):
764+
MyModule1()
765+
766+
with pytest.raises(
767+
TypeError,
768+
match=re.escape(
769+
'ImplicitMetaGradientModule requires either an optimality() method or an objective() method'
770+
),
771+
):
772+
773+
class MyModule2(torchopt.nn.ImplicitMetaGradientModule):
774+
def solve(self):
775+
pass
776+
777+
class MyModule3(torchopt.nn.ImplicitMetaGradientModule):
778+
def optimality(self):
779+
return ()
780+
781+
def solve(self):
782+
pass
783+
784+
with pytest.raises(
785+
TypeError,
786+
match=re.escape('method optimality() must not be a staticmethod.'),
787+
):
788+
789+
class MyModule4(torchopt.nn.ImplicitMetaGradientModule):
790+
@staticmethod
791+
def optimality():
792+
return ()
793+
794+
def solve(self):
795+
pass
796+
797+
with pytest.raises(
798+
TypeError,
799+
match=re.escape('method optimality() must not be a classmethod.'),
800+
):
801+
802+
class MyModule5(torchopt.nn.ImplicitMetaGradientModule):
803+
@classmethod
804+
def optimality(self):
805+
return ()
806+
807+
def solve(self):
808+
pass
809+
810+
with pytest.raises(
811+
TypeError,
812+
match=re.escape('method optimality() must be callable.'),
813+
):
814+
815+
class MyModule6(torchopt.nn.ImplicitMetaGradientModule):
816+
optimality = 0
817+
818+
def solve(self):
819+
pass
820+
821+
with pytest.raises(
822+
TypeError,
823+
match=re.escape('method optimality() must not be a staticmethod.'),
824+
):
825+
826+
class MyModule7(torchopt.nn.ImplicitMetaGradientModule):
827+
@staticmethod
828+
def optimality():
829+
return ()
830+
831+
def solve(self):
832+
pass
833+
834+
with pytest.raises(
835+
TypeError,
836+
match=re.escape('method objective() must not be a classmethod.'),
837+
):
838+
839+
class MyModule8(torchopt.nn.ImplicitMetaGradientModule):
840+
@classmethod
841+
def objective(self):
842+
return ()
843+
844+
def solve(self):
845+
pass
846+
847+
with pytest.raises(
848+
TypeError,
849+
match=re.escape('method objective() must be callable.'),
850+
):
851+
852+
class MyModule9(torchopt.nn.ImplicitMetaGradientModule):
853+
objective = 0
854+
855+
def solve(self):
856+
pass

tests/test_zero_order.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# ==============================================================================
1515

1616
import functorch
17+
import pytest
1718
import torch
1819
import torch.nn as nn
1920
import torch.nn.functional as F
@@ -117,3 +118,55 @@ def sample(self, sample_shape=torch.Size()):
117118
optimizer.zero_grad()
118119
loss.backward() # compute gradients
119120
optimizer.step() # update network parameters
121+
122+
123+
def test_module_enable_zero_order_gradients_twice() -> None:
124+
class MyModule(torchopt.nn.ZeroOrderGradientModule):
125+
def forward(self):
126+
return torch.tensor(0.0)
127+
128+
def sample(self, sample_shape):
129+
return torch.tensor(0.0)
130+
131+
from torchopt.diff.zero_order.nn.module import enable_zero_order_gradients
132+
133+
with pytest.raises(
134+
TypeError,
135+
match='Zero-order gradient estimation is already enabled for the `forward` method.',
136+
):
137+
enable_zero_order_gradients(MyModule)
138+
139+
140+
def test_module_empty_parameters() -> None:
141+
class MyModule(torchopt.nn.ZeroOrderGradientModule):
142+
def forward(self):
143+
return torch.tensor(0.0)
144+
145+
def sample(self, sample_shape):
146+
return torch.tensor(0.0)
147+
148+
m = MyModule()
149+
with pytest.raises(RuntimeError, match='The module has no parameters.'):
150+
m()
151+
152+
153+
def test_module_abstract_methods() -> None:
154+
class MyModule1(torchopt.nn.ZeroOrderGradientModule):
155+
def forward(self):
156+
return torch.tensor(0.0)
157+
158+
with pytest.raises(
159+
TypeError,
160+
match="Can't instantiate abstract class",
161+
):
162+
MyModule1()
163+
164+
class MyModule2(torchopt.nn.ZeroOrderGradientModule):
165+
def sample(self, sample_shape):
166+
return torch.tensor(0.0)
167+
168+
with pytest.raises(
169+
TypeError,
170+
match="Can't instantiate abstract class",
171+
):
172+
MyModule2()

torchopt/diff/implicit/nn/module.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import abc
2222
import functools
23+
import inspect
2324
import itertools
2425
from typing import Any, Iterable
2526

@@ -79,10 +80,9 @@ def make_optimality_from_objective(
7980
cls: type[ImplicitMetaGradientModule],
8081
) -> type[ImplicitMetaGradientModule]:
8182
"""Derives the optimality function of the objective function."""
82-
if (
83-
getattr(cls, 'objective', ImplicitMetaGradientModule.objective)
84-
is ImplicitMetaGradientModule.objective
85-
):
83+
static_super_objective = inspect.getattr_static(ImplicitMetaGradientModule, 'objective')
84+
static_cls_optimality = inspect.getattr_static(cls, 'optimality', static_super_objective)
85+
if static_cls_optimality is static_super_objective:
8686
raise TypeError('The objective function is not defined.')
8787

8888
def optimality(self: ImplicitMetaGradientModule, *input: Any, **kwargs: Any) -> TupleOfTensors:
@@ -167,7 +167,7 @@ def wrapped(self: ImplicitMetaGradientModule, *input: Any, **kwargs: Any) -> Any
167167
return cls
168168

169169

170-
class ImplicitMetaGradientModule(MetaGradientModule):
170+
class ImplicitMetaGradientModule(MetaGradientModule, metaclass=abc.ABCMeta):
171171
"""The base class for differentiable implicit meta-gradient models."""
172172

173173
_custom_optimality: bool
@@ -179,28 +179,30 @@ def __init_subclass__(cls, linear_solve: LinearSolver | None = None) -> None:
179179
super().__init_subclass__()
180180
cls.linear_solve = linear_solve
181181

182-
optimality = getattr(cls, 'optimality', ImplicitMetaGradientModule.optimality)
183-
objective = getattr(cls, 'objective', ImplicitMetaGradientModule.objective)
184-
cls._custom_optimality = optimality is not ImplicitMetaGradientModule.optimality
185-
cls._custom_objective = objective is not ImplicitMetaGradientModule.objective
182+
static_super_optimality = inspect.getattr_static(ImplicitMetaGradientModule, 'optimality')
183+
static_super_objective = inspect.getattr_static(ImplicitMetaGradientModule, 'objective')
184+
static_cls_optimality = inspect.getattr_static(cls, 'optimality')
185+
static_cls_objective = inspect.getattr_static(cls, 'objective')
186+
cls._custom_optimality = static_cls_optimality is not static_super_optimality
187+
cls._custom_objective = static_cls_objective is not static_super_objective
186188

187189
if cls._custom_optimality:
188-
if isinstance(optimality, staticmethod):
190+
if isinstance(static_cls_optimality, staticmethod):
189191
raise TypeError('method optimality() must not be a staticmethod.')
190-
if isinstance(optimality, classmethod):
192+
if isinstance(static_cls_optimality, classmethod):
191193
raise TypeError('method optimality() must not be a classmethod.')
192-
if not callable(optimality):
194+
if not callable(static_cls_optimality):
193195
raise TypeError('method optimality() must be callable.')
194196
elif not cls._custom_objective:
195197
raise TypeError(
196198
'ImplicitMetaGradientModule requires either an optimality() method or an objective() method'
197199
)
198200
else:
199-
if isinstance(objective, staticmethod):
201+
if isinstance(static_cls_objective, staticmethod):
200202
raise TypeError('method objective() must not be a staticmethod.')
201-
if isinstance(objective, classmethod):
203+
if isinstance(static_cls_objective, classmethod):
202204
raise TypeError('method objective() must not be a classmethod.')
203-
if not callable(objective):
205+
if not callable(static_cls_objective):
204206
raise TypeError('method objective() must be callable.')
205207

206208
make_optimality_from_objective(cls)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy