Skip to content

Commit 4c726f2

Browse files
committed
test: update test parameters
1 parent 9976f96 commit 4c726f2

File tree

3 files changed

+10
-10
lines changed

3 files changed

+10
-10
lines changed

tests/helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,8 @@ def assert_all_close(
174174
from torch.testing._comparison import get_tolerances
175175

176176
rtol, atol = get_tolerances(actual, expected, rtol=rtol, atol=atol)
177-
rtol *= 10 * NUM_UPDATES
178-
atol *= 10 * NUM_UPDATES
177+
rtol *= 4 * NUM_UPDATES
178+
atol *= 4 * NUM_UPDATES
179179

180180
torch.testing.assert_close(
181181
actual,

tests/test_alias.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def test_sgd(
8888

8989

9090
@helpers.parametrize(
91-
dtype=[torch.float64, torch.float32],
91+
dtype=[torch.float64],
9292
lr=[1e-2, 1e-3, 1e-4],
9393
betas=[(0.9, 0.999), (0.95, 0.9995)],
9494
eps=[1e-8],
@@ -146,7 +146,7 @@ def test_adam(
146146

147147

148148
@helpers.parametrize(
149-
dtype=[torch.float64, torch.float32],
149+
dtype=[torch.float64],
150150
lr=[1e-2, 1e-3, 1e-4],
151151
betas=[(0.9, 0.999), (0.95, 0.9995)],
152152
eps=[1e-8],
@@ -206,7 +206,7 @@ def test_adam_accelerated_cpu(
206206

207207
@pytest.mark.skipif(not torch.cuda.is_available(), reason='No CUDA device available.')
208208
@helpers.parametrize(
209-
dtype=[torch.float64, torch.float32],
209+
dtype=[torch.float64],
210210
lr=[1e-2, 1e-3, 1e-4],
211211
betas=[(0.9, 0.999), (0.95, 0.9995)],
212212
eps=[1e-8],
@@ -267,7 +267,7 @@ def test_adam_accelerated_cuda(
267267

268268

269269
@helpers.parametrize(
270-
dtype=[torch.float64, torch.float32],
270+
dtype=[torch.float64],
271271
lr=[1e-2, 1e-3, 1e-4],
272272
alpha=[0.9, 0.99],
273273
eps=[1e-8],

tests/test_optimizer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_SGD(
8484

8585

8686
@helpers.parametrize(
87-
dtype=[torch.float64, torch.float32],
87+
dtype=[torch.float64],
8888
lr=[1e-2, 1e-3, 1e-4],
8989
betas=[(0.9, 0.999), (0.95, 0.9995)],
9090
eps=[1e-8],
@@ -139,7 +139,7 @@ def test_Adam(
139139

140140

141141
@helpers.parametrize(
142-
dtype=[torch.float64, torch.float32],
142+
dtype=[torch.float64],
143143
lr=[1e-2, 1e-3, 1e-4],
144144
betas=[(0.9, 0.999), (0.95, 0.9995)],
145145
eps=[1e-8],
@@ -196,7 +196,7 @@ def test_Adam_accelerated_cpu(
196196

197197
@pytest.mark.skipif(not torch.cuda.is_available(), reason='No CUDA device available.')
198198
@helpers.parametrize(
199-
dtype=[torch.float64, torch.float32],
199+
dtype=[torch.float64],
200200
lr=[1e-2, 1e-3, 1e-4],
201201
betas=[(0.9, 0.999), (0.95, 0.9995)],
202202
eps=[1e-8],
@@ -254,7 +254,7 @@ def test_Adam_accelerated_cuda(
254254

255255

256256
@helpers.parametrize(
257-
dtype=[torch.float64, torch.float32],
257+
dtype=[torch.float64],
258258
lr=[1e-2, 1e-3, 1e-4],
259259
alpha=[0.9, 0.99],
260260
eps=[1e-8],

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy