@@ -200,3 +200,122 @@ def outer_level(p, xs, ys):
200
200
201
201
for p , p_ref in zip (params , jax_params_as_tensor ):
202
202
helpers .assert_all_close (p , p_ref )
203
+
204
+
205
+ @torch .no_grad ()
206
+ def get_dataset_torch_rr (
207
+ device : Optional [Union [str , torch .device ]] = None ) -> Tuple [nn .Module , data .DataLoader ]:
208
+ helpers .seed_everything (seed = 42 )
209
+ NUM_UPDATES = 4
210
+ BATCH_SIZE = 1024
211
+ dataset = data .TensorDataset (
212
+ torch .randn ((BATCH_SIZE * NUM_UPDATES , MODEL_NUM_INPUTS )),
213
+ torch .randn ((BATCH_SIZE * NUM_UPDATES )),
214
+ torch .randn ((BATCH_SIZE * NUM_UPDATES , MODEL_NUM_INPUTS )),
215
+ torch .randn ((BATCH_SIZE * NUM_UPDATES ))
216
+ )
217
+ loader = data .DataLoader (dataset , BATCH_SIZE , shuffle = False )
218
+
219
+ return loader
220
+
221
+ @helpers .parametrize (
222
+ lr = [1e-3 , 1e-4 ],
223
+ dtype = [torch .float64 ],
224
+ jax_dtype = [jnp .float64 ],
225
+ )
226
+ def test_rr (
227
+ lr : float ,
228
+ dtype : torch .dtype ,
229
+ jax_dtype : jnp .dtype ,
230
+ ) -> None :
231
+ helpers .seed_everything (42 )
232
+ device = 'cpu'
233
+ input_size = 10
234
+
235
+ init_params_torch = torch .randn (input_size ).to (device , dtype = dtype )
236
+ l2reg_torch = torch .rand (1 , requires_grad = True ).to (device , dtype = dtype )
237
+
238
+ init_params_jax = jnp .array (init_params_torch .detach ().numpy (), dtype = jax_dtype )
239
+ l2reg_jax = jnp .array (l2reg_torch .detach ().numpy (), dtype = jax_dtype )
240
+
241
+ loader = get_dataset_torch_rr (device = 'cpu' )
242
+
243
+ optim = torchopt .sgd (lr )
244
+ optim_state = optim .init (l2reg_torch )
245
+
246
+ optim_jax = optax .sgd (lr )
247
+ opt_state_jax = optim_jax .init (l2reg_jax )
248
+
249
+ def ridge_objective_torch (params , l2reg , data ):
250
+ """Ridge objective function."""
251
+ x_tr , y_tr = data
252
+ params = params
253
+ residuals = x_tr @ params - y_tr
254
+ return 0.5 * torch .mean (residuals ** 2 ) + 0.5 * l2reg .sum () * torch .sum (params ** 2 )
255
+
256
+ @torchopt .implicit_diff .custom_root (functorch .grad (ridge_objective_torch , argnums = 0 ), argnums = 1 )
257
+ def ridge_solver_torch (init_params , l2reg , data ):
258
+ """Solve ridge regression by conjugate gradient."""
259
+ X_tr , y_tr = data
260
+
261
+ def matvec (u ):
262
+ return torch .matmul (X_tr .T , torch .matmul (X_tr , u ))
263
+
264
+ return torchopt .linear_solve .solve_cg (matvec = matvec ,
265
+ b = torch .matmul (X_tr .T , y_tr ),
266
+ ridge = len (y_tr ) * l2reg .item (),
267
+ init = init_params ,
268
+ maxiter = 20 )
269
+
270
+ def ridge_objective_jax (params , l2reg , X_tr , y_tr ):
271
+ """Ridge objective function."""
272
+ #X_tr, y_tr = data
273
+ residuals = jnp .dot (X_tr , params ) - y_tr
274
+ return 0.5 * jnp .mean (residuals ** 2 ) + 0.5 * jnp .sum (l2reg ) * jnp .sum (params ** 2 )
275
+
276
+
277
+ @jaxopt .implicit_diff .custom_root (jax .grad (ridge_objective_jax , argnums = 0 ))
278
+ def ridge_solver_jax (init_params , l2reg , X_tr , y_tr ):
279
+ """Solve ridge regression by conjugate gradient."""
280
+
281
+ def matvec (u ):
282
+ return jnp .dot (X_tr .T , jnp .dot (X_tr , u ))
283
+
284
+ return jaxopt .linear_solve .solve_cg (matvec = matvec ,
285
+ b = jnp .dot (X_tr .T , y_tr ),
286
+ ridge = len (y_tr ) * l2reg .item (),
287
+ init = init_params ,
288
+ maxiter = 20 )
289
+
290
+ for xs , ys , xq , yq in loader :
291
+ xs = xs .to (dtype = dtype )
292
+ ys = ys .to (dtype = dtype )
293
+ xq = xq .to (dtype = dtype )
294
+ yq = yq .to (dtype = dtype )
295
+
296
+ data = (xs , ys )
297
+ #print(init_params_torch.shape, l2reg_torch.shape, xs.shape, ys.shape)
298
+ w_fit = ridge_solver_torch (init_params_torch , l2reg_torch , (xs , ys ))
299
+ outer_loss = F .mse_loss (xq @ w_fit , yq )
300
+
301
+ grad = torch .autograd .grad (outer_loss , l2reg_torch )[0 ]
302
+ updates , optim_state = optim .update (grad , optim_state )
303
+ l2reg_torch = torchopt .apply_updates (l2reg_torch , updates )
304
+
305
+ xs = jnp .array (xs .numpy (), dtype = jax_dtype )
306
+ ys = jnp .array (ys .numpy (), dtype = jax_dtype )
307
+ xq = jnp .array (xq .numpy (), dtype = jax_dtype )
308
+ yq = jnp .array (yq .numpy (), dtype = jax_dtype )
309
+
310
+ def outer_level (init_params_jax , l2reg_jax , xs , ys , xq , yq ):
311
+ w_fit = ridge_solver_jax (init_params_jax , l2reg_jax , xs , ys )
312
+ y_pred = jnp .dot (xq , w_fit )
313
+ loss_value = jnp .mean ((y_pred - yq ) ** 2 )
314
+ return loss_value
315
+
316
+ grads_jax = jax .grad (outer_level , argnums = 1 )(init_params_jax , l2reg_jax , xs , ys , xq , yq )
317
+ updates_jax , opt_state_jax = optim_jax .update (grads_jax , opt_state_jax ) # get updates
318
+ jax_params = optax .apply_updates (l2reg_jax , updates_jax )
319
+
320
+ jax_p = torch .tensor (np .array (jax_params )).to (dtype = dtype )
321
+ helpers .assert_all_close (l2reg_torch , jax_p , rtol = 5e-5 , atol = 5e-5 )
0 commit comments