Skip to content

test: refactor tests using pytest.mark.parametrize #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
dfb1b63
test: use `torch.allclose` to compare tensors
XuehaiPan Aug 8, 2022
63d5012
test: refactor tests using `pytest.mark.parametrize`
XuehaiPan Aug 8, 2022
d2ea7fc
chore: update test ids
XuehaiPan Aug 8, 2022
29284c4
chore: update test ids
XuehaiPan Aug 8, 2022
591999f
chore(workflows): increase timeout
XuehaiPan Aug 8, 2022
803c7b9
test: reorganize tests
XuehaiPan Aug 8, 2022
99a561d
feat: parallel testing
XuehaiPan Aug 8, 2022
5b7dedb
test: reorganize tests
XuehaiPan Aug 8, 2022
4992712
test: loop for more update iterations
XuehaiPan Aug 8, 2022
85bddb0
test: reduce number of tests
XuehaiPan Aug 8, 2022
92fa5c7
test: reduce number of tests
XuehaiPan Aug 8, 2022
acdb41a
test: update assert
XuehaiPan Aug 8, 2022
cf45c8e
test: rename variable
XuehaiPan Aug 9, 2022
aba2efc
test: update assert
XuehaiPan Aug 9, 2022
2d28f49
docs(CHANGELOG): update CHANGELOG.md
XuehaiPan Aug 10, 2022
2745eb1
test: cov project name
XuehaiPan Aug 10, 2022
de69946
test: update tests
XuehaiPan Aug 11, 2022
468c4dc
test: set `CUBLAS_WORKSPACE_CONFIG`
XuehaiPan Aug 12, 2022
c7e72c5
test: compare diffs
XuehaiPan Aug 15, 2022
83717bb
test: update tol
XuehaiPan Aug 15, 2022
998b400
test: test float32
XuehaiPan Aug 15, 2022
933eb5d
test: show tol in assert messages
XuehaiPan Aug 15, 2022
876e9d3
test: use smaller network for tests
XuehaiPan Aug 16, 2022
9450669
Merge remote-tracking branch 'upstream/main' into fix-tests
Benjamin-eecs Aug 22, 2022
7034ec3
fix: change the order of add and multiply in sgd
Benjamin-eecs Aug 22, 2022
4b30fc0
Merge branch 'main' into fix-tests
XuehaiPan Aug 22, 2022
1b0a35c
fix: correct test writing and pass sgd
Benjamin-eecs Aug 23, 2022
2a09c8c
Merge branch 'fix-tests' of https://github.com/XuehaiPan/TorchOpt int…
Benjamin-eecs Aug 23, 2022
d6179f6
fix: correct test writing and fix other optims
Benjamin-eecs Aug 23, 2022
48f4a94
to(tests): high level non differentiable optimizer unfixed
Benjamin-eecs Aug 23, 2022
4ddf147
to(tests): high level non differentiable optimizer unfixed
Benjamin-eecs Aug 23, 2022
d77cdd8
to(tests): high level non differentiable optimizer unfixed
Benjamin-eecs Aug 23, 2022
f32ef9f
test: disable parallel testing
XuehaiPan Aug 24, 2022
629a36f
fix(transform): fix momentum trace
XuehaiPan Aug 24, 2022
ca580a6
refactor: refactor transform and utils
XuehaiPan Aug 24, 2022
6771907
test: test inplace operators
XuehaiPan Aug 24, 2022
579e786
fix: fix RMSProp optimizer
XuehaiPan Aug 24, 2022
0cb3c44
fix: fix Makefile
XuehaiPan Aug 24, 2022
51d15b3
lint: appease linters
XuehaiPan Aug 24, 2022
1e1d86f
chore: update unused function
XuehaiPan Aug 24, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ concurrency:
jobs:
test:
runs-on: ubuntu-latest
timeout-minutes: 30
timeout-minutes: 60
steps:
- name: Checkout
uses: actions/checkout@v3
Expand Down
7 changes: 4 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add question/help/support issue template by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#43](https://github.com/metaopt/TorchOpt/pull/43).
- Add parallel training on one GPU using functorch.vmap example by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#32](https://github.com/metaopt/TorchOpt/pull/32).
- Refactor tests using `pytest.mark.parametrize` and enabling parallel testing by [@XuehaiPan](https://github.com/XuehaiPan) in [#55](https://github.com/metaopt/TorchOpt/pull/55).
- Add maml-omniglot few-shot classification example using functorch.vmap by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#39](https://github.com/metaopt/TorchOpt/pull/39).

- Add parallel training on one GPU using functorch.vmap example by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#32](https://github.com/metaopt/TorchOpt/pull/32).
- Add question/help/support issue template by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#43](https://github.com/metaopt/TorchOpt/pull/43).

### Changed

Expand All @@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fix RMSProp optimizer by [@XuehaiPan](https://github.com/XuehaiPan) in [#55](https://github.com/metaopt/TorchOpt/pull/55).
- Fix momentum tracing by [@XuehaiPan](https://github.com/XuehaiPan) in [#58](https://github.com/metaopt/TorchOpt/pull/58).
- Fix CUDA build for accelerated OP by [@XuehaiPan](https://github.com/XuehaiPan) in [#53](https://github.com/metaopt/TorchOpt/pull/53).
- Fix gamma error in MAML-RL implementation by [@Benjamin-eecs](https://github.com/Benjamin-eecs) [#47](https://github.com/metaopt/TorchOpt/pull/47).
Expand Down
5 changes: 3 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ addlicense-install: go-install

pytest: pytest-install
cd tests && \
$(PYTHON) -m pytest unit --verbose --color=yes --durations=0 \
--cov="$(PROJECT_PATH)" --cov-report=xml --cov-report=term-missing
$(PYTHON) -m pytest --verbose --color=yes --durations=0 \
--cov="$(PROJECT_NAME)" --cov-report=xml --cov-report=term-missing \
.

test: pytest

Expand Down
1 change: 0 additions & 1 deletion docs/conda-recipe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ dependencies:

# Learning
- pytorch::pytorch >= 1.12
- pytorch::torchvision
- pytorch::pytorch-mutex = *=*cpu*
- pip:
- functorch >= 0.2
Expand Down
1 change: 0 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
--extra-index-url https://download.pytorch.org/whl/cpu
torch >= 1.12
torchvision
functorch >= 0.2

--requirement ../requirements.txt
Expand Down
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ lint = [
"pre-commit",
]
test = [
'torchvision',
'functorch >= 0.2',
'pytest',
'pytest-cov',
Expand Down Expand Up @@ -123,10 +122,10 @@ test-command = """
TORCH_VERSION="$(python -c 'print(__import__("torch").__version__.partition("+")[0])')"
TEST_TORCH_SPECS="${TEST_TORCH_SPECS:-"${DEFAULT_TEST_TORCH_SPECS}"}"
for spec in ${TEST_TORCH_SPECS}; do
python -m pip uninstall -y torch torchvision
python -m pip uninstall -y torch
export PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/${spec}"
echo "PIP_EXTRA_INDEX_URL='${PIP_EXTRA_INDEX_URL}'"
python -m pip install "torch==${TORCH_VERSION}" torchvision
python -m pip install "torch==${TORCH_VERSION}"
echo "ls ${TORCH_LIB_PATH}"; ls -lh "${TORCH_LIB_PATH}"
find "${SITE_PACKAGES}/torchopt" -name "*.so" -print0 |
xargs -0 -I '{}' bash -c "echo 'ldd {}'; ldd '{}'; echo 'patchelf --print-rpath {}'; patchelf --print-rpath '{}'"
Expand Down
20 changes: 20 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2022 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import os


os.environ['PYTHONHASHSEED'] = '0'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16:8'
170 changes: 170 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Copyright 2022 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import copy
import itertools
import os
import random
from typing import Iterable, Optional, Tuple, Union

import numpy as np
import pytest
import torch
import torch.nn as nn
from torch.utils import data


BATCH_SIZE = 4
NUM_UPDATES = 3

MODEL_NUM_INPUTS = 28 * 28 # MNIST
MODEL_NUM_CLASSES = 10
MODEL_HIDDEN_SIZE = 64


def parametrize(**argvalues) -> pytest.mark.parametrize:
arguments = list(argvalues)

if 'dtype' in argvalues:
dtypes = argvalues['dtype']
argvalues['dtype'] = dtypes[:1]
arguments.remove('dtype')
arguments.insert(0, 'dtype')

argvalues = list(itertools.product(*tuple(map(argvalues.get, arguments))))
first_product = argvalues[0]
argvalues.extend((dtype,) + first_product[1:] for dtype in dtypes[1:])

ids = tuple(
'-'.join(f'{arg}({val})' for arg, val in zip(arguments, values)) for values in argvalues
)

return pytest.mark.parametrize(arguments, argvalues, ids=ids)


def seed_everything(seed: int) -> None:
os.environ['PYTHONHASHSEED'] = str(seed)

random.seed(seed)
np.random.seed(seed)

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
try:
torch.use_deterministic_algorithms(True)
except AttributeError:
pass


@torch.no_grad()
def get_models(
device: Optional[Union[str, torch.device]] = None, dtype: torch.dtype = torch.float32
) -> Tuple[nn.Module, nn.Module, nn.Module, data.DataLoader]:
seed_everything(seed=42)

model_base = nn.Sequential(
nn.Linear(
in_features=MODEL_NUM_INPUTS,
out_features=MODEL_HIDDEN_SIZE,
bias=True,
dtype=dtype,
),
nn.ReLU(),
nn.Linear(
in_features=MODEL_HIDDEN_SIZE,
out_features=MODEL_HIDDEN_SIZE,
bias=True,
dtype=dtype,
),
nn.ReLU(),
nn.Linear(
in_features=MODEL_HIDDEN_SIZE,
out_features=MODEL_NUM_CLASSES,
bias=True,
dtype=dtype,
),
nn.Softmax(dim=-1),
)
for name, param in model_base.named_parameters(recurse=True):
if name.endswith('weight'):
nn.init.orthogonal_(param)
if name.endswith('bias'):
param.data.normal_(0, 0.1)

model = copy.deepcopy(model_base)
model_ref = copy.deepcopy(model_base)
if device is not None:
model_base = model_base.to(device=torch.device(device))
model = model.to(device=torch.device(device))
model_ref = model_ref.to(device=torch.device(device))

dataset = data.TensorDataset(
torch.randint(0, 1, (BATCH_SIZE * NUM_UPDATES, MODEL_NUM_INPUTS)),
torch.randint(0, MODEL_NUM_CLASSES, (BATCH_SIZE * NUM_UPDATES,)),
)
loader = data.DataLoader(dataset, BATCH_SIZE, shuffle=False)

return model, model_ref, model_base, loader


@torch.no_grad()
def assert_model_all_close(
model: Union[nn.Module, Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]],
model_ref: nn.Module,
model_base: nn.Module,
dtype: torch.dtype = torch.float32,
rtol: Optional[float] = None,
atol: Optional[float] = None,
equal_nan: bool = False,
):

if isinstance(model, tuple):
params, buffers = model
elif isinstance(model, nn.Module):
params = model.parameters()
buffers = model.buffers()

for p, p_ref, p_base in zip(params, model_ref.parameters(), model_base.parameters()):
assert_all_close(p, p_ref, base=p_base, rtol=rtol, atol=atol, equal_nan=equal_nan)
for b, b_ref, b_base in zip(buffers, model_ref.buffers(), model_base.buffers()):
b = b.to(dtype=dtype) if not b.is_floating_point() else b
b_ref = b_ref.to(dtype=dtype) if not b_ref.is_floating_point() else b_ref
b_base = b_base.to(dtype=dtype) if not b_base.is_floating_point() else b_base
assert_all_close(b, b_ref, base=b_base, rtol=rtol, atol=atol, equal_nan=equal_nan)


@torch.no_grad()
def assert_all_close(
actual: torch.Tensor,
expected: torch.Tensor,
base: torch.Tensor = None,
rtol: Optional[float] = None,
atol: Optional[float] = None,
equal_nan: bool = False,
) -> None:

if base is not None:
actual = actual - base
expected = expected - base

torch.testing.assert_close(
actual,
expected,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
check_dtype=True,
)
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy