13
13
# limitations under the License.
14
14
# ==============================================================================
15
15
16
- import copy
17
- import itertools
18
- import random
19
- from typing import Optional , Tuple , Union
16
+ from typing import Tuple
20
17
21
18
import functorch
22
- import numpy as np
23
19
import pytest
24
20
import torch
25
- import torch .nn as nn
26
21
import torch .nn .functional as F
27
- from torch .utils import data
28
- from torchvision import models
29
22
23
+ import helpers
30
24
import torchopt
31
25
32
26
33
- def parametrize (** argvalues ) -> pytest .mark .parametrize :
34
- arguments = tuple (argvalues )
35
- argvalues = tuple (itertools .product (* tuple (map (argvalues .get , arguments ))))
36
- ids = tuple (
37
- '-' .join (f'{ arg } ({ val } )' for arg , val in zip (arguments , values )) for values in argvalues
38
- )
39
-
40
- return pytest .mark .parametrize (arguments , argvalues , ids = ids )
41
-
42
-
43
- def get_models (
44
- device : Optional [Union [str , torch .device ]] = None , dtype : torch .dtype = torch .float32
45
- ) -> Tuple [nn .Module , nn .Module , data .DataLoader ]:
46
- random .seed (0 )
47
- np .random .seed (0 )
48
- torch .manual_seed (0 )
49
- torch .cuda .manual_seed_all (0 )
50
-
51
- model = models .resnet18 ().to (dtype = dtype )
52
- model_ref = copy .deepcopy (model )
53
- if device is not None :
54
- model = model .to (device = torch .device (device ))
55
- model_ref = model_ref .to (device = torch .device (device ))
56
-
57
- batch_size = 8
58
- dataset = data .TensorDataset (
59
- torch .randn (batch_size * 2 , 3 , 224 , 224 ), torch .randint (0 , 1000 , (batch_size * 2 ,))
60
- )
61
- loader = data .DataLoader (dataset , batch_size , shuffle = False )
62
-
63
- return model , model_ref , loader
64
-
65
-
66
- @parametrize (
27
+ @helpers .parametrize (
67
28
dtype = [torch .float32 , torch .float64 ],
68
29
lr = [1e-3 , 1e-4 , 1e-5 ],
69
30
momentum = [0.0 , 0.1 , 0.2 ],
70
31
nesterov = [False , True ],
71
32
)
72
33
def test_sgd (dtype : torch .dtype , lr : float , momentum : float , nesterov : bool ) -> None :
73
- model , model_ref , loader = get_models (device = 'cpu' , dtype = dtype )
34
+ model , model_ref , loader = helpers . get_models (device = 'cpu' , dtype = dtype )
74
35
75
36
fun , params , buffers = functorch .make_functional_with_buffers (model )
76
37
optim = torchopt .sgd (lr , momentum = (momentum if momentum != 0.0 else None ), nesterov = nesterov )
@@ -103,14 +64,14 @@ def test_sgd(dtype: torch.dtype, lr: float, momentum: float, nesterov: bool) ->
103
64
assert torch .allclose (b , b_ref ), f'{ b !r} != { b_ref !r} '
104
65
105
66
106
- @parametrize (
67
+ @helpers . parametrize (
107
68
dtype = [torch .float32 , torch .float64 ],
108
69
lr = [1e-3 , 1e-4 , 1e-5 ],
109
70
betas = [(0.9 , 0.999 ), (0.95 , 0.9995 )],
110
71
eps = [1e-8 , 1e-6 ],
111
72
)
112
73
def test_adam (dtype : torch .dtype , lr : float , betas : Tuple [float , float ], eps : float ) -> None :
113
- model , model_ref , loader = get_models (device = 'cpu' , dtype = dtype )
74
+ model , model_ref , loader = helpers . get_models (device = 'cpu' , dtype = dtype )
114
75
115
76
fun , params , buffers = functorch .make_functional_with_buffers (model )
116
77
optim = torchopt .adam (lr , b1 = betas [0 ], b2 = betas [1 ], eps = eps , eps_root = 0.0 )
@@ -143,7 +104,7 @@ def test_adam(dtype: torch.dtype, lr: float, betas: Tuple[float, float], eps: fl
143
104
assert torch .allclose (b , b_ref ), f'{ b !r} != { b_ref !r} '
144
105
145
106
146
- @parametrize (
107
+ @helpers . parametrize (
147
108
dtype = [torch .float32 , torch .float64 ],
148
109
lr = [1e-3 , 1e-4 , 1e-5 ],
149
110
betas = [(0.9 , 0.999 ), (0.95 , 0.9995 )],
@@ -152,7 +113,7 @@ def test_adam(dtype: torch.dtype, lr: float, betas: Tuple[float, float], eps: fl
152
113
def test_accelerated_adam_cpu (
153
114
dtype : torch .dtype , lr : float , betas : Tuple [float , float ], eps : float
154
115
) -> None :
155
- model , model_ref , loader = get_models (device = 'cpu' , dtype = dtype )
116
+ model , model_ref , loader = helpers . get_models (device = 'cpu' , dtype = dtype )
156
117
157
118
fun , params , buffers = functorch .make_functional_with_buffers (model )
158
119
optim = torchopt .adam (
@@ -188,7 +149,7 @@ def test_accelerated_adam_cpu(
188
149
189
150
190
151
@pytest .mark .skipif (not torch .cuda .is_available (), reason = 'No CUDA device available.' )
191
- @parametrize (
152
+ @helpers . parametrize (
192
153
dtype = [torch .float32 , torch .float64 ],
193
154
lr = [1e-3 , 1e-4 , 1e-5 ],
194
155
betas = [(0.9 , 0.999 ), (0.95 , 0.9995 )],
@@ -198,7 +159,7 @@ def test_accelerated_adam_cuda(
198
159
dtype : torch .dtype , lr : float , betas : Tuple [float , float ], eps : float
199
160
) -> None :
200
161
device = 'cuda'
201
- model , model_ref , loader = get_models (device = device , dtype = dtype )
162
+ model , model_ref , loader = helpers . get_models (device = device , dtype = dtype )
202
163
203
164
fun , params , buffers = functorch .make_functional_with_buffers (model )
204
165
optim = torchopt .adam (
@@ -234,7 +195,7 @@ def test_accelerated_adam_cuda(
234
195
assert torch .allclose (b , b_ref ), f'{ b !r} != { b_ref !r} '
235
196
236
197
237
- @parametrize (
198
+ @helpers . parametrize (
238
199
dtype = [torch .float32 , torch .float64 ],
239
200
lr = [1e-3 , 1e-4 , 1e-5 ],
240
201
alpha = [0.9 , 0.99 ],
@@ -245,7 +206,7 @@ def test_accelerated_adam_cuda(
245
206
def test_rmsprop (
246
207
dtype : torch .dtype , lr : float , alpha : float , eps : float , momentum : float , centered : bool
247
208
) -> None :
248
- model , model_ref , loader = get_models (device = 'cpu' , dtype = dtype )
209
+ model , model_ref , loader = helpers . get_models (device = 'cpu' , dtype = dtype )
249
210
250
211
fun , params , buffers = functorch .make_functional_with_buffers (model )
251
212
optim = torchopt .rmsprop (
0 commit comments