Skip to content

Commit 922a33e

Browse files
committed
style(implicit): reformat
1 parent d8abe9a commit 922a33e

File tree

1 file changed

+40
-36
lines changed

1 file changed

+40
-36
lines changed

tests/test_implicit.py

Lines changed: 40 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -200,24 +200,26 @@ def outer_level(p, xs, ys):
200200

201201
for p, p_ref in zip(params, jax_params_as_tensor):
202202
helpers.assert_all_close(p, p_ref)
203-
203+
204204

205205
@torch.no_grad()
206206
def get_dataset_torch_rr(
207-
device: Optional[Union[str, torch.device]] = None) -> Tuple[nn.Module, data.DataLoader]:
207+
# device: Optional[Union[str, torch.device]] = None
208+
) -> Tuple[nn.Module, data.DataLoader]:
208209
helpers.seed_everything(seed=42)
209210
NUM_UPDATES = 4
210211
BATCH_SIZE = 1024
211212
dataset = data.TensorDataset(
212213
torch.randn((BATCH_SIZE * NUM_UPDATES, MODEL_NUM_INPUTS)),
213214
torch.randn((BATCH_SIZE * NUM_UPDATES)),
214215
torch.randn((BATCH_SIZE * NUM_UPDATES, MODEL_NUM_INPUTS)),
215-
torch.randn((BATCH_SIZE * NUM_UPDATES))
216+
torch.randn((BATCH_SIZE * NUM_UPDATES)),
216217
)
217218
loader = data.DataLoader(dataset, BATCH_SIZE, shuffle=False)
218219

219220
return loader
220221

