diff --git a/CHANGELOG.md b/CHANGELOG.md index 644cd7eb..a01f6751 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fix implicit MAML omniglot few-shot classification example by [@XuehaiPan](https://github.com/XuehaiPan) in [#108](https://github.com/metaopt/torchopt/pull/108). - Align results of distributed examples by [@XuehaiPan](https://github.com/XuehaiPan) in [#95](https://github.com/metaopt/torchopt/pull/95). - Fix `None` in module containers by [@XuehaiPan](https://github.com/XuehaiPan). - Fix backward errors when using inplace `sqrt_` and `add_` by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@JieRen98](https://github.com/JieRen98) and [@XuehaiPan](https://github.com/XuehaiPan). diff --git a/examples/iMAML/imaml-accs-functional.png b/examples/iMAML/imaml-accs-functional.png index a23132cf..34922bc0 100644 Binary files a/examples/iMAML/imaml-accs-functional.png and b/examples/iMAML/imaml-accs-functional.png differ diff --git a/examples/iMAML/imaml-accs.png b/examples/iMAML/imaml-accs.png index c0296be8..1a6a5636 100644 Binary files a/examples/iMAML/imaml-accs.png and b/examples/iMAML/imaml-accs.png differ diff --git a/examples/iMAML/imaml_omniglot_functional.py b/examples/iMAML/imaml_omniglot_functional.py index 080541c6..88314366 100644 --- a/examples/iMAML/imaml_omniglot_functional.py +++ b/examples/iMAML/imaml_omniglot_functional.py @@ -101,23 +101,21 @@ def main(): # We will use Adam to (meta-)optimize the initial parameters # to be adapted. net.train() - fnet, params = functorch.make_functional(net) + fnet, meta_params = model = functorch.make_functional(net) meta_opt = torchopt.adam(lr=1e-3) - meta_opt_state = meta_opt.init(params) + meta_opt_state = meta_opt.init(meta_params) log = [] - test(db, [params, fnet], epoch=-1, log=log, args=args) + test(db, model, epoch=-1, log=log, args=args) for epoch in range(10): - meta_opt, meta_opt_state = train( - db, [params, fnet], (meta_opt, meta_opt_state), epoch, log, args - ) - test(db, [params, fnet], epoch, log, args) + meta_opt, meta_opt_state = train(db, model, (meta_opt, meta_opt_state), epoch, log, args) + test(db, model, epoch, log, args) plot(log) -def train(db, net, meta_opt_and_state, epoch, log, args): +def train(db, model, meta_opt_and_state, epoch, log, args): n_train_iter = db.x_train.shape[0] // db.batchsz - params, fnet = net + fnet, meta_params = model meta_opt, meta_opt_state = meta_opt_and_state # Given this module we've created, rip out the parameters and buffers # and return a functional version of the module. `fnet` is stateless @@ -133,21 +131,22 @@ def train(db, net, meta_opt_and_state, epoch, log, args): n_inner_iter = args.inner_steps reg_param = args.reg_params + qry_losses = [] qry_accs = [] - init_params_copy = pytree.tree_map( - lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), params - ) - for i in range(task_num): # Optimize the likelihood of the support set by taking # gradient steps w.r.t. the model's parameters. # This adapts the model's meta-parameters to the task. + init_params = pytree.tree_map( + lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), + meta_params, + ) optimal_params = train_imaml_inner_solver( - init_params_copy, - params, + init_params, + meta_params, (x_spt[i], y_spt[i]), (fnet, n_inner_iter, reg_param), ) @@ -156,17 +155,15 @@ def train(db, net, meta_opt_and_state, epoch, log, args): # These will be used to update the model's meta-parameters. qry_logits = fnet(optimal_params, x_qry[i]) qry_loss = F.cross_entropy(qry_logits, y_qry[i]) - # Update the model's meta-parameters to optimize the query - # losses across all of the tasks sampled in this batch. - # qry_loss = qry_loss / task_num # scale gradients - meta_grads = torch.autograd.grad(qry_loss / task_num, params) - meta_updates, meta_opt_state = meta_opt.update(meta_grads, meta_opt_state) - params = torchopt.apply_updates(params, meta_updates) qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean() - qry_losses.append(qry_loss.item()) + qry_losses.append(qry_loss) qry_accs.append(qry_acc.item()) - qry_losses = np.mean(qry_losses) + qry_losses = torch.mean(torch.stack(qry_losses)) + meta_grads = torch.autograd.grad(qry_losses, meta_params) + meta_updates, meta_opt_state = meta_opt.update(meta_grads, meta_opt_state) + meta_params = torchopt.apply_updates(meta_params, meta_updates) + qry_losses = qry_losses.item() qry_accs = 100.0 * np.mean(qry_accs) i = epoch + float(batch_idx) / n_train_iter iter_time = time.time() - start_time @@ -188,26 +185,19 @@ def train(db, net, meta_opt_and_state, epoch, log, args): return (meta_opt, meta_opt_state) -def test(db, net, epoch, log, args): +def test(db, model, epoch, log, args): # Crucially in our testing procedure here, we do *not* fine-tune # the model during testing for simplicity. # Most research papers using MAML for this task do an extra # stage of fine-tuning here that should be added if you are # adapting this code for research. - params, fnet = net - # fnet, params, buffers = functorch.make_functional_with_buffers(net) + fnet, meta_params = model n_test_iter = db.x_test.shape[0] // db.batchsz - qry_losses = [] - qry_accs = [] - - # TODO: Maybe pull this out into a separate module so it - # doesn't have to be duplicated between `train` and `test`? n_inner_iter = args.inner_steps reg_param = args.reg_params - init_params_copy = pytree.tree_map( - lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), params - ) + qry_losses = [] + qry_accs = [] for batch_idx in range(n_test_iter): x_spt, y_spt, x_qry, y_qry = db.next('test') @@ -219,9 +209,13 @@ def test(db, net, epoch, log, args): # gradient steps w.r.t. the model's parameters. # This adapts the model's meta-parameters to the task. + init_params = pytree.tree_map( + lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), + meta_params, + ) optimal_params = test_imaml_inner_solver( - init_params_copy, - params, + init_params, + meta_params, (x_spt[i], y_spt[i]), (fnet, n_inner_iter, reg_param), ) @@ -249,12 +243,12 @@ def test(db, net, epoch, log, args): ) -def imaml_objective(optimal_params, init_params, data, aux): +def imaml_objective(params, meta_params, data, aux): x_spt, y_spt = data fnet, n_inner_iter, reg_param = aux - y_pred = fnet(optimal_params, x_spt) + y_pred = fnet(params, x_spt) regularization_loss = 0 - for p1, p2 in zip(optimal_params, init_params): + for p1, p2 in zip(params, meta_params): regularization_loss += 0.5 * reg_param * torch.sum(torch.square(p1 - p2)) loss = F.cross_entropy(y_pred, y_spt) + regularization_loss return loss @@ -266,11 +260,10 @@ def imaml_objective(optimal_params, init_params, data, aux): has_aux=False, solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0), ) -def train_imaml_inner_solver(init_params_copy, init_params, data, aux): +def train_imaml_inner_solver(params, meta_params, data, aux): x_spt, y_spt = data fnet, n_inner_iter, reg_param = aux # Initial functional optimizer based on TorchOpt - params = init_params_copy inner_opt = torchopt.sgd(lr=1e-1) inner_opt_state = inner_opt.init(params) with torch.enable_grad(): @@ -280,20 +273,21 @@ def train_imaml_inner_solver(init_params_copy, init_params, data, aux): loss = F.cross_entropy(pred, y_spt) # compute loss # Compute regularization loss regularization_loss = 0 - for p1, p2 in zip(params, init_params): + for p1, p2 in zip(params, meta_params): regularization_loss += 0.5 * reg_param * torch.sum(torch.square(p1 - p2)) final_loss = loss + regularization_loss grads = torch.autograd.grad(final_loss, params) # compute gradients - updates, inner_opt_state = inner_opt.update(grads, inner_opt_state) # get updates - params = torchopt.apply_updates(params, updates) + updates, inner_opt_state = inner_opt.update( + grads, inner_opt_state, inplace=True + ) # get updates + params = torchopt.apply_updates(params, updates, inplace=True) return params -def test_imaml_inner_solver(init_params_copy, init_params, data, aux): +def test_imaml_inner_solver(params, meta_params, data, aux): x_spt, y_spt = data fnet, n_inner_iter, reg_param = aux # Initial functional optimizer based on TorchOpt - params = init_params_copy inner_opt = torchopt.sgd(lr=1e-1) inner_opt_state = inner_opt.init(params) with torch.enable_grad(): @@ -303,12 +297,14 @@ def test_imaml_inner_solver(init_params_copy, init_params, data, aux): loss = F.cross_entropy(pred, y_spt) # compute loss # Compute regularization loss regularization_loss = 0 - for p1, p2 in zip(params, init_params): + for p1, p2 in zip(params, meta_params): regularization_loss += 0.5 * reg_param * torch.sum(torch.square(p1 - p2)) final_loss = loss + regularization_loss grads = torch.autograd.grad(final_loss, params) # compute gradients - updates, inner_opt_state = inner_opt.update(grads, inner_opt_state) # get updates - params = torchopt.apply_updates(params, updates) + updates, inner_opt_state = inner_opt.update( + grads, inner_opt_state, inplace=True + ) # get updates + params = torchopt.apply_updates(params, updates, inplace=True) return params diff --git a/tests/test_implicit.py b/tests/test_implicit.py index 06d180d4..661a6627 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -126,11 +126,11 @@ def test_imaml(dtype: torch.dtype, lr: float, inner_lr: float, inner_update: int optim_jax = optax.sgd(lr) optim_state_jax = optim_jax.init(jax_params) - def imaml_objective_torchopt(optimal_params, init_params, data): + def imaml_objective_torchopt(params, meta_params, data): x, y, f = data - y_pred = f(optimal_params, x) + y_pred = f(params, x) regularization_loss = 0 - for p1, p2 in zip(optimal_params, init_params): + for p1, p2 in zip(params, meta_params): regularization_loss += 0.5 * torch.sum(torch.square(p1 - p2)) loss = F.cross_entropy(y_pred, y) + regularization_loss return loss @@ -138,10 +138,9 @@ def imaml_objective_torchopt(optimal_params, init_params, data): @torchopt.diff.implicit.custom_root( functorch.grad(imaml_objective_torchopt, argnums=0), argnums=1, has_aux=True ) - def inner_solver_torchopt(init_params_copy, init_params, data): + def inner_solver_torchopt(params, meta_params, data): # Initial functional optimizer based on TorchOpt x, y, f = data - params = init_params_copy optimizer = torchopt.sgd(lr=inner_lr) opt_state = optimizer.init(params) with torch.enable_grad(): @@ -151,43 +150,42 @@ def inner_solver_torchopt(init_params_copy, init_params, data): loss = F.cross_entropy(pred, y) # compute loss # Compute regularization loss regularization_loss = 0 - for p1, p2 in zip(params, init_params): + for p1, p2 in zip(params, meta_params): regularization_loss += 0.5 * torch.sum(torch.square(p1 - p2)) final_loss = loss + regularization_loss grads = torch.autograd.grad(final_loss, params) # compute gradients - updates, opt_state = optimizer.update(grads, opt_state) # get updates - params = torchopt.apply_updates(params, updates) + updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get updates + params = torchopt.apply_updates(params, updates, inplace=True) return params, (0, {'a': 1, 'b': 2}) - def imaml_objective_jax(optimal_params, init_params, x, y): - y_pred = jax_model(optimal_params, x) + def imaml_objective_jax(params, meta_params, x, y): + y_pred = jax_model(params, x) loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(y_pred, y)) regularization_loss = 0 - for p1, p2 in zip(optimal_params.values(), init_params.values()): + for p1, p2 in zip(params.values(), meta_params.values()): regularization_loss += 0.5 * jnp.sum(jnp.square((p1 - p2))) loss = loss + regularization_loss return loss @jaxopt.implicit_diff.custom_root(jax.grad(imaml_objective_jax, argnums=0), has_aux=True) - def inner_solver_jax(init_params_copy, init_params, x, y): + def inner_solver_jax(params, meta_params, x, y): """Solve ridge regression by conjugate gradient.""" # Initial functional optimizer based on torchopt - params = init_params_copy optimizer = optax.sgd(inner_lr) opt_state = optimizer.init(params) - def compute_loss(params, init_params, x, y): + def compute_loss(params, meta_params, x, y): pred = jax_model(params, x) loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(pred, y)) # Compute regularization loss regularization_loss = 0 - for p1, p2 in zip(params.values(), init_params.values()): + for p1, p2 in zip(params.values(), meta_params.values()): regularization_loss += 0.5 * jnp.sum(jnp.square((p1 - p2))) final_loss = loss + regularization_loss return final_loss for i in range(inner_update): - grads = jax.grad(compute_loss)(params, init_params, x, y) # compute gradients + grads = jax.grad(compute_loss)(params, meta_params, x, y) # compute gradients updates, opt_state = optimizer.update(grads, opt_state) # get updates params = optax.apply_updates(params, updates) return params, (0, {'a': 1, 'b': 2}) @@ -195,10 +193,10 @@ def compute_loss(params, init_params, x, y): for xs, ys in loader: xs = xs.to(dtype=dtype) data = (xs, ys, fmodel) - init_params_copy = pytree.tree_map( + meta_params_copy = pytree.tree_map( lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), params ) - optimal_params, aux = inner_solver_torchopt(init_params_copy, params, data) + optimal_params, aux = inner_solver_torchopt(meta_params_copy, params, data) assert aux == (0, {'a': 1, 'b': 2}) outer_loss = fmodel(optimal_params, xs).mean() @@ -275,35 +273,34 @@ def solve(self, x, y): optim_jax = optax.sgd(lr) optim_state_jax = optim_jax.init(jax_params) - def imaml_objective_jax(optimal_params, init_params, x, y): - y_pred = jax_model(optimal_params, x) + def imaml_objective_jax(params, meta_params, x, y): + y_pred = jax_model(params, x) loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(y_pred, y)) regularization_loss = 0 - for p1, p2 in zip(optimal_params.values(), init_params.values()): + for p1, p2 in zip(params.values(), meta_params.values()): regularization_loss += 0.5 * jnp.sum(jnp.square((p1 - p2))) loss = loss + regularization_loss return loss @jaxopt.implicit_diff.custom_root(jax.grad(imaml_objective_jax, argnums=0), has_aux=True) - def inner_solver_jax(init_params_copy, init_params, x, y): + def inner_solver_jax(params, meta_params, x, y): """Solve ridge regression by conjugate gradient.""" # Initial functional optimizer based on torchopt - params = init_params_copy optimizer = optax.sgd(inner_lr) opt_state = optimizer.init(params) - def compute_loss(params, init_params, x, y): + def compute_loss(params, meta_params, x, y): pred = jax_model(params, x) loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(pred, y)) # Compute regularization loss regularization_loss = 0 - for p1, p2 in zip(params.values(), init_params.values()): + for p1, p2 in zip(params.values(), meta_params.values()): regularization_loss += 0.5 * jnp.sum(jnp.square((p1 - p2))) final_loss = loss + regularization_loss return final_loss for i in range(inner_update): - grads = jax.grad(compute_loss)(params, init_params, x, y) # compute gradients + grads = jax.grad(compute_loss)(params, meta_params, x, y) # compute gradients updates, opt_state = optimizer.update(grads, opt_state) # get updates params = optax.apply_updates(params, updates) return params, (0, {'a': 1, 'b': 2}) @@ -374,7 +371,7 @@ def ridge_objective_torch(params, l2reg, data): return 0.5 * torch.mean(torch.square(residuals)) + regularization_loss @torchopt.diff.implicit.custom_root(functorch.grad(ridge_objective_torch, argnums=0), argnums=1) - def ridge_solver_torch(init_params, l2reg, data): + def ridge_solver_torch(params, l2reg, data): """Solve ridge regression by conjugate gradient.""" X_tr, y_tr = data @@ -383,7 +380,7 @@ def matvec(u): solve = torchopt.linear_solve.solve_cg( ridge=len(y_tr) * l2reg.item(), - init=init_params, + init=params, maxiter=20, ) @@ -396,7 +393,7 @@ def ridge_objective_jax(params, l2reg, X_tr, y_tr): return 0.5 * jnp.mean(jnp.square(residuals)) + regularization_loss @jaxopt.implicit_diff.custom_root(jax.grad(ridge_objective_jax, argnums=0)) - def ridge_solver_jax(init_params, l2reg, X_tr, y_tr): + def ridge_solver_jax(params, l2reg, X_tr, y_tr): """Solve ridge regression by conjugate gradient.""" def matvec(u): @@ -406,7 +403,7 @@ def matvec(u): matvec=matvec, b=X_tr.T @ y_tr, ridge=len(y_tr) * l2reg.item(), - init=init_params, + init=params, maxiter=20, ) @@ -428,8 +425,8 @@ def matvec(u): xq = jnp.array(xq.numpy(), dtype=np_dtype) yq = jnp.array(yq.numpy(), dtype=np_dtype) - def outer_level(init_params_jax, l2reg_jax, xs, ys, xq, yq): - w_fit = ridge_solver_jax(init_params_jax, l2reg_jax, xs, ys) + def outer_level(params_jax, l2reg_jax, xs, ys, xq, yq): + w_fit = ridge_solver_jax(params_jax, l2reg_jax, xs, ys) y_pred = xq @ w_fit loss_value = jnp.mean(jnp.square(y_pred - yq)) return loss_value diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py index 4ad48d32..47527b1a 100644 --- a/torchopt/diff/implicit/nn/module.py +++ b/torchopt/diff/implicit/nn/module.py @@ -14,9 +14,12 @@ # ============================================================================== """The base class for differentiable implicit meta-gradient models.""" +# pylint: disable=redefined-builtin + +import contextlib import functools import itertools -from typing import Any, Callable, Dict, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, Type, Union import functorch import torch @@ -31,19 +34,43 @@ __all__ = ['ImplicitMetaGradientModule'] +def update_containers( + dst_containers: Iterable[Dict[str, Optional[torch.Tensor]]], + src_containers: Iterable[Dict[str, Optional[torch.Tensor]]], +) -> None: + """Update the tensor containers in ``dst_containers`` with the ones in ``src_containers``.""" + for src_container, dst_container in zip(src_containers, dst_containers): + dst_container.update(src_container) + + +@contextlib.contextmanager +def container_context( + orig_containers: Iterable[Dict[str, Optional[torch.Tensor]]], + args_containers: Iterable[Dict[str, Optional[torch.Tensor]]], +) -> Generator[None, None, None]: + # pylint: disable-next=line-too-long + """A context manager that temporarily updates the containers in ``orig_containers`` with the ones in ``args_containers``.""" + if not isinstance(orig_containers, (list, tuple)): + orig_containers = list(orig_containers) + orig_containers_backups = [container.copy() for container in orig_containers] + try: + update_containers(orig_containers, args_containers) + yield + finally: + update_containers(orig_containers, orig_containers_backups) + + def make_optimality_from_objective( objective: Callable[..., torch.Tensor] ) -> Callable[..., TupleOfTensors]: """Make a function that computes the optimality function of the objective function.""" - # pylint: disable-next=redefined-builtin + def optimality(self: 'ImplicitMetaGradientModule', *input, **kwargs) -> TupleOfTensors: params_containers = extract_module_containers(self, with_buffers=False)[0] - params_containers_backups = [container.copy() for container in params_containers] flat_params: TupleOfTensors # pylint: disable-next=line-too-long flat_params, params_containers_treespec = pytree.tree_flatten_as_tuple(params_containers) # type: ignore[arg-type] - # pylint: disable-next=redefined-builtin def objective_fn(flat_params: TupleOfTensors, *input, **kwargs) -> torch.Tensor: flat_grad_tracking_params = flat_params grad_tracking_params_containers: Tuple[ @@ -52,18 +79,8 @@ def objective_fn(flat_params: TupleOfTensors, *input, **kwargs) -> torch.Tensor: params_containers_treespec, flat_grad_tracking_params ) - try: - for container, grad_tracking_container in zip( - params_containers, grad_tracking_params_containers - ): - container.update(grad_tracking_container) - + with container_context(params_containers, grad_tracking_params_containers): return objective(self, *input, **kwargs) - finally: - for container, container_backup in zip( - params_containers, params_containers_backups - ): - container.update(container_backup) objective_grad_fn = functorch.grad(objective_fn, argnums=0) flat_grads = objective_grad_fn(flat_params, *input, **kwargs) @@ -87,7 +104,7 @@ def enable_implicit_gradients( @functools.wraps(cls_solve) def wrapped( # pylint: disable=too-many-locals - self: 'ImplicitMetaGradientModule', *input, **kwargs # pylint: disable=redefined-builtin + self: 'ImplicitMetaGradientModule', *input, **kwargs ) -> Union['ImplicitMetaGradientModule', Tuple['ImplicitMetaGradientModule', Any]]: """Solve the optimization problem.""" params_containers = extract_module_containers(self, with_buffers=False)[0] @@ -97,10 +114,6 @@ def wrapped( # pylint: disable=too-many-locals extract_module_containers(meta_module, with_buffers=False)[0] ) meta_params_containers = tuple(meta_params_containers) - params_containers_backups = tuple(container.copy() for container in params_containers) - meta_params_containers_backups = tuple( - container.copy() for container in meta_params_containers - ) flat_params: TupleOfTensors flat_meta_params: TupleOfTensors @@ -114,7 +127,7 @@ def wrapped( # pylint: disable=too-many-locals def optimality_fn( flat_params: TupleOfTensors, flat_meta_params: TupleOfTensors, - *input, # pylint: disable=redefined-builtin + *input, **kwargs, ) -> TupleOfTensors: flat_grad_tracking_params = flat_params @@ -130,26 +143,23 @@ def optimality_fn( meta_params_containers_treespec, flat_grad_tracking_meta_params ) - try: - for container, grad_tracking_container in itertools.chain( - zip(params_containers, grad_tracking_params_containers), - zip(meta_params_containers, grad_tracking_meta_params_containers), - ): - container.update(grad_tracking_container) - + with container_context( + itertools.chain( + params_containers, + meta_params_containers, + ), + itertools.chain( + grad_tracking_params_containers, + grad_tracking_meta_params_containers, + ), + ): return self.optimality(*input, **kwargs) - finally: - for container, container_backup in itertools.chain( - zip(params_containers, params_containers_backups), - zip(meta_params_containers, meta_params_containers_backups), - ): - container.update(container_backup) @custom_root(optimality_fn, argnums=1, **custom_root_kwargs) # type: ignore[arg-type] def solver_fn( flat_params: TupleOfTensors, # pylint: disable=unused-argument flat_meta_params: TupleOfTensors, # pylint: disable=unused-argument - *input, # pylint: disable=redefined-builtin + *input, **kwargs, ) -> Union[TupleOfTensors, Tuple[TupleOfTensors, Any]]: output = cls_solve(self, *input, **kwargs) @@ -227,7 +237,7 @@ def __init_subclass__( enable_implicit_gradients(cls) def solve( - self, *input, **kwargs # pylint: disable=redefined-builtin + self, *input, **kwargs ) -> Union['ImplicitMetaGradientModule', Tuple['ImplicitMetaGradientModule', Any]]: """Solves the inner optimization problem. @@ -259,7 +269,6 @@ def solve(self, batch, labels): """ raise NotImplementedError # update parameters - # pylint: disable-next=redefined-builtin def optimality(self, *input, **kwargs) -> TensorTree: r"""Computes the optimality residual. @@ -302,7 +311,6 @@ def optimality(self, *input, **kwargs) -> TensorTree: """ # pylint: disable=line-too-long raise NotImplementedError - # pylint: disable-next=redefined-builtin def objective(self, *input, **kwargs) -> torch.Tensor: """Computes the objective function value. diff --git a/torchopt/optim/meta/adam.py b/torchopt/optim/meta/adam.py index 8d934e2c..9340b513 100644 --- a/torchopt/optim/meta/adam.py +++ b/torchopt/optim/meta/adam.py @@ -37,7 +37,7 @@ class MetaAdam(MetaOptimizer): # pylint: disable-next=too-many-arguments def __init__( self, - net: nn.Module, + module: nn.Module, lr: ScalarOrSchedule = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, @@ -51,7 +51,7 @@ def __init__( """The :meth:`init` function. Args: - net: (nn.Module) + module: (nn.Module) A network whose parameters should be optimized. lr: (default: :const:`1e-3`) This is a fixed global scaling factor. @@ -75,7 +75,7 @@ def __init__( If :data:`True` use our implemented fused operator. """ super().__init__( - net, + module, alias.adam( lr=lr, betas=betas, diff --git a/torchopt/optim/meta/adamw.py b/torchopt/optim/meta/adamw.py index cb91f38f..70f3a80a 100644 --- a/torchopt/optim/meta/adamw.py +++ b/torchopt/optim/meta/adamw.py @@ -37,7 +37,7 @@ class MetaAdamW(MetaOptimizer): # pylint: disable-next=too-many-arguments def __init__( self, - net: nn.Module, + module: nn.Module, lr: ScalarOrSchedule = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, @@ -52,7 +52,7 @@ def __init__( """The :meth:`init` function. Args: - net: (nn.Module) + module: (nn.Module) A network whose parameters should be optimized. lr: (default: :const:`1e-3`) This is a fixed global scaling factor. @@ -85,7 +85,7 @@ def __init__( If :data:`True` use our implemented fused operator. """ super().__init__( - net, + module, alias.adamw( lr=lr, betas=betas, diff --git a/torchopt/optim/meta/base.py b/torchopt/optim/meta/base.py index b83d3fd2..7af32e17 100644 --- a/torchopt/optim/meta/base.py +++ b/torchopt/optim/meta/base.py @@ -14,7 +14,7 @@ # ============================================================================== """The base class for differentiable meta-optimizers.""" -from typing import Dict, List, Optional, Sequence, Tuple, cast +from typing import Dict, List, Optional, Sequence, Tuple import torch import torch.nn as nn @@ -31,11 +31,11 @@ class MetaOptimizer: """The base class for high-level differentiable optimizers.""" - def __init__(self, net: nn.Module, impl: GradientTransformation): + def __init__(self, module: nn.Module, impl: GradientTransformation) -> None: """The :meth:`init` function. Args: - net: (nn.Module) + module: (nn.Module) A network whose parameters should be optimized. impl: (GradientTransformation) A low level optimizer function, it could be a optimizer function provided by @@ -51,9 +51,9 @@ def __init__(self, net: nn.Module, impl: GradientTransformation): self.param_containers_groups: List[Tuple[Dict[str, Optional[torch.Tensor]], ...]] = [] self.state_groups: List[OptState] = [] - self.add_param_group(net) + self.add_param_group(module) - def step(self, loss: torch.Tensor): # pylint: disable=too-many-locals + def step(self, loss: torch.Tensor) -> None: # pylint: disable=too-many-locals """Compute the gradients of the loss to the network parameters and update network parameters. Graph of the derivative will be constructed, allowing to compute higher order derivative @@ -84,16 +84,17 @@ def step(self, loss: torch.Tensor): # pylint: disable=too-many-locals ) self.state_groups[i] = new_state flat_new_params = apply_updates(flat_params, updates, inplace=False) - new_params = cast( - Tuple[Dict[str, Optional[torch.Tensor]], ...], - pytree.tree_unflatten(container_treespec, flat_new_params), + new_params: Tuple[ + Dict[str, Optional[torch.Tensor]], ... + ] = pytree.tree_unflatten( # type: ignore[assignment] + container_treespec, flat_new_params ) for container, new_param in zip(param_container, new_params): container.update(new_param) - def add_param_group(self, net: nn.Module) -> None: + def add_param_group(self, module: nn.Module) -> None: """Add a param group to the optimizer's :attr:`state_groups`.""" - params_container = extract_module_containers(net, with_buffers=False)[0] + params_container = extract_module_containers(module, with_buffers=False)[0] flat_params: TupleOfTensors = tuple(pytree.tree_leaves(params_container)) # type: ignore[arg-type] optimizer_state = self.impl.init(flat_params) self.param_containers_groups.append(params_container) diff --git a/torchopt/optim/meta/rmsprop.py b/torchopt/optim/meta/rmsprop.py index e7bc9b37..47c3e983 100644 --- a/torchopt/optim/meta/rmsprop.py +++ b/torchopt/optim/meta/rmsprop.py @@ -35,7 +35,7 @@ class MetaRMSProp(MetaOptimizer): # pylint: disable-next=too-many-arguments def __init__( self, - net: nn.Module, + module: nn.Module, lr: ScalarOrSchedule = 1e-2, alpha: float = 0.99, eps: float = 1e-8, @@ -50,7 +50,7 @@ def __init__( """The :meth:`init` function. Args: - net: (nn.Module) + module: (nn.Module) A network whose parameters should be optimized. lr: (default: :const:`1e-2`) This is a fixed global scaling factor. @@ -76,7 +76,7 @@ def __init__( Maximize the params based on the objective, instead of minimizing. """ super().__init__( - net, + module, alias.rmsprop( lr=lr, alpha=alpha, diff --git a/torchopt/optim/meta/sgd.py b/torchopt/optim/meta/sgd.py index 78b8e2fc..f46158a6 100644 --- a/torchopt/optim/meta/sgd.py +++ b/torchopt/optim/meta/sgd.py @@ -35,7 +35,7 @@ class MetaSGD(MetaOptimizer): # pylint: disable-next=too-many-arguments def __init__( self, - net: nn.Module, + module: nn.Module, lr: ScalarOrSchedule, momentum: float = 0.0, weight_decay: float = 0.0, @@ -47,7 +47,7 @@ def __init__( """The :meth:`init` function. Args: - net: (nn.Module) + module: (nn.Module) A network whose parameters should be optimized. lr: This is a fixed global scaling factor. momentum: (default: :const:`0.0`) @@ -66,7 +66,7 @@ def __init__( Maximize the params based on the objective, instead of minimizing. """ super().__init__( - net, + module, alias.sgd( lr=lr, momentum=momentum, diff --git a/torchopt/utils.py b/torchopt/utils.py index cfe25e32..404dfee1 100644 --- a/torchopt/utils.py +++ b/torchopt/utils.py @@ -85,9 +85,10 @@ def stop_gradient(target: Union[TensorTree, ModuleState, nn.Module, 'MetaOptimiz # pylint: disable-next=import-outside-toplevel from torchopt.optim.meta.base import MetaOptimizer - def f(obj): + def fn_(obj): if isinstance(obj, torch.Tensor): - obj.detach_().requires_grad_(obj.requires_grad) + requires_grad = obj.requires_grad + obj.detach_().requires_grad_(requires_grad) if isinstance(target, ModuleState): true_target = cast(TensorTree, (target.params, target.buffers)) @@ -98,7 +99,7 @@ def f(obj): else: true_target = cast(TensorTree, target) # tree of tensors - pytree.tree_map(f, true_target) + pytree.tree_map(fn_, true_target) @overload diff --git a/tutorials/5_Implicit_Differentiation.ipynb b/tutorials/5_Implicit_Differentiation.ipynb index f52aceb1..c83e69b8 100644 --- a/tutorials/5_Implicit_Differentiation.ipynb +++ b/tutorials/5_Implicit_Differentiation.ipynb @@ -78,11 +78,11 @@ "outputs": [], "source": [ "# Optimality function\n", - "def imaml_objective(optimal_params, init_params, data):\n", + "def imaml_objective(params, meta_params, data):\n", " x, y, fmodel = data\n", - " y_pred = fmodel(optimal_params, x)\n", + " y_pred = fmodel(params, x)\n", " regularization_loss = 0.0\n", - " for p1, p2 in zip(optimal_params, init_params):\n", + " for p1, p2 in zip(params, meta_params):\n", " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", " loss = F.mse_loss(y_pred, y) + regularization_loss\n", " return loss\n", @@ -94,13 +94,12 @@ "# You can also set argnums as (1, 2) if you want to backpropogate through multiple meta parameters\n", "\n", "# Here we pass argnums=1 to the custom_root. That means we want to compute the gradient of\n", - "# optimal_params w.r.t. the 1-indexed argument in inner_solver, i.e., init_params.\n", + "# optimal_params w.r.t. the 1-indexed argument in inner_solver, i.e., params.\n", "@torchopt.diff.implicit.custom_root(functorch.grad(imaml_objective, argnums=0), argnums=1)\n", - "def inner_solver(init_params_copy, init_params, data):\n", + "def inner_solver(params, meta_params, data):\n", " \"\"\"Solve ridge regression by conjugate gradient.\"\"\"\n", " # Initial functional optimizer based on TorchOpt\n", " x, y, fmodel = data\n", - " params = init_params_copy\n", " optimizer = torchopt.sgd(lr=2e-2)\n", " opt_state = optimizer.init(params)\n", " with torch.enable_grad():\n", @@ -111,13 +110,13 @@ "\n", " # Compute regularization loss\n", " regularization_loss = 0.0\n", - " for p1, p2 in zip(params, init_params):\n", + " for p1, p2 in zip(params, meta_params):\n", " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", " final_loss = loss + regularization_loss\n", "\n", " grads = torch.autograd.grad(final_loss, params) # compute gradients\n", - " updates, opt_state = optimizer.update(grads, opt_state) # get updates\n", - " params = torchopt.apply_updates(params, updates)\n", + " updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get updates\n", + " params = torchopt.apply_updates(params, updates, inplace=True)\n", "\n", " optimal_params = params\n", " return optimal_params"
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: