@@ -200,24 +200,26 @@ def outer_level(p, xs, ys):
200
200
201
201
for p , p_ref in zip (params , jax_params_as_tensor ):
202
202
helpers .assert_all_close (p , p_ref )
203
-
203
+
204
204
205
205
@torch .no_grad ()
206
206
def get_dataset_torch_rr (
207
- device : Optional [Union [str , torch .device ]] = None ) -> Tuple [nn .Module , data .DataLoader ]:
207
+ # device: Optional[Union[str, torch.device]] = None
208
+ ) -> Tuple [nn .Module , data .DataLoader ]:
208
209
helpers .seed_everything (seed = 42 )
209
210
NUM_UPDATES = 4
210
211
BATCH_SIZE = 1024
211
212
dataset = data .TensorDataset (
212
213
torch .randn ((BATCH_SIZE * NUM_UPDATES , MODEL_NUM_INPUTS )),
213
214
torch .randn ((BATCH_SIZE * NUM_UPDATES )),
214
215
torch .randn ((BATCH_SIZE * NUM_UPDATES , MODEL_NUM_INPUTS )),
215
- torch .randn ((BATCH_SIZE * NUM_UPDATES ))
216
+ torch .randn ((BATCH_SIZE * NUM_UPDATES )),
216
217
)
217
218
loader = data .DataLoader (dataset , BATCH_SIZE , shuffle = False )
218
219
219
220
return loader
220
221
222
+
221
223
@helpers .parametrize (
222
224
lr = [1e-3 , 1e-4 ],
223
225
dtype = [torch .float64 ],
@@ -231,27 +233,27 @@ def test_rr(
231
233
helpers .seed_everything (42 )
232
234
device = 'cpu'
233
235
input_size = 10
234
-
236
+
235
237
init_params_torch = torch .randn (input_size ).to (device , dtype = dtype )
236
238
l2reg_torch = torch .rand (1 , requires_grad = True ).to (device , dtype = dtype )
237
-
239
+
238
240
init_params_jax = jnp .array (init_params_torch .detach ().numpy (), dtype = jax_dtype )
239
241
l2reg_jax = jnp .array (l2reg_torch .detach ().numpy (), dtype = jax_dtype )
240
-
242
+
241
243
loader = get_dataset_torch_rr (device = 'cpu' )
242
244
243
245
optim = torchopt .sgd (lr )
244
246
optim_state = optim .init (l2reg_torch )
245
-
247
+
246
248
optim_jax = optax .sgd (lr )
247
249
opt_state_jax = optim_jax .init (l2reg_jax )
248
-
250
+
249
251
def ridge_objective_torch (params , l2reg , data ):
250
252
"""Ridge objective function."""
251
253
x_tr , y_tr = data
252
254
params = params
253
- residuals = x_tr @ params - y_tr
254
- return 0.5 * torch .mean (residuals ** 2 ) + 0.5 * l2reg .sum () * torch .sum (params ** 2 )
255
+ residuals = x_tr @ params - y_tr
256
+ return 0.5 * torch .mean (residuals ** 2 ) + 0.5 * l2reg .sum () * torch .sum (params ** 2 )
255
257
256
258
@torchopt .implicit_diff .custom_root (functorch .grad (ridge_objective_torch , argnums = 0 ), argnums = 1 )
257
259
def ridge_solver_torch (init_params , l2reg , data ):
@@ -261,61 +263,63 @@ def ridge_solver_torch(init_params, l2reg, data):
261
263
def matvec (u ):
262
264
return torch .matmul (X_tr .T , torch .matmul (X_tr , u ))
263
265
264
- return torchopt .linear_solve .solve_cg (matvec = matvec ,
265
- b = torch .matmul (X_tr .T , y_tr ),
266
- ridge = len (y_tr ) * l2reg .item (),
267
- init = init_params ,
268
- maxiter = 20 )
269
-
266
+ return torchopt .linear_solve .solve_cg (
267
+ matvec = matvec ,
268
+ b = torch .matmul (X_tr .T , y_tr ),
269
+ ridge = len (y_tr ) * l2reg .item (),
270
+ init = init_params ,
271
+ maxiter = 20 ,
272
+ )
273
+
270
274
def ridge_objective_jax (params , l2reg , X_tr , y_tr ):
271
275
"""Ridge objective function."""
272
- #X_tr, y_tr = data
276
+ # X_tr, y_tr = data
273
277
residuals = jnp .dot (X_tr , params ) - y_tr
274
- return 0.5 * jnp .mean (residuals ** 2 ) + 0.5 * jnp .sum (l2reg ) * jnp .sum (params ** 2 )
275
-
278
+ return 0.5 * jnp .mean (residuals ** 2 ) + 0.5 * jnp .sum (l2reg ) * jnp .sum (params ** 2 )
276
279
277
280
@jaxopt .implicit_diff .custom_root (jax .grad (ridge_objective_jax , argnums = 0 ))
278
281
def ridge_solver_jax (init_params , l2reg , X_tr , y_tr ):
279
- """Solve ridge regression by conjugate gradient."""
282
+ """Solve ridge regression by conjugate gradient."""
280
283
281
- def matvec (u ):
282
- return jnp .dot (X_tr .T , jnp .dot (X_tr , u ))
284
+ def matvec (u ):
285
+ return jnp .dot (X_tr .T , jnp .dot (X_tr , u ))
286
+
287
+ return jaxopt .linear_solve .solve_cg (
288
+ matvec = matvec ,
289
+ b = jnp .dot (X_tr .T , y_tr ),
290
+ ridge = len (y_tr ) * l2reg .item (),
291
+ init = init_params ,
292
+ maxiter = 20 ,
293
+ )
283
294
284
- return jaxopt .linear_solve .solve_cg (matvec = matvec ,
285
- b = jnp .dot (X_tr .T , y_tr ),
286
- ridge = len (y_tr ) * l2reg .item (),
287
- init = init_params ,
288
- maxiter = 20 )
289
-
290
295
for xs , ys , xq , yq in loader :
291
296
xs = xs .to (dtype = dtype )
292
297
ys = ys .to (dtype = dtype )
293
298
xq = xq .to (dtype = dtype )
294
299
yq = yq .to (dtype = dtype )
295
-
296
- data = (xs , ys )
297
- #print(init_params_torch.shape, l2reg_torch.shape, xs.shape, ys.shape)
300
+
301
+ # print(init_params_torch.shape, l2reg_torch.shape, xs.shape, ys.shape)
298
302
w_fit = ridge_solver_torch (init_params_torch , l2reg_torch , (xs , ys ))
299
303
outer_loss = F .mse_loss (xq @ w_fit , yq )
300
-
304
+
301
305
grad = torch .autograd .grad (outer_loss , l2reg_torch )[0 ]
302
306
updates , optim_state = optim .update (grad , optim_state )
303
307
l2reg_torch = torchopt .apply_updates (l2reg_torch , updates )
304
-
308
+
305
309
xs = jnp .array (xs .numpy (), dtype = jax_dtype )
306
310
ys = jnp .array (ys .numpy (), dtype = jax_dtype )
307
311
xq = jnp .array (xq .numpy (), dtype = jax_dtype )
308
312
yq = jnp .array (yq .numpy (), dtype = jax_dtype )
309
-
313
+
310
314
def outer_level (init_params_jax , l2reg_jax , xs , ys , xq , yq ):
311
315
w_fit = ridge_solver_jax (init_params_jax , l2reg_jax , xs , ys )
312
316
y_pred = jnp .dot (xq , w_fit )
313
317
loss_value = jnp .mean ((y_pred - yq ) ** 2 )
314
318
return loss_value
315
-
319
+
316
320
grads_jax = jax .grad (outer_level , argnums = 1 )(init_params_jax , l2reg_jax , xs , ys , xq , yq )
317
321
updates_jax , opt_state_jax = optim_jax .update (grads_jax , opt_state_jax ) # get updates
318
322
jax_params = optax .apply_updates (l2reg_jax , updates_jax )
319
-
323
+
320
324
jax_p = torch .tensor (np .array (jax_params )).to (dtype = dtype )
321
325
helpers .assert_all_close (l2reg_torch , jax_p , rtol = 5e-5 , atol = 5e-5 )
0 commit comments