222+
221223
@helpers.parametrize(
222224
lr=[1e-3, 1e-4],
223225
dtype=[torch.float64],
@@ -231,27 +233,27 @@ def test_rr(
231233
helpers.seed_everything(42)
232234
device = 'cpu'
233235
input_size = 10
234-
236+
235237
init_params_torch = torch.randn(input_size).to(device, dtype=dtype)
236238
l2reg_torch = torch.rand(1, requires_grad=True).to(device, dtype=dtype)
237-
239+
238240
init_params_jax = jnp.array(init_params_torch.detach().numpy(), dtype=jax_dtype)
239241
l2reg_jax = jnp.array(l2reg_torch.detach().numpy(), dtype=jax_dtype)
240-
242+
241243
loader = get_dataset_torch_rr(device='cpu')
242244

243245
optim = torchopt.sgd(lr)
244246
optim_state = optim.init(l2reg_torch)
245-
247+
246248
optim_jax = optax.sgd(lr)
247249
opt_state_jax = optim_jax.init(l2reg_jax)
248-
250+
249251
def ridge_objective_torch(params, l2reg, data):
250252
"""Ridge objective function."""
251253
x_tr, y_tr = data
252254
params = params
253-
residuals = x_tr @ params - y_tr
254-
return 0.5 * torch.mean(residuals ** 2) + 0.5 * l2reg.sum() * torch.sum(params ** 2)
255+
residuals = x_tr @ params - y_tr
256+
return 0.5 * torch.mean(residuals**2) + 0.5 * l2reg.sum() * torch.sum(params**2)
255257

256258
@torchopt.implicit_diff.custom_root(functorch.grad(ridge_objective_torch, argnums=0), argnums=1)
257259
def ridge_solver_torch(init_params, l2reg, data):
@@ -261,61 +263,63 @@ def ridge_solver_torch(init_params, l2reg, data):
261263
def matvec(u):
262264
return torch.matmul(X_tr.T, torch.matmul(X_tr, u))
263265

264-
return torchopt.linear_solve.solve_cg(matvec=matvec,
265-
b=torch.matmul(X_tr.T, y_tr),
266-
ridge=len(y_tr) * l2reg.item(),
267-
init=init_params,
268-
maxiter=20)
269-
266+
return torchopt.linear_solve.solve_cg(
267+
matvec=matvec,
268+
b=torch.matmul(X_tr.T, y_tr),
269+
ridge=len(y_tr) * l2reg.item(),
270+
init=init_params,
271+
maxiter=20,
272+
)
273+
270274
def ridge_objective_jax(params, l2reg, X_tr, y_tr):
271275
"""Ridge objective function."""
272-
#X_tr, y_tr = data
276+
# X_tr, y_tr = data
273277
residuals = jnp.dot(X_tr, params) - y_tr
274-
return 0.5 * jnp.mean(residuals ** 2) + 0.5 * jnp.sum(l2reg) * jnp.sum(params ** 2)
275-
278+
return 0.5 * jnp.mean(residuals**2) + 0.5 * jnp.sum(l2reg) * jnp.sum(params**2)
276279

277280
@jaxopt.implicit_diff.custom_root(jax.grad(ridge_objective_jax, argnums=0))
278281
def ridge_solver_jax(init_params, l2reg, X_tr, y_tr):
279-
"""Solve ridge regression by conjugate gradient."""
282+
"""Solve ridge regression by conjugate gradient."""
280283

281-
def matvec(u):
282-
return jnp.dot(X_tr.T, jnp.dot(X_tr, u))
284+
def matvec(u):
285+
return jnp.dot(X_tr.T, jnp.dot(X_tr, u))
286+
287+
return jaxopt.linear_solve.solve_cg(
288+
matvec=matvec,
289+
b=jnp.dot(X_tr.T, y_tr),
290+
ridge=len(y_tr) * l2reg.item(),
291+
init=init_params,
292+
maxiter=20,
293+
)
283294

284-
return jaxopt.linear_solve.solve_cg(matvec=matvec,
285-
b=jnp.dot(X_tr.T, y_tr),
286-
ridge=len(y_tr) * l2reg.item(),
287-
init=init_params,
288-
maxiter=20)
289-
290295
for xs, ys, xq, yq in loader:
291296
xs = xs.to(dtype=dtype)
292297
ys = ys.to(dtype=dtype)
293298
xq = xq.to(dtype=dtype)
294299
yq = yq.to(dtype=dtype)
295-
296-
data = (xs, ys)
297-
#print(init_params_torch.shape, l2reg_torch.shape, xs.shape, ys.shape)
300+
301+
# print(init_params_torch.shape, l2reg_torch.shape, xs.shape, ys.shape)
298302
w_fit = ridge_solver_torch(init_params_torch, l2reg_torch, (xs, ys))
299303
outer_loss = F.mse_loss(xq @ w_fit, yq)
300-
304+
301305
grad = torch.autograd.grad(outer_loss, l2reg_torch)[0]
302306
updates, optim_state = optim.update(grad, optim_state)
303307
l2reg_torch = torchopt.apply_updates(l2reg_torch, updates)
304-
308+
305309
xs = jnp.array(xs.numpy(), dtype=jax_dtype)
306310
ys = jnp.array(ys.numpy(), dtype=jax_dtype)
307311
xq = jnp.array(xq.numpy(), dtype=jax_dtype)
308312
yq = jnp.array(yq.numpy(), dtype=jax_dtype)
309-
313+
310314
def outer_level(init_params_jax, l2reg_jax, xs, ys, xq, yq):
311315
w_fit = ridge_solver_jax(init_params_jax, l2reg_jax, xs, ys)
312316
y_pred = jnp.dot(xq, w_fit)
313317
loss_value = jnp.mean((y_pred - yq) ** 2)
314318
return loss_value
315-
319+
316320
grads_jax = jax.grad(outer_level, argnums=1)(init_params_jax, l2reg_jax, xs, ys, xq, yq)
317321
updates_jax, opt_state_jax = optim_jax.update(grads_jax, opt_state_jax) # get updates
318322
jax_params = optax.apply_updates(l2reg_jax, updates_jax)
319-
323+
320324
jax_p = torch.tensor(np.array(jax_params)).to(dtype=dtype)
321325
helpers.assert_all_close(l2reg_torch, jax_p, rtol=5e-5, atol=5e-5)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy