Skip to content

Commit 5dd8606

Browse files
test: refactor tests using pytest.mark.parametrize (#55)
Co-authored-by: Benjamin-eecs <benjaminliu.eecs@gmail.com>
1 parent 31fd1f6 commit 5dd8606

26 files changed

+891
-654
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ concurrency:
2929
jobs:
3030
test:
3131
runs-on: ubuntu-latest
32-
timeout-minutes: 30
32+
timeout-minutes: 60
3333
steps:
3434
- name: Checkout
3535
uses: actions/checkout@v3

CHANGELOG.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313

1414
### Added
1515

16-
- Add question/help/support issue template by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#43](https://github.com/metaopt/TorchOpt/pull/43).
17-
- Add parallel training on one GPU using functorch.vmap example by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#32](https://github.com/metaopt/TorchOpt/pull/32).
16+
- Refactor tests using `pytest.mark.parametrize` and enabling parallel testing by [@XuehaiPan](https://github.com/XuehaiPan) in [#55](https://github.com/metaopt/TorchOpt/pull/55).
1817
- Add maml-omniglot few-shot classification example using functorch.vmap by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#39](https://github.com/metaopt/TorchOpt/pull/39).
19-
18+
- Add parallel training on one GPU using functorch.vmap example by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#32](https://github.com/metaopt/TorchOpt/pull/32).
19+
- Add question/help/support issue template by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#43](https://github.com/metaopt/TorchOpt/pull/43).
2020

2121
### Changed
2222

@@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2525

2626
### Fixed
2727

28+
- Fix RMSProp optimizer by [@XuehaiPan](https://github.com/XuehaiPan) in [#55](https://github.com/metaopt/TorchOpt/pull/55).
2829
- Fix momentum tracing by [@XuehaiPan](https://github.com/XuehaiPan) in [#58](https://github.com/metaopt/TorchOpt/pull/58).
2930
- Fix CUDA build for accelerated OP by [@XuehaiPan](https://github.com/XuehaiPan) in [#53](https://github.com/metaopt/TorchOpt/pull/53).
3031
- Fix gamma error in MAML-RL implementation by [@Benjamin-eecs](https://github.com/Benjamin-eecs) [#47](https://github.com/metaopt/TorchOpt/pull/47).

Makefile

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,9 @@ addlicense-install: go-install
8383

8484
pytest: pytest-install
8585
cd tests && \
86-
$(PYTHON) -m pytest unit --verbose --color=yes --durations=0 \
87-
--cov="$(PROJECT_PATH)" --cov-report=xml --cov-report=term-missing
86+
$(PYTHON) -m pytest --verbose --color=yes --durations=0 \
87+
--cov="$(PROJECT_NAME)" --cov-report=xml --cov-report=term-missing \
88+
.
8889

8990
test: pytest
9091

docs/conda-recipe.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ dependencies:
3131

3232
# Learning
3333
- pytorch::pytorch >= 1.12
34-
- pytorch::torchvision
3534
- pytorch::pytorch-mutex = *=*cpu*
3635
- pip:
3736
- functorch >= 0.2

docs/requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
--extra-index-url https://download.pytorch.org/whl/cpu
22
torch >= 1.12
3-
torchvision
43
functorch >= 0.2
54

65
--requirement ../requirements.txt

pyproject.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ lint = [
7373
"pre-commit",
7474
]
7575
test = [
76-
'torchvision',
7776
'functorch >= 0.2',
7877
'pytest',
7978
'pytest-cov',
@@ -123,10 +122,10 @@ test-command = """
123122
TORCH_VERSION="$(python -c 'print(__import__("torch").__version__.partition("+")[0])')"
124123
TEST_TORCH_SPECS="${TEST_TORCH_SPECS:-"${DEFAULT_TEST_TORCH_SPECS}"}"
125124
for spec in ${TEST_TORCH_SPECS}; do
126-
python -m pip uninstall -y torch torchvision
125+
python -m pip uninstall -y torch
127126
export PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/${spec}"
128127
echo "PIP_EXTRA_INDEX_URL='${PIP_EXTRA_INDEX_URL}'"
129-
python -m pip install "torch==${TORCH_VERSION}" torchvision
128+
python -m pip install "torch==${TORCH_VERSION}"
130129
echo "ls ${TORCH_LIB_PATH}"; ls -lh "${TORCH_LIB_PATH}"
131130
find "${SITE_PACKAGES}/torchopt" -name "*.so" -print0 |
132131
xargs -0 -I '{}' bash -c "echo 'ldd {}'; ldd '{}'; echo 'patchelf --print-rpath {}'; patchelf --print-rpath '{}'"

tests/conftest.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 2022 MetaOPT Team. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
import os
17+
18+
19+
os.environ['PYTHONHASHSEED'] = '0'
20+
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16:8'

tests/helpers.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright 2022 MetaOPT Team. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
import copy
17+
import itertools
18+
import os
19+
import random
20+
from typing import Iterable, Optional, Tuple, Union
21+
22+
import numpy as np
23+
import pytest
24+
import torch
25+
import torch.nn as nn
26+
from torch.utils import data
27+
28+
29+
BATCH_SIZE = 4
30+
NUM_UPDATES = 3
31+
32+
MODEL_NUM_INPUTS = 28 * 28 # MNIST
33+
MODEL_NUM_CLASSES = 10
34+
MODEL_HIDDEN_SIZE = 64
35+
36+
37+
def parametrize(**argvalues) -> pytest.mark.parametrize:
38+
arguments = list(argvalues)
39+
40+
if 'dtype' in argvalues:
41+
dtypes = argvalues['dtype']
42+
argvalues['dtype'] = dtypes[:1]
43+
arguments.remove('dtype')
44+
arguments.insert(0, 'dtype')
45+
46+
argvalues = list(itertools.product(*tuple(map(argvalues.get, arguments))))
47+
first_product = argvalues[0]
48+
argvalues.extend((dtype,) + first_product[1:] for dtype in dtypes[1:])
49+
50+
ids = tuple(
51+
'-'.join(f'{arg}({val})' for arg, val in zip(arguments, values)) for values in argvalues
52+
)
53+
54+
return pytest.mark.parametrize(arguments, argvalues, ids=ids)
55+
56+
57+
def seed_everything(seed: int) -> None:
58+
os.environ['PYTHONHASHSEED'] = str(seed)
59+
60+
random.seed(seed)
61+
np.random.seed(seed)
62+
63+
torch.manual_seed(seed)
64+
torch.cuda.manual_seed(seed)
65+
torch.cuda.manual_seed_all(seed)
66+
try:
67+
torch.use_deterministic_algorithms(True)
68+
except AttributeError:
69+
pass
70+
71+
72+
@torch.no_grad()
73+
def get_models(
74+
device: Optional[Union[str, torch.device]] = None, dtype: torch.dtype = torch.float32
75+
) -> Tuple[nn.Module, nn.Module, nn.Module, data.DataLoader]:
76+
seed_everything(seed=42)
77+
78+
model_base = nn.Sequential(
79+
nn.Linear(
80+
in_features=MODEL_NUM_INPUTS,
81+
out_features=MODEL_HIDDEN_SIZE,
82+
bias=True,
83+
dtype=dtype,
84+
),
85+
nn.ReLU(),
86+
nn.Linear(
87+
in_features=MODEL_HIDDEN_SIZE,
88+
out_features=MODEL_HIDDEN_SIZE,
89+
bias=True,
90+
dtype=dtype,
91+
),
92+
nn.ReLU(),
93+
nn.Linear(
94+
in_features=MODEL_HIDDEN_SIZE,
95+
out_features=MODEL_NUM_CLASSES,
96+
bias=True,
97+
dtype=dtype,
98+
),
99+
nn.Softmax(dim=-1),
100+
)
101+
for name, param in model_base.named_parameters(recurse=True):
102+
if name.endswith('weight'):
103+
nn.init.orthogonal_(param)
104+
if name.endswith('bias'):
105+
param.data.normal_(0, 0.1)
106+
107+
model = copy.deepcopy(model_base)
108+
model_ref = copy.deepcopy(model_base)
109+
if device is not None:
110+
model_base = model_base.to(device=torch.device(device))
111+
model = model.to(device=torch.device(device))
112+
model_ref = model_ref.to(device=torch.device(device))
113+
114+
dataset = data.TensorDataset(
115+
torch.randint(0, 1, (BATCH_SIZE * NUM_UPDATES, MODEL_NUM_INPUTS)),
116+
torch.randint(0, MODEL_NUM_CLASSES, (BATCH_SIZE * NUM_UPDATES,)),
117+
)
118+
loader = data.DataLoader(dataset, BATCH_SIZE, shuffle=False)
119+
120+
return model, model_ref, model_base, loader
121+
122+
123+
@torch.no_grad()
124+
def assert_model_all_close(
125+
model: Union[nn.Module, Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]],
126+
model_ref: nn.Module,
127+
model_base: nn.Module,
128+
dtype: torch.dtype = torch.float32,
129+
rtol: Optional[float] = None,
130+
atol: Optional[float] = None,
131+
equal_nan: bool = False,
132+
):
133+
134+
if isinstance(model, tuple):
135+
params, buffers = model
136+
elif isinstance(model, nn.Module):
137+
params = model.parameters()
138+
buffers = model.buffers()
139+
140+
for p, p_ref, p_base in zip(params, model_ref.parameters(), model_base.parameters()):
141+
assert_all_close(p, p_ref, base=p_base, rtol=rtol, atol=atol, equal_nan=equal_nan)
142+
for b, b_ref, b_base in zip(buffers, model_ref.buffers(), model_base.buffers()):
143+
b = b.to(dtype=dtype) if not b.is_floating_point() else b
144+
b_ref = b_ref.to(dtype=dtype) if not b_ref.is_floating_point() else b_ref
145+
b_base = b_base.to(dtype=dtype) if not b_base.is_floating_point() else b_base
146+
assert_all_close(b, b_ref, base=b_base, rtol=rtol, atol=atol, equal_nan=equal_nan)
147+
148+
149+
@torch.no_grad()
150+
def assert_all_close(
151+
actual: torch.Tensor,
152+
expected: torch.Tensor,
153+
base: torch.Tensor = None,
154+
rtol: Optional[float] = None,
155+
atol: Optional[float] = None,
156+
equal_nan: bool = False,
157+
) -> None:
158+
159+
if base is not None:
160+
actual = actual - base
161+
expected = expected - base
162+
163+
torch.testing.assert_close(
164+
actual,
165+
expected,
166+
rtol=rtol,
167+
atol=atol,
168+
equal_nan=equal_nan,
169+
check_dtype=True,
170+
)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy