Skip to content

Commit d8abe9a

Browse files
committed
new test implicit gradient
1 parent 02d28b9 commit d8abe9a

File tree

1 file changed

+119
-0
lines changed

1 file changed

+119
-0
lines changed

tests/test_implicit.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,122 @@ def outer_level(p, xs, ys):
200200

201201
for p, p_ref in zip(params, jax_params_as_tensor):
202202
helpers.assert_all_close(p, p_ref)
203+
204+
205+
@torch.no_grad()
206+
def get_dataset_torch_rr(
207+
device: Optional[Union[str, torch.device]] = None) -> Tuple[nn.Module, data.DataLoader]:
208+
helpers.seed_everything(seed=42)
209+
NUM_UPDATES = 4
210+
BATCH_SIZE = 1024
211+
dataset = data.TensorDataset(
212+
torch.randn((BATCH_SIZE * NUM_UPDATES, MODEL_NUM_INPUTS)),
213+
torch.randn((BATCH_SIZE * NUM_UPDATES)),
214+
torch.randn((BATCH_SIZE * NUM_UPDATES, MODEL_NUM_INPUTS)),
215+
torch.randn((BATCH_SIZE * NUM_UPDATES))
216+
)
217+
loader = data.DataLoader(dataset, BATCH_SIZE, shuffle=False)
218+
219+
return loader
220+
221+
@helpers.parametrize(
222+
lr=[1e-3, 1e-4],
223+
dtype=[torch.float64],
224+
jax_dtype=[jnp.float64],
225+
)
226+
def test_rr(
227+
lr: float,
228+
dtype: torch.dtype,
229+
jax_dtype: jnp.dtype,
230+
) -> None:
231+
helpers.seed_everything(42)
232+
device = 'cpu'
233+
input_size = 10
234+
235+
init_params_torch = torch.randn(input_size).to(device, dtype=dtype)
236+
l2reg_torch = torch.rand(1, requires_grad=True).to(device, dtype=dtype)
237+
238+
init_params_jax = jnp.array(init_params_torch.detach().numpy(), dtype=jax_dtype)
239+
l2reg_jax = jnp.array(l2reg_torch.detach().numpy(), dtype=jax_dtype)
240+
241+
loader = get_dataset_torch_rr(device='cpu')
242+
243+
optim = torchopt.sgd(lr)
244+
optim_state = optim.init(l2reg_torch)
245+
246+
optim_jax = optax.sgd(lr)
247+
opt_state_jax = optim_jax.init(l2reg_jax)
248+
249+
def ridge_objective_torch(params, l2reg, data):
250+
"""Ridge objective function."""
251+
x_tr, y_tr = data
252+
params = params
253+
residuals = x_tr @ params - y_tr
254+
return 0.5 * torch.mean(residuals ** 2) + 0.5 * l2reg.sum() * torch.sum(params ** 2)
255+
256+
@torchopt.implicit_diff.custom_root(functorch.grad(ridge_objective_torch, argnums=0), argnums=1)
257+
def ridge_solver_torch(init_params, l2reg, data):
258+
"""Solve ridge regression by conjugate gradient."""
259+
X_tr, y_tr = data
260+
261+
def matvec(u):
262+
return torch.matmul(X_tr.T, torch.matmul(X_tr, u))
263+
264+
return torchopt.linear_solve.solve_cg(matvec=matvec,
265+
b=torch.matmul(X_tr.T, y_tr),
266+
ridge=len(y_tr) * l2reg.item(),
267+
init=init_params,
268+
maxiter=20)
269+
270+
def ridge_objective_jax(params, l2reg, X_tr, y_tr):
271+
"""Ridge objective function."""
272+
#X_tr, y_tr = data
273+
residuals = jnp.dot(X_tr, params) - y_tr
274+
return 0.5 * jnp.mean(residuals ** 2) + 0.5 * jnp.sum(l2reg) * jnp.sum(params ** 2)
275+
276+
277+
@jaxopt.implicit_diff.custom_root(jax.grad(ridge_objective_jax, argnums=0))
278+
def ridge_solver_jax(init_params, l2reg, X_tr, y_tr):
279+
"""Solve ridge regression by conjugate gradient."""
280+
281+
def matvec(u):
282+
return jnp.dot(X_tr.T, jnp.dot(X_tr, u))
283+
284+
return jaxopt.linear_solve.solve_cg(matvec=matvec,
285+
b=jnp.dot(X_tr.T, y_tr),
286+
ridge=len(y_tr) * l2reg.item(),
287+
init=init_params,
288+
maxiter=20)
289+
290+
for xs, ys, xq, yq in loader:
291+
xs = xs.to(dtype=dtype)
292+
ys = ys.to(dtype=dtype)
293+
xq = xq.to(dtype=dtype)
294+
yq = yq.to(dtype=dtype)
295+
296+
data = (xs, ys)
297+
#print(init_params_torch.shape, l2reg_torch.shape, xs.shape, ys.shape)
298+
w_fit = ridge_solver_torch(init_params_torch, l2reg_torch, (xs, ys))
299+
outer_loss = F.mse_loss(xq @ w_fit, yq)
300+
301+
grad = torch.autograd.grad(outer_loss, l2reg_torch)[0]
302+
updates, optim_state = optim.update(grad, optim_state)
303+
l2reg_torch = torchopt.apply_updates(l2reg_torch, updates)
304+
305+
xs = jnp.array(xs.numpy(), dtype=jax_dtype)
306+
ys = jnp.array(ys.numpy(), dtype=jax_dtype)
307+
xq = jnp.array(xq.numpy(), dtype=jax_dtype)
308+
yq = jnp.array(yq.numpy(), dtype=jax_dtype)
309+
310+
def outer_level(init_params_jax, l2reg_jax, xs, ys, xq, yq):
311+
w_fit = ridge_solver_jax(init_params_jax, l2reg_jax, xs, ys)
312+
y_pred = jnp.dot(xq, w_fit)
313+
loss_value = jnp.mean((y_pred - yq) ** 2)
314+
return loss_value
315+
316+
grads_jax = jax.grad(outer_level, argnums=1)(init_params_jax, l2reg_jax, xs, ys, xq, yq)
317+
updates_jax, opt_state_jax = optim_jax.update(grads_jax, opt_state_jax) # get updates
318+
jax_params = optax.apply_updates(l2reg_jax, updates_jax)
319+
320+
jax_p = torch.tensor(np.array(jax_params)).to(dtype=dtype)
321+
helpers.assert_all_close(l2reg_torch, jax_p, rtol=5e-5, atol=5e-5)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy