Skip to content

fix: fix transpose empty iterable with zip(*nested) in transformations #145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
test: add more tests
  • Loading branch information
XuehaiPan committed Mar 3, 2023
commit f3d86ee15ab807f10a6400497837bb2488573070
9 changes: 9 additions & 0 deletions tests/.coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,12 @@ omit =
../docs/*
../examples/*
../tutorials/*

[report]
exclude_lines =
pragma: no cover
raise NotImplementedError
class .*\bProtocol\):
@(abc\.)?abstractmethod
if __name__ == ('__main__'|"__main__"):
if TYPE_CHECKING:
16 changes: 9 additions & 7 deletions tests/test_implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ def __init__(self, x):
self.x = x

def objective(self):
return self.x
return self.x.mean()

def solve(self):
pass
Expand All @@ -722,6 +722,11 @@ def solve(self):
with pytest.raises(RuntimeError, match='The module has no parameters.'):
model.optimality()

model = EmptyParameters(torch.zeros(8))
model.register_parameter('y', torch.zeros(8, requires_grad=True))
with pytest.raises(RuntimeError, match='The module has no meta-parameters.'):
model.optimality()

model = EmptyParameters(torch.zeros(8, requires_grad=True))
model.register_parameter('y', torch.zeros(8, requires_grad=True))
model.solve()
Expand Down Expand Up @@ -757,10 +762,7 @@ class MyModule1(torchopt.nn.ImplicitMetaGradientModule):
def objective(self):
return torch.tensor(0.0)

with pytest.raises(
TypeError,
match="Can't instantiate abstract class",
):
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
MyModule1()

with pytest.raises(
Expand Down Expand Up @@ -820,12 +822,12 @@ def solve(self):

with pytest.raises(
TypeError,
match=re.escape('method optimality() must not be a staticmethod.'),
match=re.escape('method objective(() must not be a staticmethod.'),
):

class MyModule7(torchopt.nn.ImplicitMetaGradientModule):
@staticmethod
def optimality():
def objective():
return ()

def solve(self):
Expand Down
10 changes: 2 additions & 8 deletions tests/test_zero_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,12 @@ class MyModule1(torchopt.nn.ZeroOrderGradientModule):
def forward(self):
return torch.tensor(0.0)

with pytest.raises(
TypeError,
match="Can't instantiate abstract class",
):
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
MyModule1()

class MyModule2(torchopt.nn.ZeroOrderGradientModule):
def sample(self, sample_shape):
return torch.tensor(0.0)

with pytest.raises(
TypeError,
match="Can't instantiate abstract class",
):
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
MyModule2()
2 changes: 1 addition & 1 deletion torchopt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from typing import TYPE_CHECKING, Callable, NamedTuple, Protocol


if TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING:
from torchopt.typing import OptState, Params, Updates


Expand Down
2 changes: 1 addition & 1 deletion torchopt/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def sample(
) -> Union[Tensor, Sequence[Numeric]]:
# pylint: disable-next=line-too-long
"""Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched."""
raise NotImplementedError # pragma: no cover
raise NotImplementedError


Samplable.register(Distribution)
2 changes: 1 addition & 1 deletion torchopt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from torchopt.typing import Device, ModuleTensorContainers, OptState, TensorContainer, TensorTree


if TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING:
from torchopt.optim.meta.base import MetaOptimizer


Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy