Skip to content

Commit 844324a

Browse files
committed
test: reorganize tests
1 parent a15f4e2 commit 844324a

File tree

6 files changed

+100
-147
lines changed

6 files changed

+100
-147
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ addlicense-install: go-install
8282
# Tests
8383

8484
pytest: pytest-install
85-
cd tests && $(PYTHON) -m pytest unit --cov $(PROJECT_PATH) --durations 0 -v --cov-report term-missing --color=yes
85+
cd tests && $(PYTHON) -m pytest . --cov $(PROJECT_PATH) --durations 0 -v --cov-report term-missing --color=yes
8686

8787
test: pytest
8888

tests/helpers.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2022 MetaOPT Team. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
import copy
17+
import itertools
18+
import os
19+
import random
20+
from typing import Optional, Tuple, Union
21+
22+
import numpy as np
23+
import pytest
24+
import torch
25+
import torch.nn as nn
26+
from torch.utils import data
27+
from torchvision import models
28+
29+
30+
def seed_everything(seed: int) -> None:
31+
os.environ['PYTHONHASHSEED'] = str(seed)
32+
33+
random.seed(seed)
34+
np.random.seed(seed)
35+
36+
torch.manual_seed(seed)
37+
torch.cuda.manual_seed(seed)
38+
torch.cuda.manual_seed_all(seed)
39+
try:
40+
torch.use_deterministic_algorithms(True)
41+
except AttributeError:
42+
pass
43+
44+
45+
def parametrize(**argvalues) -> pytest.mark.parametrize:
46+
arguments = tuple(argvalues)
47+
argvalues = tuple(itertools.product(*tuple(map(argvalues.get, arguments))))
48+
ids = tuple(
49+
'-'.join(f'{arg}({val})' for arg, val in zip(arguments, values)) for values in argvalues
50+
)
51+
52+
return pytest.mark.parametrize(arguments, argvalues, ids=ids)
53+
54+
55+
def get_models(
56+
device: Optional[Union[str, torch.device]] = None, dtype: torch.dtype = torch.float32
57+
) -> Tuple[nn.Module, nn.Module, data.DataLoader]:
58+
seed_everything(seed=42)
59+
60+
model = models.resnet18().to(dtype=dtype)
61+
model_ref = copy.deepcopy(model)
62+
if device is not None:
63+
model = model.to(device=torch.device(device))
64+
model_ref = model_ref.to(device=torch.device(device))
65+
66+
batch_size = 8
67+
dataset = data.TensorDataset(
68+
torch.randn(batch_size * 2, 3, 224, 224), torch.randint(0, 1000, (batch_size * 2,))
69+
)
70+
loader = data.DataLoader(dataset, batch_size, shuffle=False)
71+
72+
return model, model_ref, loader

tests/unit/high_level/test_high_level_inplace.py renamed to tests/high_level/test_high_level_inplace.py

Lines changed: 12 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -13,63 +13,24 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16-
import copy
17-
import itertools
18-
import random
19-
from typing import Optional, Tuple, Union
16+
from typing import Tuple
2017

21-
import numpy as np
2218
import pytest
2319
import torch
24-
import torch.nn as nn
2520
import torch.nn.functional as F
26-
from torch.utils import data
27-
from torchvision import models
2821

22+
import helpers
2923
import torchopt
3024

3125

32-
def parametrize(**argvalues) -> pytest.mark.parametrize:
33-
arguments = tuple(argvalues)
34-
argvalues = tuple(itertools.product(*tuple(map(argvalues.get, arguments))))
35-
ids = tuple(
36-
'-'.join(f'{arg}({val})' for arg, val in zip(arguments, values)) for values in argvalues
37-
)
38-
39-
return pytest.mark.parametrize(arguments, argvalues, ids=ids)
40-
41-
42-
def get_models(
43-
device: Optional[Union[str, torch.device]] = None, dtype: torch.dtype = torch.float32
44-
) -> Tuple[nn.Module, nn.Module, data.DataLoader]:
45-
random.seed(0)
46-
np.random.seed(0)
47-
torch.manual_seed(0)
48-
torch.cuda.manual_seed_all(0)
49-
50-
model = models.resnet18().to(dtype=dtype)
51-
model_ref = copy.deepcopy(model)
52-
if device is not None:
53-
model = model.to(device=torch.device(device))
54-
model_ref = model_ref.to(device=torch.device(device))
55-
56-
batch_size = 8
57-
dataset = data.TensorDataset(
58-
torch.randn(batch_size * 2, 3, 224, 224), torch.randint(0, 1000, (batch_size * 2,))
59-
)
60-
loader = data.DataLoader(dataset, batch_size, shuffle=False)
61-
62-
return model, model_ref, loader
63-
64-
65-
@parametrize(
26+
@helpers.parametrize(
6627
dtype=[torch.float32, torch.float64],
6728
lr=[1e-3, 1e-4, 1e-5],
6829
momentum=[0.0, 0.1, 0.2],
6930
nesterov=[False, True],
7031
)
7132
def test_sgd(dtype: torch.dtype, lr: float, momentum: float, nesterov: bool) -> None:
72-
model, model_ref, loader = get_models(device='cpu', dtype=dtype)
33+
model, model_ref, loader = helpers.get_models(device='cpu', dtype=dtype)
7334

7435
optim = torchopt.SGD(
7536
model.parameters(), lr, momentum=(momentum if momentum != 0.0 else None), nesterov=nesterov
@@ -102,14 +63,14 @@ def test_sgd(dtype: torch.dtype, lr: float, momentum: float, nesterov: bool) ->
10263
assert torch.allclose(b, b_ref), f'{b!r} != {b_ref!r}'
10364

10465

105-
@parametrize(
66+
@helpers.parametrize(
10667
dtype=[torch.float32, torch.float64],
10768
lr=[1e-3, 1e-4, 1e-5],
10869
betas=[(0.9, 0.999), (0.95, 0.9995)],
10970
eps=[1e-8, 1e-6],
11071
)
11172
def test_adam(dtype: torch.dtype, lr: float, betas: Tuple[float, float], eps: float) -> None:
112-
model, model_ref, loader = get_models(device='cpu', dtype=dtype)
73+
model, model_ref, loader = helpers.get_models(device='cpu', dtype=dtype)
11374

11475
optim = torchopt.Adam(model.parameters(), lr, b1=betas[0], b2=betas[1], eps=eps, eps_root=0.0)
11576
optim_ref = torch.optim.Adam(
@@ -140,7 +101,7 @@ def test_adam(dtype: torch.dtype, lr: float, betas: Tuple[float, float], eps: fl
140101
assert torch.allclose(b, b_ref), f'{b!r} != {b_ref!r}'
141102

142103

143-
@parametrize(
104+
@helpers.parametrize(
144105
dtype=[torch.float32, torch.float64],
145106
lr=[1e-3, 1e-4, 1e-5],
146107
betas=[(0.9, 0.999), (0.95, 0.9995)],
@@ -149,7 +110,7 @@ def test_adam(dtype: torch.dtype, lr: float, betas: Tuple[float, float], eps: fl
149110
def test_accelerated_adam_cpu(
150111
dtype: torch.dtype, lr: float, betas: Tuple[float, float], eps: float
151112
) -> None:
152-
model, model_ref, loader = get_models(device='cpu', dtype=dtype)
113+
model, model_ref, loader = helpers.get_models(device='cpu', dtype=dtype)
153114

154115
optim = torchopt.Adam(
155116
model.parameters(),
@@ -189,7 +150,7 @@ def test_accelerated_adam_cpu(
189150

190151

191152
@pytest.mark.skipif(not torch.cuda.is_available(), reason='No CUDA device available.')
192-
@parametrize(
153+
@helpers.parametrize(
193154
dtype=[torch.float32, torch.float64],
194155
lr=[1e-3, 1e-4, 1e-5],
195156
betas=[(0.9, 0.999), (0.95, 0.9995)],
@@ -199,7 +160,7 @@ def test_accelerated_adam_cuda(
199160
dtype: torch.dtype, lr: float, betas: Tuple[float, float], eps: float
200161
) -> None:
201162
device = 'cuda'
202-
model, model_ref, loader = get_models(device=device, dtype=dtype)
163+
model, model_ref, loader = helpers.get_models(device=device, dtype=dtype)
203164

204165
optim = torchopt.Adam(
205166
model.parameters(),
@@ -239,7 +200,7 @@ def test_accelerated_adam_cuda(
239200
assert torch.allclose(b, b_ref), f'{b!r} != {b_ref!r}'
240201

241202

242-
@parametrize(
203+
@helpers.parametrize(
243204
dtype=[torch.float32, torch.float64],
244205
lr=[1e-3, 1e-4, 1e-5],
245206
alpha=[0.9, 0.99],
@@ -250,7 +211,7 @@ def test_accelerated_adam_cuda(
250211
def test_rmsprop(
251212
dtype: torch.dtype, lr: float, alpha: float, eps: float, momentum: float, centered: bool
252213
) -> None:
253-
model, model_ref, loader = get_models(device='cpu', dtype=dtype)
214+
model, model_ref, loader = helpers.get_models(device='cpu', dtype=dtype)
254215

255216
optim = torchopt.RMSProp(
256217
model.parameters(),

tests/unit/low_level/test_low_level_inplace.py renamed to tests/low_level/test_low_level_inplace.py

Lines changed: 12 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -13,64 +13,25 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16-
import copy
17-
import itertools
18-
import random
19-
from typing import Optional, Tuple, Union
16+
from typing import Tuple
2017

2118
import functorch
22-
import numpy as np
2319
import pytest
2420
import torch
25-
import torch.nn as nn
2621
import torch.nn.functional as F
27-
from torch.utils import data
28-
from torchvision import models
2922

23+
import helpers
3024
import torchopt
3125

3226

33-
def parametrize(**argvalues) -> pytest.mark.parametrize:
34-
arguments = tuple(argvalues)
35-
argvalues = tuple(itertools.product(*tuple(map(argvalues.get, arguments))))
36-
ids = tuple(
37-
'-'.join(f'{arg}({val})' for arg, val in zip(arguments, values)) for values in argvalues
38-
)
39-
40-
return pytest.mark.parametrize(arguments, argvalues, ids=ids)
41-
42-
43-
def get_models(
44-
device: Optional[Union[str, torch.device]] = None, dtype: torch.dtype = torch.float32
45-
) -> Tuple[nn.Module, nn.Module, data.DataLoader]:
46-
random.seed(0)
47-
np.random.seed(0)
48-
torch.manual_seed(0)
49-
torch.cuda.manual_seed_all(0)
50-
51-
model = models.resnet18().to(dtype=dtype)
52-
model_ref = copy.deepcopy(model)
53-
if device is not None:
54-
model = model.to(device=torch.device(device))
55-
model_ref = model_ref.to(device=torch.device(device))
56-
57-
batch_size = 8
58-
dataset = data.TensorDataset(
59-
torch.randn(batch_size * 2, 3, 224, 224), torch.randint(0, 1000, (batch_size * 2,))
60-
)
61-
loader = data.DataLoader(dataset, batch_size, shuffle=False)
62-
63-
return model, model_ref, loader
64-
65-
66-
@parametrize(
27+
@helpers.parametrize(
6728
dtype=[torch.float32, torch.float64],
6829
lr=[1e-3, 1e-4, 1e-5],
6930
momentum=[0.0, 0.1, 0.2],
7031
nesterov=[False, True],
7132
)
7233
def test_sgd(dtype: torch.dtype, lr: float, momentum: float, nesterov: bool) -> None:
73-
model, model_ref, loader = get_models(device='cpu', dtype=dtype)
34+
model, model_ref, loader = helpers.get_models(device='cpu', dtype=dtype)
7435

7536
fun, params, buffers = functorch.make_functional_with_buffers(model)
7637
optim = torchopt.sgd(lr, momentum=(momentum if momentum != 0.0 else None), nesterov=nesterov)
@@ -103,14 +64,14 @@ def test_sgd(dtype: torch.dtype, lr: float, momentum: float, nesterov: bool) ->
10364
assert torch.allclose(b, b_ref), f'{b!r} != {b_ref!r}'
10465

10566

106-
@parametrize(
67+
@helpers.parametrize(
10768
dtype=[torch.float32, torch.float64],
10869
lr=[1e-3, 1e-4, 1e-5],
10970
betas=[(0.9, 0.999), (0.95, 0.9995)],
11071
eps=[1e-8, 1e-6],
11172
)
11273
def test_adam(dtype: torch.dtype, lr: float, betas: Tuple[float, float], eps: float) -> None:
113-
model, model_ref, loader = get_models(device='cpu', dtype=dtype)
74+
model, model_ref, loader = helpers.get_models(device='cpu', dtype=dtype)
11475

11576
fun, params, buffers = functorch.make_functional_with_buffers(model)
11677
optim = torchopt.adam(lr, b1=betas[0], b2=betas[1], eps=eps, eps_root=0.0)
@@ -143,7 +104,7 @@ def test_adam(dtype: torch.dtype, lr: float, betas: Tuple[float, float], eps: fl
143104
assert torch.allclose(b, b_ref), f'{b!r} != {b_ref!r}'
144105

145106

146-
@parametrize(
107+
@helpers.parametrize(
147108
dtype=[torch.float32, torch.float64],
148109
lr=[1e-3, 1e-4, 1e-5],
149110
betas=[(0.9, 0.999), (0.95, 0.9995)],
@@ -152,7 +113,7 @@ def test_adam(dtype: torch.dtype, lr: float, betas: Tuple[float, float], eps: fl
152113
def test_accelerated_adam_cpu(
153114
dtype: torch.dtype, lr: float, betas: Tuple[float, float], eps: float
154115
) -> None:
155-
model, model_ref, loader = get_models(device='cpu', dtype=dtype)
116+
model, model_ref, loader = helpers.get_models(device='cpu', dtype=dtype)
156117

157118
fun, params, buffers = functorch.make_functional_with_buffers(model)
158119
optim = torchopt.adam(
@@ -188,7 +149,7 @@ def test_accelerated_adam_cpu(
188149

189150

190151
@pytest.mark.skipif(not torch.cuda.is_available(), reason='No CUDA device available.')
191-
@parametrize(
152+
@helpers.parametrize(
192153
dtype=[torch.float32, torch.float64],
193154
lr=[1e-3, 1e-4, 1e-5],
194155
betas=[(0.9, 0.999), (0.95, 0.9995)],
@@ -198,7 +159,7 @@ def test_accelerated_adam_cuda(
198159
dtype: torch.dtype, lr: float, betas: Tuple[float, float], eps: float
199160
) -> None:
200161
device = 'cuda'
201-
model, model_ref, loader = get_models(device=device, dtype=dtype)
162+
model, model_ref, loader = helpers.get_models(device=device, dtype=dtype)
202163

203164
fun, params, buffers = functorch.make_functional_with_buffers(model)
204165
optim = torchopt.adam(
@@ -234,7 +195,7 @@ def test_accelerated_adam_cuda(
234195
assert torch.allclose(b, b_ref), f'{b!r} != {b_ref!r}'
235196

236197

237-
@parametrize(
198+
@helpers.parametrize(
238199
dtype=[torch.float32, torch.float64],
239200
lr=[1e-3, 1e-4, 1e-5],
240201
alpha=[0.9, 0.99],
@@ -245,7 +206,7 @@ def test_accelerated_adam_cuda(
245206
def test_rmsprop(
246207
dtype: torch.dtype, lr: float, alpha: float, eps: float, momentum: float, centered: bool
247208
) -> None:
248-
model, model_ref, loader = get_models(device='cpu', dtype=dtype)
209+
model, model_ref, loader = helpers.get_models(device='cpu', dtype=dtype)
249210

250211
fun, params, buffers = functorch.make_functional_with_buffers(model)
251212
optim = torchopt.rmsprop(

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy