diff --git a/CHANGELOG.md b/CHANGELOG.md index 74d23144..3a3ad174 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Use postponed evaluation of annotations and update doctring style by [@XuehaiPan](https://github.com/XuehaiPan) in [#135](https://github.com/metaopt/torchopt/pull/135). - Rewrite setup CUDA Toolkit logic by [@XuehaiPan](https://github.com/XuehaiPan) in [#133](https://github.com/metaopt/torchopt/pull/133). ### Fixed diff --git a/README.md b/README.md index c1fb97ba..321f39e3 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,6 @@ ![CodeCov](https://img.shields.io/codecov/c/gh/metaopt/torchopt) ![Documentation Status](https://img.shields.io/readthedocs/torchopt?logo=readthedocs) ![Downloads](https://static.pepy.tech/personalized-badge/torchopt?period=total&left_color=grey&right_color=blue&left_text=downloads) - ![GitHub Repo Stars](https://img.shields.io/github/stars/metaopt/torchopt?color=brightgreen&logo=github) ![License](https://img.shields.io/github/license/metaopt/torchopt?label=license&logo=) diff --git a/docs/source/api/api.rst b/docs/source/api/api.rst index c7e04e95..b2866407 100644 --- a/docs/source/api/api.rst +++ b/docs/source/api/api.rst @@ -285,6 +285,115 @@ Chain .. autofunction:: chain +Distributed Utilities +===================== + +.. currentmodule:: torchopt.distributed + +Initialization and Synchronization +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + auto_init_rpc + barrier + +.. autofunction:: auto_init_rpc +.. autofunction:: barrier + +Process group information +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + get_world_info + get_world_rank + get_rank + get_world_size + get_local_rank + get_local_world_size + get_worker_id + +.. autofunction:: get_world_info +.. autofunction:: get_world_rank +.. autofunction:: get_rank +.. autofunction:: get_world_size +.. autofunction:: get_local_rank +.. autofunction:: get_local_world_size +.. autofunction:: get_worker_id + +Worker selection +~~~~~~~~~~~~~~~~ + +.. autosummary:: + + on_rank + not_on_rank + rank_zero_only + rank_non_zero_only + +.. autofunction:: on_rank +.. autofunction:: not_on_rank +.. autofunction:: rank_zero_only +.. autofunction:: rank_non_zero_only + +Remote Procedure Call (RPC) +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + remote_async_call + remote_sync_call + +.. autofunction:: remote_async_call +.. autofunction:: remote_sync_call + +Predefined partitioners and reducers +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + dim_partitioner + batch_partitioner + mean_reducer + sum_reducer + +.. autofunction:: dim_partitioner +.. autofunction:: batch_partitioner +.. autofunction:: mean_reducer +.. autofunction:: sum_reducer + +Function parallelization wrappers +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + parallelize + parallelize_async + parallelize_sync + +.. autofunction:: parallelize +.. autofunction:: parallelize_async +.. autofunction:: parallelize_sync + +Distributed Autograd +~~~~~~~~~~~~~~~~~~~~ + +.. currentmodule:: torchopt.distributed.autograd + +.. autosummary:: + + context + get_gradients + backward + grad + +.. autofunction:: context +.. autofunction:: get_gradients +.. autofunction:: backward +.. autofunction:: grad + + General Utilities ================= diff --git a/docs/source/distributed/distributed.rst b/docs/source/distributed/distributed.rst index f85eec3f..b6f00951 100644 --- a/docs/source/distributed/distributed.rst +++ b/docs/source/distributed/distributed.rst @@ -142,7 +142,6 @@ Initialization and Synchronization .. autosummary:: - torchopt.distributed.auto_init_rpc torchopt.distributed.barrier @@ -197,7 +196,6 @@ Process group information .. autosummary:: - torchopt.distributed.get_world_info torchopt.distributed.get_world_rank torchopt.distributed.get_rank @@ -228,7 +226,6 @@ Worker selection .. autosummary:: - torchopt.distributed.on_rank torchopt.distributed.not_on_rank torchopt.distributed.rank_zero_only @@ -275,7 +272,6 @@ Remote Procedure Call (RPC) .. autosummary:: - torchopt.distributed.remote_async_call torchopt.distributed.remote_sync_call @@ -354,7 +350,6 @@ Predefined partitioners and reducers .. autosummary:: - torchopt.distributed.dim_partitioner torchopt.distributed.batch_partitioner torchopt.distributed.mean_reducer @@ -439,7 +434,6 @@ Function parallelization wrappers .. autosummary:: - torchopt.distributed.parallelize torchopt.distributed.parallelize_async torchopt.distributed.parallelize_sync @@ -490,7 +484,6 @@ Distributed Autograd .. autosummary:: - torchopt.distributed.autograd.context torchopt.distributed.autograd.get_gradients torchopt.distributed.autograd.backward diff --git a/docs/source/spelling_wordlist.txt b/docs/source/spelling_wordlist.txt index 8f9d6895..aac17046 100644 --- a/docs/source/spelling_wordlist.txt +++ b/docs/source/spelling_wordlist.txt @@ -171,3 +171,4 @@ issubclass abc ABCMeta subclasscheck +ctx diff --git a/tests/helpers.py b/tests/helpers.py index 4bba706e..23e178f0 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -13,11 +13,13 @@ # limitations under the License. # ============================================================================== +from __future__ import annotations + import copy import itertools import os import random -from typing import Iterable, Optional, Tuple, Union +from typing import Iterable import numpy as np import pytest @@ -137,7 +139,7 @@ def get_model(): @torch.no_grad() def get_models( device: torch.types.Device = None, dtype: torch.dtype = torch.float32 -) -> Tuple[nn.Module, nn.Module, nn.Module, data.DataLoader]: +) -> tuple[nn.Module, nn.Module, nn.Module, data.DataLoader]: seed_everything(seed=42) model_base = get_model().to(dtype=dtype) @@ -166,12 +168,12 @@ def get_models( @torch.no_grad() def assert_model_all_close( - model: Union[nn.Module, Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]], + model: nn.Module | tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]], model_ref: nn.Module, model_base: nn.Module, dtype: torch.dtype = torch.float32, - rtol: Optional[float] = None, - atol: Optional[float] = None, + rtol: float | None = None, + atol: float | None = None, equal_nan: bool = False, ) -> None: if isinstance(model, tuple): @@ -194,8 +196,8 @@ def assert_all_close( actual: torch.Tensor, expected: torch.Tensor, base: torch.Tensor = None, - rtol: Optional[float] = None, - atol: Optional[float] = None, + rtol: float | None = None, + atol: float | None = None, equal_nan: bool = False, ) -> None: if base is not None: @@ -223,9 +225,9 @@ def assert_all_close( def assert_pytree_all_close( actual: TensorTree, expected: TensorTree, - base: Optional[TensorTree] = None, - rtol: Optional[float] = None, - atol: Optional[float] = None, + base: TensorTree | None = None, + rtol: float | None = None, + atol: float | None = None, equal_nan: bool = False, ) -> None: actual_leaves, actual_treespec = pytree.tree_flatten(actual) diff --git a/tests/test_alias.py b/tests/test_alias.py index c613d7d5..b609cf58 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -13,7 +13,9 @@ # limitations under the License. # ============================================================================== -from typing import Callable, Tuple +from __future__ import annotations + +from typing import Callable import functorch import pytest @@ -107,7 +109,7 @@ def test_sgd( def test_adam( dtype: torch.dtype, lr: float, - betas: Tuple[float, float], + betas: tuple[float, float], eps: float, inplace: bool, weight_decay: float, @@ -177,7 +179,7 @@ def test_maml_adam( outer_lr: float, inner_lr: float, inner_update: int, - betas: Tuple[float, float], + betas: tuple[float, float], eps: float, inplace: bool, weight_decay: float, @@ -263,7 +265,7 @@ def maml_inner_solver_torchopt(params, data, use_accelerated_op): def test_adamw( dtype: torch.dtype, lr: float, - betas: Tuple[float, float], + betas: tuple[float, float], eps: float, inplace: bool, weight_decay: float, @@ -333,8 +335,8 @@ def test_adamw( def test_adam_accelerated_cuda( dtype: torch.dtype, lr: float, - optimizers: Tuple[Callable, torch.optim.Optimizer], - betas: Tuple[float, float], + optimizers: tuple[Callable, torch.optim.Optimizer], + betas: tuple[float, float], eps: float, inplace: bool, weight_decay: float, diff --git a/tests/test_implicit.py b/tests/test_implicit.py index ce0ee23b..9e3722d3 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -13,10 +13,11 @@ # limitations under the License. # ============================================================================== +from __future__ import annotations + import copy from collections import OrderedDict from types import FunctionType -from typing import Tuple import functorch import jax @@ -55,7 +56,7 @@ def forward(self, x): return self.fc(x) -def get_model_jax(dtype: np.dtype = np.float32) -> Tuple[FunctionType, OrderedDict]: +def get_model_jax(dtype: np.dtype = np.float32) -> tuple[FunctionType, OrderedDict]: helpers.seed_everything(seed=42) def func(params, x): @@ -73,7 +74,7 @@ def func(params, x): @torch.no_grad() def get_model_torch( device: torch.types.Device = None, dtype: torch.dtype = torch.float32 -) -> Tuple[nn.Module, data.DataLoader]: +) -> tuple[nn.Module, data.DataLoader]: helpers.seed_everything(seed=42) model = FcNet(MODEL_NUM_INPUTS, MODEL_NUM_CLASSES).to(dtype=dtype) diff --git a/tests/test_meta_optim.py b/tests/test_meta_optim.py index 2c0966cc..61f8a7ad 100644 --- a/tests/test_meta_optim.py +++ b/tests/test_meta_optim.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -from typing import Tuple +from __future__ import annotations import torch import torch.nn.functional as F @@ -40,7 +40,7 @@ def test_maml_meta_adam( outer_lr: float, inner_lr: float, inner_update: int, - betas: Tuple[float, float], + betas: tuple[float, float], eps: float, eps_root: float, weight_decay: float, diff --git a/tests/test_optim.py b/tests/test_optim.py index c43bc438..b2be7500 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -13,7 +13,9 @@ # limitations under the License. # ============================================================================== -from typing import Callable, Tuple +from __future__ import annotations + +from typing import Callable import functorch import pytest @@ -96,7 +98,7 @@ def test_SGD( def test_Adam( dtype: torch.dtype, lr: float, - betas: Tuple[float, float], + betas: tuple[float, float], eps: float, weight_decay: float, maximize: bool, @@ -154,7 +156,7 @@ def test_Adam( def test_AdamW( dtype: torch.dtype, lr: float, - betas: Tuple[float, float], + betas: tuple[float, float], eps: float, weight_decay: float, maximize: bool, @@ -216,8 +218,8 @@ def test_AdamW( def test_Adam_accelerated_cuda( dtype: torch.dtype, lr: float, - optimizers: Tuple[torchopt.Optimizer, torch.optim.Optimizer], - betas: Tuple[float, float], + optimizers: tuple[torchopt.Optimizer, torch.optim.Optimizer], + betas: tuple[float, float], eps: float, weight_decay: float, maximize: bool, @@ -339,7 +341,7 @@ def test_RMSProp( def test_FuncOptimizer( dtype: torch.dtype, lr: float, - optimizers: Tuple[Callable, torch.optim.Optimizer], + optimizers: tuple[Callable, torch.optim.Optimizer], inplace: bool, weight_decay: float, ) -> None: diff --git a/tests/test_schedule.py b/tests/test_schedule.py index 9590acf8..ae714875 100644 --- a/tests/test_schedule.py +++ b/tests/test_schedule.py @@ -13,7 +13,9 @@ # limitations under the License. # ============================================================================== -from typing import Callable, Tuple +from __future__ import annotations + +from typing import Callable import functorch import numpy as np @@ -62,7 +64,7 @@ def test_lr_linear_schedule( dtype: torch.dtype, lr: float, total_iters: int, - optimizers: Tuple[Callable, torch.optim.Optimizer], + optimizers: tuple[Callable, torch.optim.Optimizer], inplace: bool, weight_decay: float, use_chain_flat: bool, diff --git a/tests/test_transform.py b/tests/test_transform.py index 4dfd034d..9598386d 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -13,13 +13,8 @@ # limitations under the License. # ============================================================================== -from typing import Tuple - -import functorch import torch -import torch.nn.functional as F -import helpers import torchopt diff --git a/torchopt/_C/adam_op.pyi b/torchopt/_C/adam_op.pyi index bc3e8ebc..7ecfe7c2 100644 --- a/torchopt/_C/adam_op.pyi +++ b/torchopt/_C/adam_op.pyi @@ -15,7 +15,7 @@ # pylint: disable=all -from typing import Tuple +from __future__ import annotations import torch @@ -28,7 +28,7 @@ def forward_( eps: float, eps_root: float, count: int, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ... +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ... def forward_mu(updates: torch.Tensor, mu: torch.Tensor, b1: float) -> torch.Tensor: ... def forward_nu(updates: torch.Tensor, nu: torch.Tensor, b2: float) -> torch.Tensor: ... def forward_updates( @@ -42,10 +42,10 @@ def forward_updates( ) -> torch.Tensor: ... def backward_mu( dmu: torch.Tensor, updates: torch.Tensor, mu: torch.Tensor, b1: float -) -> Tuple[torch.Tensor, torch.Tensor]: ... +) -> tuple[torch.Tensor, torch.Tensor]: ... def backward_nu( dnu: torch.Tensor, updates: torch.Tensor, nu: torch.Tensor, b2: float -) -> Tuple[torch.Tensor, torch.Tensor]: ... +) -> tuple[torch.Tensor, torch.Tensor]: ... def backward_updates( dupdates: torch.Tensor, updates: torch.Tensor, @@ -55,4 +55,4 @@ def backward_updates( b2: float, eps_root: float, count: int, -) -> Tuple[torch.Tensor, torch.Tensor]: ... +) -> tuple[torch.Tensor, torch.Tensor]: ... diff --git a/torchopt/accelerated_op/__init__.py b/torchopt/accelerated_op/__init__.py index 003a8a9f..ede60009 100644 --- a/torchopt/accelerated_op/__init__.py +++ b/torchopt/accelerated_op/__init__.py @@ -14,7 +14,9 @@ # ============================================================================== """The accelerated Ops.""" -from typing import Iterable, Optional, Union +from __future__ import annotations + +from typing import Iterable import torch @@ -22,7 +24,7 @@ from torchopt.typing import Device -def is_available(devices: Optional[Union[Device, Iterable[Device]]] = None) -> bool: +def is_available(devices: Device | Iterable[Device] | None = None) -> bool: """Check the availability of accelerated optimizer.""" op = AdamOp() diff --git a/torchopt/accelerated_op/_src/adam_op.py b/torchopt/accelerated_op/_src/adam_op.py index 9f801b8d..ab5ea195 100644 --- a/torchopt/accelerated_op/_src/adam_op.py +++ b/torchopt/accelerated_op/_src/adam_op.py @@ -16,7 +16,7 @@ # pylint: disable=invalid-name,too-many-arguments,unused-argument -from typing import Tuple +from __future__ import annotations import torch @@ -30,7 +30,7 @@ def forward_( eps: float, eps_root: float, count: int, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Adam forward inplace.""" mu = mu.mul_(b1).add_(updates, alpha=1.0 - b1) nu = nu.mul_(b2).addcmul_(updates, updates, value=1.0 - b2) @@ -80,7 +80,7 @@ def backward_mu( updates: torch.Tensor, mu: torch.Tensor, b1: float, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """Adam backward mu.""" dupdates = dmu.mul(1.0 - b1) dmu = dmu.mul(b1) @@ -92,7 +92,7 @@ def backward_nu( updates: torch.Tensor, nu: torch.Tensor, b2: float, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """Adam backward nu.""" dupdates = updates.mul(dnu).mul_(2.0 * (1.0 - b2)) dnu = dnu.mul(b2) @@ -108,7 +108,7 @@ def backward_updates( b2: float, eps_root: float, count: int, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """Adam backward updates.""" one_minus_pow_b1 = 1.0 - pow(b1, count) inv_one_minus_pow_b2 = 1.0 / (1.0 - pow(b2, count) + eps_root) diff --git a/torchopt/accelerated_op/adam_op.py b/torchopt/accelerated_op/adam_op.py index 6b93bf18..232513d6 100644 --- a/torchopt/accelerated_op/adam_op.py +++ b/torchopt/accelerated_op/adam_op.py @@ -16,8 +16,10 @@ # pylint: disable=c-extension-no-member,invalid-name +from __future__ import annotations + import contextlib -from typing import Any, Optional, Tuple +from typing import Any import torch @@ -132,9 +134,9 @@ def __call__( self, mu: torch.Tensor, nu: torch.Tensor, - updates: Optional[torch.Tensor], + updates: torch.Tensor | None, count: int, - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: """Apply the Adam operator.""" if updates is None: return mu, nu, None diff --git a/torchopt/alias/adam.py b/torchopt/alias/adam.py index a7f90a79..08654577 100644 --- a/torchopt/alias/adam.py +++ b/torchopt/alias/adam.py @@ -31,7 +31,7 @@ # ============================================================================== """Preset :class:`GradientTransformation` for the Adam optimizer.""" -from typing import Tuple +from __future__ import annotations from torchopt.alias.utils import ( _get_use_chain_flat, @@ -49,7 +49,7 @@ # pylint: disable-next=too-many-arguments def adam( lr: ScalarOrSchedule = 1e-3, - betas: Tuple[float, float] = (0.9, 0.999), + betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.0, *, @@ -68,26 +68,25 @@ def adam( - Kingma et al, 2014: https://arxiv.org/abs/1412.6980 Args: - lr: (default: :const:`1e-3`) - This is a fixed global scaling factor. - betas: (default: :const:`(0.9, 0.999)`) - Coefficients used for computing running averages of gradient and its square. - eps: (default: :const:`1e-8`) - A small constant applied to denominator outside of the square root (as in the Adam - paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - eps_root: (default: :data:`0.0`) - A small constant applied to denominator inside the square root (as in RMSProp), to avoid - dividing by zero when rescaling. This is needed for example when computing - (meta-)gradients through Adam. - moment_requires_grad: (default: :data:`False`) - If :data:`True` the momentums will be created with flag ``requires_grad=True``, this - flag is often used in Meta-Learning algorithms. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - use_accelerated_op: (default: :data:`False`) - If :data:`True` use our implemented fused operator. + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + eps_root (float, optional): A small constant applied to denominator inside the square root + (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example + when computing (meta-)gradients through Adam. (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created with + flag ``requires_grad=True``, this flag is often used in Meta-Learning algorithms. + (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of minimizing. + (default: :data:`False`) + use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator. + (default: :data:`False`) Returns: The corresponding :class:`GradientTransformation` instance. diff --git a/torchopt/alias/adamw.py b/torchopt/alias/adamw.py index 9aecc8ee..21ef84ef 100644 --- a/torchopt/alias/adamw.py +++ b/torchopt/alias/adamw.py @@ -31,7 +31,9 @@ # ============================================================================== """Preset :class:`GradientTransformation` for the AdamW optimizer.""" -from typing import Any, Callable, Optional, Tuple, Union +from __future__ import annotations + +from typing import Callable from torchopt.alias.utils import ( _get_use_chain_flat, @@ -40,7 +42,7 @@ ) from torchopt.combine import chain from torchopt.transform import add_decayed_weights, scale_by_accelerated_adam, scale_by_adam -from torchopt.typing import GradientTransformation, Params, ScalarOrSchedule +from torchopt.typing import GradientTransformation, OptState, Params, ScalarOrSchedule __all__ = ['adamw'] @@ -49,12 +51,12 @@ # pylint: disable-next=too-many-arguments,too-many-locals def adamw( lr: ScalarOrSchedule = 1e-3, - betas: Tuple[float, float] = (0.9, 0.999), + betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 1e-2, *, eps_root: float = 0.0, - mask: Optional[Union[Any, Callable[[Params], Any]]] = None, + mask: OptState | Callable[[Params], OptState] | None = None, moment_requires_grad: bool = False, maximize: bool = False, use_accelerated_op: bool = False, @@ -70,35 +72,34 @@ def adamw( - Loshchilov et al, 2019: https://arxiv.org/abs/1711.05101 Args: - lr: (default: :const:`1e-3`) - This is a fixed global scaling factor. - betas: (default: :const:`(0.9, 0.999)`) - Coefficients used for computing running averages of gradient and its square. - eps: (default: :const:`1e-8`) - A small constant applied to denominator outside of the square root (as in the Adam - paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`1e-2`) - Strength of the weight decay regularization. Note that this weight decay is multiplied - with the learning rate. This is consistent with other frameworks such as PyTorch, but - different from (Loshchilov et al, 2019) where the weight decay is only multiplied with - the "schedule multiplier", but not the base learning rate. - eps_root: (default: :data:`0.0`) - A small constant applied to denominator inside the square root (as in RMSProp), to avoid - dividing by zero when rescaling. This is needed for example when computing - (meta-)gradients through Adam. - mask: (default: :data:`None`) - A tree with same structure as (or a prefix of) the params PyTree, or a Callable that + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with other + frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight + decay is only multiplied with the "schedule multiplier", but not the base learning rate. + (default: :const:`1e-2`) + eps_root (float, optional): A small constant applied to denominator inside the square root + (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example + when computing (meta-)gradients through Adam. (default: :const:`0.0`) + mask (tree of Tensor, callable, or None, optional): + A tree with same structure as (or a prefix of) the params pytree, or a function that returns such a pytree given the params/updates. The leaves should be booleans, :data:`True` for leaves/subtrees you want to apply the weight decay to, and - :data:`False` for those you want to skip. Note that the Adam gradient - transformations are applied to all parameters. - moment_requires_grad: (default: :data:`False`) - If :data:`True` the momentums will be created with flag ``requires_grad=True``, this - flag is often used in Meta-Learning algorithms. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - use_accelerated_op: (default: :data:`False`) - If :data:`True` use our implemented fused operator. + :data:`False` for those you want to skip. Note that the Adam gradient transformations + are applied to all parameters. (default: :data:`None`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created with + flag ``requires_grad=True``, this flag is often used in Meta-Learning algorithms. + (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) + use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator. + (default: :data:`False`) Returns: The corresponding :class:`GradientTransformation` instance. diff --git a/torchopt/alias/rmsprop.py b/torchopt/alias/rmsprop.py index 18a5c5e8..f0eb92cd 100644 --- a/torchopt/alias/rmsprop.py +++ b/torchopt/alias/rmsprop.py @@ -69,28 +69,25 @@ def rmsprop( - Graves, 2013: https://arxiv.org/abs/1308.0850 Args: - lr: (default: :const:`1e-2`) - This is a fixed global scaling factor. - alpha: (default: :const:`0.99`) - Smoothing constant, the decay used to track the magnitude of previous gradients. - eps: (default: :const:`1e-8`) - A small numerical constant to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - momentum: (default: :const:`0.0`) - The decay rate used by the momentum term. The momentum is not used when it is set to - :const:`0.0`. - centered: (default: :data:`False`) - If :data:`True`, use the variance of the past gradients to rescale the latest - gradients. - initial_scale: (default: :data:`0.0`) - Initialization of accumulators tracking the magnitude of previous updates. PyTorch - uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When reproducing results from a - paper, verify the value used by the authors. - nesterov: (default: :data:`False`) - Whether to use Nesterov momentum. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-2`) + alpha (float, optional): Smoothing constant, the decay used to track the magnitude of + previous gradients. (default: :const:`0.99`) + eps (float, optional): A small numerical constant to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + momentum (float, optional): The decay rate used by the momentum term. The momentum is not + used when it is set to :const:`0.0`. (default: :const:`0.0`) + centered (bool, optional): If :data:`True`, use the variance of the past gradients to + rescale the latest gradients. (default: :data:`False`) + initial_scale (float, optional): Initialization of accumulators tracking the magnitude of + previous updates. PyTorch uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When + reproducing results from a paper, verify the value used by the authors. + (default: :data:`0.0`) + nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) Returns: The corresponding :class:`GradientTransformation` instance. diff --git a/torchopt/alias/sgd.py b/torchopt/alias/sgd.py index 61b3d6e4..7d86b538 100644 --- a/torchopt/alias/sgd.py +++ b/torchopt/alias/sgd.py @@ -64,21 +64,19 @@ def sgd( - Sutskever et al, 2013: http://proceedings.mlr.press/v28/sutskever13.pdf Args: - lr: This is a fixed global scaling factor. - momentum: (default: :const:`0.0`) - The decay rate used by the momentum term. The momentum is not used when it is set to - :const:`0.0`. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - dampening: (default: :const:`0.0`) - Dampening for momentum. - nesterov: (default: :data:`False`) - Whether to use Nesterov momentum. - moment_requires_grad: (default: :data:`False`) - If :data:`True` the momentums will be created with flag ``requires_grad=True``, this - flag is often used in Meta-Learning algorithms. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. + lr (float or callable): This is a fixed global scaling factor or a learning rate + scheduler. + momentum (float, optional): The decay rate used by the momentum term. The momentum is not + used when it is set to :const:`0.0`. (default: :const:`0.0`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + dampening (float, optional): Dampening for momentum. (default: :const:`0.0`) + nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created with + flag ``requires_grad=True``, this flag is often used in Meta-Learning algorithms. + (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) Returns: The corresponding :class:`GradientTransformation` instance. diff --git a/torchopt/alias/utils.py b/torchopt/alias/utils.py index 869aad87..b5088164 100644 --- a/torchopt/alias/utils.py +++ b/torchopt/alias/utils.py @@ -13,8 +13,9 @@ # limitations under the License. r"""Utilities for the aliases of preset :class:`GradientTransformation`\s for optimizers.""" +from __future__ import annotations + import threading -from typing import Optional, Tuple from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation, identity @@ -93,9 +94,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, + params: Params | None = None, inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: assert params is not None, ( 'Parameters are required for weight decay. ' 'Call `update(updates, state, params=params)` instead.' @@ -126,9 +127,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, # pylint: disable=unused-argument + params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: if inplace: def f(g): @@ -151,9 +152,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, + params: Params | None = None, inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: assert params is not None, ( 'Parameters are required for weight decay. ' 'Call `update(updates, state, params=params)` instead.' diff --git a/torchopt/base.py b/torchopt/base.py index bb37b147..b250c387 100644 --- a/torchopt/base.py +++ b/torchopt/base.py @@ -31,9 +31,11 @@ # ============================================================================== """The base classes for gradient transformation.""" +from __future__ import annotations + import itertools from abc import abstractmethod -from typing import TYPE_CHECKING, Callable, NamedTuple, Optional, Tuple +from typing import TYPE_CHECKING, Callable, NamedTuple from typing_extensions import Protocol # Python 3.8+ @@ -67,12 +69,11 @@ class TransformInitFn(Protocol): # pylint: disable=too-few-public-methods """ @abstractmethod - def __call__(self, params: 'Params') -> 'OptState': + def __call__(self, params: Params) -> OptState: """Initialize the gradient transformation state. Args: - params: - The initial value of the parameters. + params (tree of Tensor): The initial value of the parameters. Returns: The initial state of the gradient transformation. @@ -93,21 +94,21 @@ class TransformUpdateFn(Protocol): # pylint: disable=too-few-public-methods @abstractmethod def __call__( self, - updates: 'Updates', - state: 'OptState', + updates: Updates, + state: OptState, *, - params: Optional['Params'] = None, + params: Params | None = None, inplace: bool = True, - ) -> Tuple['Updates', 'OptState']: + ) -> tuple[Updates, OptState]: """Transform the updates and state. Args: - updates: A tree of candidate updates. - state: The state of the gradient transformation. - params: (optional) - The current value of the parameters. - inplace: (optional) - If :data:`True`, modify updates and state using inplace operations. + updates (tree of Tensor): A tree of candidate updates. + state (tree of Tensor): The state of the gradient transformation. + params (tree of Tensor or None, optional): The current value of the parameters. + (default: :data:`None`) + inplace (bool, optional): If :data:`True`, modify updates and state using inplace + operations. (default: :data:`True`) Returns: The transformed ``updates``, and the updated ``state``. @@ -134,9 +135,9 @@ class GradientTransformation(NamedTuple): optimizer state. update: A pure function which takes as input a pytree of updates (with the same tree structure - as the original params ``pytree`` passed to :attr:`init`), the previous optimizer state - (which may have been initialized using the :attr:`init` function), and optionally the - ``inplace`` flag. The :attr:`update` function then returns the computed gradient + as the original params ``pytree`` passed to ``init``), the previous optimizer state + (which may have been initialized using the ``init`` function), and optionally the + ``inplace`` flag. The ``update`` function then returns the computed gradient updates, and a updates optimizer state. If the ``inplace`` flag is :data:`True`, the output results are the same instance as the input. """ @@ -145,7 +146,7 @@ class GradientTransformation(NamedTuple): update: TransformUpdateFn # pylint: disable-next=redefined-builtin - def chain(self, next: 'GradientTransformation') -> 'ChainedGradientTransformation': + def chain(self, next: GradientTransformation) -> ChainedGradientTransformation: """Chain two gradient transformations together.""" return ChainedGradientTransformation(self, next) @@ -157,9 +158,9 @@ class ChainedGradientTransformation(GradientTransformation): gradient transformations. """ - transformations: Tuple[GradientTransformation, ...] + transformations: tuple[GradientTransformation, ...] - def __new__(cls, *transformations: GradientTransformation) -> 'ChainedGradientTransformation': + def __new__(cls, *transformations: GradientTransformation) -> ChainedGradientTransformation: """Create a new chained gradient transformation.""" transformations = tuple( itertools.chain.from_iterable( @@ -175,16 +176,16 @@ def __new__(cls, *transformations: GradientTransformation) -> 'ChainedGradientTr init_fns, update_fns = tuple(zip(*transformations)) - def init_fn(params: 'Params') -> 'OptState': + def init_fn(params: Params) -> OptState: return tuple(fn(params) for fn in init_fns) def update_fn( - updates: 'Updates', - state: 'OptState', + updates: Updates, + state: OptState, *, - params: Optional['Params'] = None, + params: Params | None = None, inplace: bool = True, - ) -> Tuple['Updates', 'OptState']: + ) -> tuple[Updates, OptState]: if len(update_fns) != len(state): raise ValueError( 'The number of updates and states has to be the same in chain! Make sure you' @@ -219,15 +220,15 @@ def __hash__(self) -> int: """Return the hash of the chained gradient transformation.""" return hash(self.transformations) - def __getstate__(self) -> Tuple[GradientTransformation, ...]: + def __getstate__(self) -> tuple[GradientTransformation, ...]: """Return the state of the chained gradient transformation for serialization.""" return self.transformations - def __setstate__(self, state: Tuple[GradientTransformation, ...]) -> None: + def __setstate__(self, state: tuple[GradientTransformation, ...]) -> None: """Set the state of the chained gradient transformation from serialization.""" self.transformations = state - def __reduce__(self) -> Tuple[Callable, Tuple[Tuple[GradientTransformation, ...]]]: + def __reduce__(self) -> tuple[Callable, tuple[tuple[GradientTransformation, ...]]]: """Serialize the chained gradient transformation.""" return ChainedGradientTransformation, (self.transformations,) @@ -240,18 +241,18 @@ def __new__(cls): return super().__new__(cls, init=cls.init_fn, update=cls.update_fn) @staticmethod - def init_fn(params: 'Params') -> 'OptState': # pylint: disable=unused-argument + def init_fn(params: Params) -> OptState: # pylint: disable=unused-argument """Return empty state.""" return EmptyState() @staticmethod def update_fn( - updates: 'Updates', - state: 'OptState', + updates: Updates, + state: OptState, *, - params: Optional['Params'] = None, # pylint: disable=unused-argument + params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, # pylint: disable=unused-argument - ) -> Tuple['Updates', 'OptState']: + ) -> tuple[Updates, OptState]: """Return updates unchanged.""" return updates, state diff --git a/torchopt/clip.py b/torchopt/clip.py index 2469d17a..b2aafb48 100644 --- a/torchopt/clip.py +++ b/torchopt/clip.py @@ -17,7 +17,7 @@ # ============================================================================== """Utilities for gradient clipping.""" -from typing import Optional, Tuple, Union +from __future__ import annotations import torch @@ -33,18 +33,19 @@ def clip_grad_norm( - max_norm: Union[float, int], - norm_type: Union[float, int] = 2.0, + max_norm: float | int, + norm_type: float | int = 2.0, error_if_nonfinite: bool = False, ) -> GradientTransformation: """Clip gradient norm of an iterable of parameters. Args: max_norm (float or int): The maximum absolute value for each element in the update. - norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for - infinity norm. - error_if_nonfinite (bool): if :data:`True`, an error is thrown if the total norm of the - gradients from :attr:`updates` is ``nan``, ``inf``, or ``-inf``. + norm_type (float or int, optional): Type of the used p-norm. Can be ``'inf'`` for infinity + norm. (default: :const:`2.0`) + error_if_nonfinite (bool, optional): If :data:`True`, an error is thrown if the total norm + of the gradients from ``updates`` is ``nan``, ``inf``, or ``-inf``. + (default: :data:`False`) Returns: An ``(init_fn, update_fn)`` tuple. @@ -57,9 +58,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, # pylint: disable=unused-argument + params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: available_updates = pytree.tree_leaves(updates) if len(available_updates) == 0: return updates, state diff --git a/torchopt/combine.py b/torchopt/combine.py index 82297426..0f1ed8ec 100644 --- a/torchopt/combine.py +++ b/torchopt/combine.py @@ -31,7 +31,7 @@ # ============================================================================== """Utilities to define a chained transformation.""" -from typing import Optional, Tuple +from __future__ import annotations from torchopt import pytree from torchopt.base import ChainedGradientTransformation, GradientTransformation, identity @@ -49,8 +49,8 @@ def chain(*transformations: GradientTransformation) -> GradientTransformation: :func:`update_fn` which chains the update transformations feeding the appropriate state to each. Args: - *transformations: - A sequence of chainable ``(init_fn, update_fn)`` tuples. + *transformations (iterable of GradientTransformation): A sequence of chainable + ``(init_fn, update_fn)`` tuples. Returns: A single ``(init_fn, update_fn)`` tuple. @@ -66,8 +66,8 @@ def chain_flat(*transformations: GradientTransformation) -> GradientTransformati """Wrap around the inner transformations that manipulate the flattened tree structure (:class:``list``). Args: - *transformations: - A sequence of chainable ``(init_fn, update_fn)`` tuples. + *transformations (iterable of GradientTransformation): A sequence of chainable + ``(init_fn, update_fn)`` tuples. Returns: A single ``(init_fn, update_fn)`` tuple. @@ -86,9 +86,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, + params: Params | None = None, inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: flat_updates, treespec = pytree.tree_flatten(updates, none_is_leaf=True) if params is not None: flat_params = pytree.tree_leaves(params, none_is_leaf=True) diff --git a/torchopt/diff/implicit/decorator.py b/torchopt/diff/implicit/decorator.py index 377bc1f4..a5908963 100644 --- a/torchopt/diff/implicit/decorator.py +++ b/torchopt/diff/implicit/decorator.py @@ -16,9 +16,11 @@ # pylint: disable=invalid-name +from __future__ import annotations + import functools import inspect -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, Sequence, Tuple import functorch import torch @@ -47,7 +49,7 @@ def __init__( optimality_fn: Callable[..., TensorOrTensors], solution: TensorOrTensors, output_is_tensor: bool, - argnums: Tuple[int, ...], + argnums: tuple[int, ...], *args: Any, ) -> None: self.optimality_fn = optimality_fn @@ -88,7 +90,7 @@ def _root_vjp( args: Args, grad_outputs: TupleOfTensors, output_is_tensor: bool, - argnums: Tuple[int, ...], + argnums: tuple[int, ...], solve: Callable[..., TensorOrTensors] = linear_solve.solve_normal_cg(), ) -> TupleOfOptionalTensors: if output_is_tensor: @@ -145,14 +147,14 @@ def matvec(u: TupleOfTensors) -> TupleOfTensors: return tuple(true_output) -def _extract_kwargs(kwarg_keys: Sequence[str], flat_args: Tuple[Any, ...]) -> Tuple[Args, KwArgs]: +def _extract_kwargs(kwarg_keys: Sequence[str], flat_args: tuple[Any, ...]) -> tuple[Args, KwArgs]: nargs = len(flat_args) - len(kwarg_keys) args, kwarg_vals = flat_args[:nargs], flat_args[nargs:] kwargs = dict(zip(kwarg_keys, kwarg_vals)) return args, kwargs -def _signature_bind(signature: inspect.Signature, *args: Any, **kwargs: Any) -> Tuple[Args, KwArgs]: +def _signature_bind(signature: inspect.Signature, *args: Any, **kwargs: Any) -> tuple[Args, KwArgs]: bound = signature.bind(*args, **kwargs) bound.apply_defaults() return bound.args, bound.kwargs @@ -160,7 +162,7 @@ def _signature_bind(signature: inspect.Signature, *args: Any, **kwargs: Any) -> def _signature_bind_and_match( signature: inspect.Signature, *args: Any, **kwargs: Any -) -> Tuple[Args, KwArgs, Callable[[Args], Tuple[Args, KwArgs]]]: +) -> tuple[Args, KwArgs, Callable[[Args], tuple[Args, KwArgs]]]: # We want to bind *args and **kwargs based on the provided signature, but also to associate the # resulting positional arguments back. To achieve this, we lift arguments to a triple: # @@ -193,13 +195,13 @@ def map_args_back(out_args): def _split_tensor_and_others( - mixed_tuple: Tuple[Any, ...], -) -> Tuple[pytree.PyTreeSpec, Tuple[bool, ...], TupleOfTensors, Tuple[Any, ...]]: - flattened: List[Any] + mixed_tuple: tuple[Any, ...], +) -> tuple[pytree.PyTreeSpec, tuple[bool, ...], TupleOfTensors, tuple[Any, ...]]: + flattened: list[Any] flattened, treespec = pytree.tree_flatten(mixed_tuple, none_is_leaf=True) # type: ignore[arg-type] tensors: ListOfTensors = [] - non_tensors: List[Any] = [] - is_tensor_mask: List[bool] = [] + non_tensors: list[Any] = [] + is_tensor_mask: list[bool] = [] for item in flattened: is_tensor = isinstance(item, torch.Tensor) is_tensor_mask.append(is_tensor) @@ -212,10 +214,10 @@ def _split_tensor_and_others( def _merge_tensor_and_others( treespec: pytree.PyTreeSpec, - is_tensor_mask: Tuple[bool, ...], + is_tensor_mask: tuple[bool, ...], tensors: TupleOfTensors, - non_tensors: Tuple[Any, ...], -) -> Tuple[Any, ...]: + non_tensors: tuple[Any, ...], +) -> tuple[Any, ...]: tensor_counter = 0 non_tensor_counter = 0 results = [] @@ -231,13 +233,13 @@ def _merge_tensor_and_others( # pylint: disable-next=too-many-arguments,too-many-statements def _custom_root( - solver_fn: Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]], + solver_fn: Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]], optimality_fn: Callable[..., TensorOrTensors], solve: Callable[..., TensorOrTensors], - argnums: Tuple[int, ...], + argnums: tuple[int, ...], has_aux: bool, - reference_signature: Optional[Union[inspect.Signature, Callable]] = None, -) -> Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]]: + reference_signature: inspect.Signature | Callable | None = None, +) -> Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]]: solver_fn_signature = inspect.signature(solver_fn) if reference_signature is None: @@ -249,16 +251,16 @@ def _custom_root( reference_signature = inspect.signature(fn) def make_custom_vjp_solver_fn( - solver_fn: Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]], + solver_fn: Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]], kwarg_keys: Sequence[str], - args_signs: Tuple[Tuple[int, int, Optional[Union[Type[tuple], Type[list]]]], ...], - ) -> Type[Function]: + args_signs: tuple[tuple[int, int, type[tuple] | type[list] | None], ...], + ) -> type[Function]: # pylint: disable-next=missing-class-docstring,abstract-method class ImplicitMetaGradient(Function): @staticmethod def forward( # type: ignore[override] # pylint: disable=arguments-differ ctx: Any, *flat_args: Any - ) -> Tuple[Any, ...]: + ) -> tuple[Any, ...]: output, aux, output_is_tensor = None, None, False args = [] @@ -361,12 +363,12 @@ def backward( # pylint: disable=too-many-locals @functools.wraps(solver_fn) def wrapped_solver_fn( *args: Any, **kwargs: Any - ) -> Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]: + ) -> TensorOrTensors | tuple[TensorOrTensors, Any]: args, kwargs = _signature_bind(solver_fn_signature, *args, **kwargs) keys, vals = list(kwargs.keys()), list(kwargs.values()) - args_signs: List[Tuple[int, int, Optional[Union[Type[tuple], Type[list]]]]] = [] - flat_args: List[Any] = [] + args_signs: list[tuple[int, int, type[tuple] | type[list] | None]] = [] + flat_args: list[Any] = [] args_offset = 0 for idx, arg in enumerate(args): if idx in argnums: @@ -410,12 +412,12 @@ def wrapped_solver_fn( def custom_root( optimality_fn: Callable[..., TensorOrTensors], - argnums: Union[int, Tuple[int, ...]], + argnums: int | tuple[int, ...], has_aux: bool = False, solve: Callable[..., TensorOrTensors] = linear_solve.solve_normal_cg(), ) -> Callable[ - [Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]]], - Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]], + [Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]]], + Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]], ]: """Return a decorator for adding implicit differentiation to a root solver. @@ -442,18 +444,17 @@ def solver_fn(params, arg1, arg2, ...): **In best practice, the ``optimality_fn`` should have the same signature as ``solver_fn``.** Args: - optimality_fn: (callable) - An equation function, ``optimality_fn(params, *args)``. The invariant is - ``optimality_fn(solution, *args) == 0`` at the solution / root of ``solution``. - argnums: (int or tuple of ints) - Specifies arguments to compute gradients with respect to. The ``argnums`` can be an - integer or a tuple of integers, which respect to the zero-based indices of the arguments - of the ``solver_fn(params, *args)`` function. The argument ``params`` is included - for the counting, while it is indexed as ``argnums=0``. - has_aux: (default: :data:`False`) - Whether the decorated solver function returns auxiliary data. - solve: (callable, optional, default: :func:`linear_solve.solve_normal_cg`) - a linear solver of the form ``solve(matvec, b)``. + optimality_fn (callable): An equation function, ``optimality_fn(params, *args)``. The + invariant is ``optimality_fn(solution, *args) == 0`` at the solution / root of + ``solution``. + argnums (int or tuple of int): Specifies arguments to compute gradients with respect to. The + ``argnums`` can be an integer or a tuple of integers, which respect to the zero-based + indices of the arguments of the ``solver_fn(params, *args)`` function. The argument + ``params`` is included for the counting, while it is indexed as ``argnums=0``. + has_aux (bool, optional): Whether the decorated solver function returns auxiliary data. + (default: :data:`False`) + solve (callable, optional): A linear solver of the form ``solve(matvec, b)``. + (default: :func:`linear_solve.solve_normal_cg`) Returns: A solver function decorator, i.e., ``custom_root(optimality_fn)(solver_fn)``. diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py index f9bff4de..bbae37c9 100644 --- a/torchopt/diff/implicit/nn/module.py +++ b/torchopt/diff/implicit/nn/module.py @@ -16,10 +16,12 @@ # pylint: disable=redefined-builtin +from __future__ import annotations + import abc import functools import itertools -from typing import Any, Iterable, Optional, Tuple, Type +from typing import Any, Iterable import functorch import torch @@ -38,7 +40,7 @@ def _stateless_objective_fn( __flat_meta_params: TupleOfTensors, __params_names: Iterable[str], __meta_params_names: Iterable[str], - self: 'ImplicitMetaGradientModule', + self: ImplicitMetaGradientModule, *input, **kwargs, ) -> torch.Tensor: @@ -57,7 +59,7 @@ def _stateless_optimality_fn( __flat_meta_params: TupleOfTensors, __params_names: Iterable[str], __meta_params_names: Iterable[str], - self: 'ImplicitMetaGradientModule', + self: ImplicitMetaGradientModule, *input, **kwargs, ) -> TupleOfTensors: @@ -72,8 +74,8 @@ def _stateless_optimality_fn( def make_optimality_from_objective( - cls: Type['ImplicitMetaGradientModule'], -) -> Type['ImplicitMetaGradientModule']: + cls: type[ImplicitMetaGradientModule], +) -> type[ImplicitMetaGradientModule]: """Derives the optimality function of the objective function.""" if ( getattr(cls, 'objective', ImplicitMetaGradientModule.objective) @@ -81,7 +83,7 @@ def make_optimality_from_objective( ): raise TypeError('The objective function is not defined.') - def optimality(self: 'ImplicitMetaGradientModule', *input, **kwargs) -> TupleOfTensors: + def optimality(self: ImplicitMetaGradientModule, *input, **kwargs) -> TupleOfTensors: params_names, flat_params = tuple(zip(*self.named_parameters())) meta_params_names, flat_meta_params = tuple(zip(*self.named_meta_parameters())) @@ -102,8 +104,8 @@ def optimality(self: 'ImplicitMetaGradientModule', *input, **kwargs) -> TupleOfT def enable_implicit_gradients( - cls: Type['ImplicitMetaGradientModule'], -) -> Type['ImplicitMetaGradientModule']: + cls: type[ImplicitMetaGradientModule], +) -> type[ImplicitMetaGradientModule]: """Enable implicit gradients for the :func:`solve` method.""" cls_solve = cls.solve if getattr(cls_solve, '__implicit_gradients_enabled__', False): @@ -122,17 +124,17 @@ def stateless_solver_fn( __params_names: Iterable[str], __meta_params_names: Iterable[str], # pylint: enable=unused-argument - self: 'ImplicitMetaGradientModule', + self: ImplicitMetaGradientModule, *input, **kwargs, - ) -> Tuple[TupleOfTensors, Any]: + ) -> tuple[TupleOfTensors, Any]: """Solve the optimization problem.""" output = cls_solve(self, *input, **kwargs) flat_optimal_params = tuple(p.detach_() for p in self.parameters()) return flat_optimal_params, output @functools.wraps(cls_solve) - def wrapped(self: 'ImplicitMetaGradientModule', *input, **kwargs) -> Any: + def wrapped(self: ImplicitMetaGradientModule, *input, **kwargs) -> Any: """Solve the optimization problem.""" params_names, flat_params = tuple(zip(*self.named_parameters())) meta_params_names, flat_meta_params = tuple(zip(*self.named_meta_parameters())) @@ -159,9 +161,9 @@ class ImplicitMetaGradientModule(MetaGradientModule): _custom_optimality: bool _custom_objective: bool - linear_solve: Optional[LinearSolver] + linear_solve: LinearSolver | None - def __init_subclass__(cls, linear_solve: Optional[LinearSolver] = None) -> None: + def __init_subclass__(cls, linear_solve: LinearSolver | None = None) -> None: """Validate and initialize the subclass.""" super().__init_subclass__() cls.linear_solve = linear_solve diff --git a/torchopt/diff/zero_order/decorator.py b/torchopt/diff/zero_order/decorator.py index 80664d8b..43522028 100644 --- a/torchopt/diff/zero_order/decorator.py +++ b/torchopt/diff/zero_order/decorator.py @@ -14,8 +14,10 @@ # ============================================================================== """Zero-Order Gradient Estimation.""" +from __future__ import annotations + import functools -from typing import Any, Callable, List, Sequence, Tuple, Union +from typing import Any, Callable, Sequence from typing_extensions import Literal # Python 3.8+ from typing_extensions import TypeAlias # Python 3.10+ @@ -33,9 +35,7 @@ def __init__(self, sample_fn: SampleFunc) -> None: """Wrap a sample function to make it a :class:`Samplable` object.""" self.sample_fn = sample_fn - def sample( - self, sample_shape: torch.Size = torch.Size() - ) -> Union[torch.Tensor, Sequence[Numeric]]: + def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor | Sequence[Numeric]: # pylint: disable-next=line-too-long """Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.""" return self.sample_fn(sample_shape) @@ -44,14 +44,14 @@ def sample( def _zero_order_naive( # pylint: disable=too-many-statements fn: Callable[..., torch.Tensor], distribution: Samplable, - argnums: Tuple[int, ...], + argnums: tuple[int, ...], num_samples: int, sigma: Numeric, ) -> Callable[..., torch.Tensor]: @functools.wraps(fn) def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements diff_params = [args[argnum] for argnum in argnums] - flat_diff_params: List[Any] + flat_diff_params: list[Any] flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params) # type: ignore[arg-type] class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-method @@ -59,7 +59,7 @@ class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-m def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: flat_diff_params = args[:-1] origin_args = list(args[-1][0]) - flat_args: List[Any] + flat_args: list[Any] flat_args, args_treespec = pytree.tree_flatten(origin_args, none_is_leaf=True) # type: ignore[arg-type] ctx.args_treespec = args_treespec @@ -107,7 +107,7 @@ def backward( # pylint: disable=too-many-locals flat_args.append(non_tensors[non_tensors_counter]) non_tensors_counter += 1 - args: List[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment] + args: list[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment] def add_perturbation(tensor, noises): return tensor.add(noises, alpha=sigma) @@ -119,7 +119,7 @@ def add_perturbation(tensor, noises): flat_noisy_params = [ add_perturbation(t, n) for t, n in zip(flat_diff_params, noises) ] - noisy_params: List[Any] = pytree.tree_unflatten( # type: ignore[assignment] + noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment] diff_params_treespec, flat_noisy_params ) @@ -145,14 +145,14 @@ def add_perturbation(tensor, noises): def _zero_order_forward( # pylint: disable=too-many-statements fn: Callable[..., torch.Tensor], distribution: Samplable, - argnums: Tuple[int, ...], + argnums: tuple[int, ...], num_samples: int, sigma: Numeric, ) -> Callable[..., torch.Tensor]: @functools.wraps(fn) def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements diff_params = [args[argnum] for argnum in argnums] - flat_diff_params: List[Any] + flat_diff_params: list[Any] flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params) # type: ignore[arg-type] class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-method @@ -160,7 +160,7 @@ class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-m def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: flat_diff_params = args[:-1] origin_args = list(args[-1][0]) - flat_args: List[Any] + flat_args: list[Any] flat_args, args_treespec = pytree.tree_flatten(origin_args, none_is_leaf=True) # type: ignore[arg-type] ctx.args_treespec = args_treespec @@ -209,7 +209,7 @@ def backward( # pylint: disable=too-many-locals flat_args.append(non_tensors[non_tensors_counter]) non_tensors_counter += 1 - args: List[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment] + args: list[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment] def add_perturbation(tensor, noises): return tensor.add(noises, alpha=sigma) @@ -221,7 +221,7 @@ def add_perturbation(tensor, noises): flat_noisy_params = [ add_perturbation(t, n) for t, n in zip(flat_diff_params, noises) ] - noisy_params: List[Any] = pytree.tree_unflatten( # type: ignore[assignment] + noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment] diff_params_treespec, flat_noisy_params ) @@ -248,14 +248,14 @@ def add_perturbation(tensor, noises): def _zero_order_antithetic( # pylint: disable=too-many-statements fn: Callable[..., torch.Tensor], distribution: Samplable, - argnums: Tuple[int, ...], + argnums: tuple[int, ...], num_samples: int, sigma: Numeric, ) -> Callable[..., torch.Tensor]: @functools.wraps(fn) def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements diff_params = [args[argnum] for argnum in argnums] - flat_diff_params: List[Any] + flat_diff_params: list[Any] flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params) # type: ignore[arg-type] class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-method @@ -263,7 +263,7 @@ class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-m def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: flat_diff_params = args[:-1] origin_args = list(args[-1][0]) - flat_args: List[Any] + flat_args: list[Any] flat_args, args_treespec = pytree.tree_flatten(origin_args, none_is_leaf=True) # type: ignore[arg-type] ctx.args_treespec = args_treespec @@ -309,7 +309,7 @@ def backward(ctx: Any, *grad_outputs: Any): # pylint: disable=too-many-locals flat_args.append(non_tensors[non_tensors_counter]) non_tensors_counter += 1 - args: List[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment] + args: list[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment] param_grads: ListOfTensors = [0.0 for _ in range(len(flat_diff_params))] # type: ignore[misc] @@ -318,7 +318,7 @@ def get_output(add_perturbation_fn, noises) -> torch.Tensor: add_perturbation_fn(t, n, alpha=sigma) for t, n in zip(flat_diff_params, noises) ] - noisy_params: List[Any] = pytree.tree_unflatten( # type: ignore[assignment] + noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment] diff_params_treespec, flat_noisy_params ) @@ -349,28 +349,28 @@ def get_output(add_perturbation_fn, noises) -> torch.Tensor: def zero_order( - distribution: Union[SampleFunc, Samplable], + distribution: SampleFunc | Samplable, method: Method = 'naive', - argnums: Union[int, Tuple[int, ...]] = (0,), + argnums: int | tuple[int, ...] = (0,), num_samples: int = 1, sigma: Numeric = 1.0, ) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: """Return a decorator for applying zero-order differentiation. Args: - distribution: (function or Samplable) - A samplable object that has method ``samplable.sample(sample_shape)`` or a function that - takes the shape as input and returns a shaped batch of samples. This is used to sample - perturbations from the given distribution. The distribution should be sphere symmetric. - method: (str) - The algorithm to use. The currently supported algorithms are :const:`'naive'`, - :const:`'forward'`, and :const:`'antithetic'`. Defaults to :const:`'naive'`. - argnums: (int or tuple of int, default: :const:`0`) - Specifies arguments to compute gradients with respect to. - num_samples: (int, default :const:`1`) - The number of sample to get the averaged estimated gradient. - sigma: (Numeric) - The standard deviation of the perturbation. Defaults to :const:`1.0`. + distribution (callable or Samplable): A samplable object that has method + ``samplable.sample(sample_shape)`` or a function that takes the shape as input and + returns a shaped batch of samples. This is used to sample perturbations from the given + distribution. The distribution should be sphere symmetric. + method (str, optional): The algorithm to use. The currently supported algorithms are + :const:`'naive'`, :const:`'forward'`, and :const:`'antithetic'`. + (default: :const:`'naive'`) + argnums (int or tuple of int, optional): Specifies arguments to compute gradients with + respect to. (default: :const:`0`) + num_samples (int, optional): The number of sample to get the averaged estimated gradient. + (default: :const:`1`) + sigma (float or Tensor, optional): The standard deviation of the perturbation. + (default: :const:`1.0`) Returns: A function decorator that enables zero-order gradient estimation. diff --git a/torchopt/diff/zero_order/nn/module.py b/torchopt/diff/zero_order/nn/module.py index d76ac444..65014fb9 100644 --- a/torchopt/diff/zero_order/nn/module.py +++ b/torchopt/diff/zero_order/nn/module.py @@ -16,9 +16,11 @@ # pylint: disable=redefined-builtin +from __future__ import annotations + import abc import functools -from typing import Sequence, Type, Union +from typing import Sequence import torch import torch.nn as nn @@ -32,11 +34,11 @@ def enable_zero_order_gradients( - cls: Type['ZeroOrderGradientModule'], + cls: type[ZeroOrderGradientModule], method: Method = 'naive', num_samples: int = 1, sigma: Numeric = 1.0, -) -> Type['ZeroOrderGradientModule']: +) -> type[ZeroOrderGradientModule]: """Enable zero-order gradient estimation for the :func:`forward` method.""" cls_forward = cls.forward if getattr(cls_forward, '__zero_order_gradients_enabled__', False): @@ -45,7 +47,7 @@ def enable_zero_order_gradients( ) @functools.wraps(cls_forward) - def wrapped(self: 'ZeroOrderGradientModule', *input, **kwargs) -> torch.Tensor: + def wrapped(self: ZeroOrderGradientModule, *input, **kwargs) -> torch.Tensor: """Do the forward pass calculation.""" params_names, flat_params = tuple(zip(*self.named_parameters())) @@ -91,7 +93,7 @@ def forward(self, *args, **kwargs) -> torch.Tensor: @abc.abstractmethod def sample( self, sample_shape: torch.Size = torch.Size() # pylint: disable=unused-argument - ) -> Union[torch.Tensor, Sequence[Numeric]]: + ) -> torch.Tensor | Sequence[Numeric]: # pylint: disable-next=line-too-long """Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.""" raise NotImplementedError diff --git a/torchopt/distributed/api.py b/torchopt/distributed/api.py index 53f87fba..b46ad67e 100644 --- a/torchopt/distributed/api.py +++ b/torchopt/distributed/api.py @@ -14,6 +14,8 @@ # ============================================================================== """Distributed APIs.""" +from __future__ import annotations + import functools import sys from typing import ( @@ -73,8 +75,8 @@ class TensorDimensionPartitioner: while the non-tensor values will be broadcasted to partitions. Args: - dim: The dimension to partition. - exclusive: Whether to partition the batch exclusively. + dim (int): The dimension to partition. + exclusive (bool, optional): Whether to partition the batch exclusively. (default: :data:`False`) If :data:`True`, the batch will be partitioned into ``batch_size`` partitions, where ``batch_size`` is the size of the batch along the given dimension. Each batch sample will be assigned to a separate RPC call. @@ -82,11 +84,12 @@ class TensorDimensionPartitioner: partitions, where ``num_workers`` is the number of workers in the world. When ``batch_size > num_workers``, there can be multiple batch samples forward in a single RPC call. - keepdim: Whether to keep the partitioned dimension. Defaults to :data:`True`, i.e., keep the - batch dimension. If :data:`False`, use select instead of slicing. This functionality - should be used with ``exclusive=True``. - workers: The workers to partition the batch to. If :data:`None`, the batch will be - partitioned to all workers in the world. + keepdim (bool, optional): Whether to keep the partitioned dimension. (default: :data:`True`) + If :data:`True`, keep the batch dimension. If :data:`False`, use select instead of + slicing. This functionality should be used with ``exclusive=True``. + workers (sequence of int or str, or None, optional): The workers to partition the batch to. + If :data:`None`, the batch will be partitioned to all workers in the world. + (default: :data:`None`) """ def __init__( @@ -95,7 +98,7 @@ def __init__( *, exclusive: bool = False, keepdim: bool = False, - workers: Optional[Sequence[Union[int, str]]] = None, + workers: Sequence[int | str] | None = None, ) -> None: """Initialize the partitioner instance.""" if not keepdim and not exclusive: @@ -111,7 +114,7 @@ def __call__( self, *args: Any, **kwargs: Any, - ) -> List[Tuple[int, Optional[Args], Optional[KwArgs]]]: + ) -> list[tuple[int, Args | None, KwArgs | None]]: """Partition the batch of inputs along the given dimension.""" if self.workers is None: workers = list(range(get_world_size())) @@ -120,7 +123,7 @@ def __call__( num_workers = len(workers) args_tree = (args, kwargs) - flat_args: List[Any] + flat_args: list[Any] flat_args, treespec = pytree.tree_flatten(args_tree) # type: ignore[arg-type] batch_size = None @@ -137,8 +140,8 @@ def __call__( if batch_size is None: return [(get_world_rank(), args, kwargs.copy())] - dim_slices: List[Union[int, slice]] - batch_slices: List[Tuple[Union[int, slice, Ellipsis.__class__], ...]] # type: ignore[name-defined] + dim_slices: list[int | slice] + batch_slices: list[tuple[int | slice | Ellipsis.__class__, ...]] # type: ignore[name-defined] if self.exclusive: num_replicas = batch_size if self.keepdim: @@ -172,7 +175,7 @@ def __call__( for dim_slice in dim_slices ] - flat_args_replicas: List[List[Any]] = [[] for _ in range(num_replicas)] + flat_args_replicas: list[list[Any]] = [[] for _ in range(num_replicas)] for arg in flat_args: if isinstance(arg, torch.Tensor): for i, batch_slice in enumerate(batch_slices): @@ -181,7 +184,7 @@ def __call__( for i in range(num_replicas): flat_args_replicas[i].append(arg) - args_replicas: List[Tuple[Args, KwArgs]] = [ + args_replicas: list[tuple[Args, KwArgs]] = [ pytree.tree_unflatten(treespec, args_replica) # type: ignore[misc] for args_replica in flat_args_replicas ] @@ -193,10 +196,10 @@ def __call__( def __reduce__( self, - ) -> Tuple[ - Callable[..., 'TensorDimensionPartitioner'], - Tuple[int], - Dict[str, Union[bool, Optional[Sequence[Union[int, str]]]]], + ) -> tuple[ + Callable[..., TensorDimensionPartitioner], + tuple[int], + dict[str, bool | Sequence[int | str] | None], ]: """Return a tuple that allows the partitioner to be pickled.""" return ( @@ -211,7 +214,7 @@ def dim_partitioner( *, exclusive: bool = False, keepdim: bool = True, - workers: Optional[Sequence[Union[int, str]]] = None, + workers: Sequence[int | str] | None = None, ) -> PartitionFunction: """Partition a batch of inputs along a given dimension. @@ -219,8 +222,8 @@ def dim_partitioner( while the non-tensor values will be broadcasted to partitions. Args: - dim: The dimension to partition. - exclusive: Whether to partition the batch exclusively. + dim (int, optional): The dimension to partition. (default: :const:`0`) + exclusive (bool, optional): Whether to partition the batch exclusively. (default: :data:`False`) If :data:`True`, the batch will be partitioned into ``batch_size`` partitions, where ``batch_size`` is the size of the batch along the given dimension. Each batch sample will be assigned to a separate RPC call. @@ -228,11 +231,12 @@ def dim_partitioner( partitions, where ``num_workers`` is the number of workers in the world. When ``batch_size > num_workers``, there can be multiple batch samples forward in a single RPC call. - keepdim: Whether to keep the partitioned dimension. Defaults to :data:`True`, i.e., keep the - batch dimension. If :data:`False`, use select instead of slicing. This functionality - should be used with ``exclusive=True``. - workers: The workers to partition the batch to. If :data:`None`, the batch will be - partitioned to all workers in the world. + keepdim (bool, optional): Whether to keep the partitioned dimension. (default: :data:`False`) + If :data:`True`, keep the batch dimension. If :data:`False`, use select instead of + slicing. This functionality should be used with ``exclusive=True``. + workers (sequence of int or str, or None, optional): The workers to partition the batch to. + If :data:`None`, the batch will be partitioned to all workers in the world. + (default: :data:`None`) Returns: A partition function. @@ -273,26 +277,26 @@ def sum_reducer(results: Iterable[torch.Tensor]) -> torch.Tensor: def remote_async_call( func: Callable[..., T], *, - args: Optional[Args] = None, - kwargs: Optional[KwArgs] = None, - partitioner: Optional[Partitioner] = None, - reducer: Optional[Callable[[Iterable[T]], U]] = None, - timeout: Optional[float] = UNSET_RPC_TIMEOUT, -) -> Union[Future[List[T]], Future[U]]: + args: Args | None = None, + kwargs: KwArgs | None = None, + partitioner: Partitioner | None = None, + reducer: Callable[[Iterable[T]], U] | None = None, + timeout: float | None = UNSET_RPC_TIMEOUT, +) -> Future[list[T]] | Future[U]: """Asynchronously do an RPC on remote workers and return the a :class:`torch.Future` instance at the current worker. Args: - func (Callable[..., T]): The function to call. - args (Optional[Args], optional): The arguments to pass to the function. Defaults to - :data:`None`. - kwargs (Optional[KwArgs], optional): The keyword arguments to pass to the function. Defaults - to :data:`None`. - partitioner (Partitioner, optional): A partitioner that partitions the arguments to multiple - workers. Defaults to :func:`batch_partitioner`. - reducer (Callable[[Iterable[T]], U], optional): A reducer that reduces the results from - multiple workers. Defaults to :data:`None`. - timeout (float, optional): The timeout for the RPC call. Defaults to - :data:`rpc.api.UNSET_RPC_TIMEOUT`. + func (callable): The function to call. + args (tuple of object or None, optional): The arguments to pass to the function. + (default: :data:`None`) + kwargs (dict[str, object] or None, optional): The keyword arguments to pass to the function. + (default: :data:`None`) + partitioner (int, str, or callable, optional): A partitioner that partitions the arguments + to multiple workers. (default: :func:`batch_partitioner`) + reducer (callable or None, optional): A reducer that reduces the results from multiple + workers. If :data:`None`, do not reduce the results. (default: :data:`None`) + timeout (float, optional): The timeout for the RPC call. + (default: :data:`rpc.api.UNSET_RPC_TIMEOUT`) Returns: A :class:`torch.Future` instance for the result. The result is at the current worker. @@ -330,26 +334,26 @@ def remote_async_call( def remote_sync_call( func: Callable[..., T], *, - args: Optional[Args] = None, - kwargs: Optional[KwArgs] = None, - partitioner: Optional[Partitioner] = None, - reducer: Optional[Callable[[Iterable[T]], U]] = None, - timeout: Optional[float] = UNSET_RPC_TIMEOUT, -) -> Union[List[T], U]: + args: Args | None = None, + kwargs: KwArgs | None = None, + partitioner: Partitioner | None = None, + reducer: Callable[[Iterable[T]], U] | None = None, + timeout: float | None = UNSET_RPC_TIMEOUT, +) -> list[T] | U: """Do an RPC synchronously on remote workers and return the result to the current worker. Args: - func (Callable[..., T]): The function to call. - args (Optional[Args], optional): The arguments to pass to the function. Defaults to - :data:`None`. - kwargs (Optional[KwArgs], optional): The keyword arguments to pass to the function. Defaults - to :data:`None`. - partitioner (Partitioner, optional): A partitioner that partitions the arguments to multiple - workers. Defaults to :func:`batch_partitioner`. - reducer (Callable[[Iterable[T]], U], optional): A reducer that reduces the results from - multiple workers. Defaults to :data:`None`. - timeout (float, optional): The timeout for the RPC call. Defaults to - :data:`rpc.api.UNSET_RPC_TIMEOUT`. + func (callable): The function to call. + args (tuple of object or None, optional): The arguments to pass to the function. + (default: :data:`None`) + kwargs (dict[str, object] or None, optional): The keyword arguments to pass to the function. + (default: :data:`None`) + partitioner (int, str, or callable, optional): A partitioner that partitions the arguments + to multiple workers. (default: :func:`batch_partitioner`) + reducer (callable or None, optional): A reducer that reduces the results from multiple + workers. If :data:`None`, do not reduce the results. (default: :data:`None`) + timeout (float, optional): The timeout for the RPC call. + (default: :data:`rpc.api.UNSET_RPC_TIMEOUT`) Returns: The result of the RPC call. The result is at the current worker. @@ -365,10 +369,10 @@ def remote_sync_call( def parallelize_async( - partitioner: Optional[Partitioner] = None, - reducer: Optional[Callable[[Iterable[T]], U]] = None, - timeout: Optional[float] = UNSET_RPC_TIMEOUT, -) -> Callable[[Callable[..., T]], Callable[..., Union[Future[List[T]], Future[U]]]]: + partitioner: Partitioner | None = None, + reducer: Callable[[Iterable[T]], U] | None = None, + timeout: float | None = UNSET_RPC_TIMEOUT, +) -> Callable[[Callable[..., T]], Callable[..., Future[list[T]] | Future[U]]]: """Return a decorator for parallelizing a function. This decorator can be used to parallelize a function call across multiple workers. The @@ -376,13 +380,12 @@ def parallelize_async( return a :class:`torch.Future` instance of the result. Args: - partitioner (Partitioner, optional): A partitioner that partitions the arguments to multiple - workers. Defaults to :func:`batch_partitioner`. - reducer (Callable[[Iterable[T]], U], optional): A reducer that reduces the results from - multiple workers. Defaults to :func:`mean_reducer` if the ``partitioner`` is not - specified, i.e., :func:`batch_partitioner`. Otherwise, it defaults to :data:`None`. - timeout (float, optional): The timeout for the RPC call. Defaults to - :data:`rpc.api.UNSET_RPC_TIMEOUT`. + partitioner (int, str, or callable, optional): A partitioner that partitions the arguments + to multiple workers. (default: :func:`batch_partitioner`) + reducer (callable or None, optional): A reducer that reduces the results from multiple + workers. If :data:`None`, do not reduce the results. (default: :data:`None`) + timeout (float, optional): The timeout for the RPC call. + (default: :data:`rpc.api.UNSET_RPC_TIMEOUT`) Returns: The decorator function. @@ -392,9 +395,9 @@ def parallelize_async( if reducer is None: reducer = mean_reducer # type: ignore[assignment] - def wrapper(func: Callable[..., T]) -> Callable[..., Union[Future[List[T]], Future[U]]]: + def wrapper(func: Callable[..., T]) -> Callable[..., Future[list[T]] | Future[U]]: @functools.wraps(func) - def wrapped(*args: Any, **kwargs: Any) -> Union[Future[List[T]], Future[U]]: + def wrapped(*args: Any, **kwargs: Any) -> Future[list[T]] | Future[U]: return remote_async_call( func, args=args, @@ -423,22 +426,21 @@ def wrapped(*args: Any, **kwargs: Any) -> Union[Future[List[T]], Future[U]]: def parallelize( - partitioner: Optional[Partitioner] = None, - reducer: Optional[Callable[[Iterable[T]], U]] = None, - timeout: Optional[float] = UNSET_RPC_TIMEOUT, -) -> Callable[[Callable[..., T]], Callable[..., Union[List[T], U]]]: + partitioner: Partitioner | None = None, + reducer: Callable[[Iterable[T]], U] | None = None, + timeout: float | None = UNSET_RPC_TIMEOUT, +) -> Callable[[Callable[..., T]], Callable[..., list[T] | U]]: """Return a decorator for parallelizing a function. This decorator can be used to parallelize a function call across multiple workers. Args: - partitioner (Partitioner, optional): A partitioner that partitions the arguments to multiple - workers. Defaults to :func:`batch_partitioner`. - reducer (Callable[[Iterable[T]], U], optional): A reducer that reduces the results from - multiple workers. Defaults to :func:`mean_reducer` if the ``partitioner`` is not - specified, i.e., :func:`batch_partitioner`. Otherwise, it defaults to :data:`None`. - timeout (float, optional): The timeout for the RPC call. Defaults to - :data:`rpc.api.UNSET_RPC_TIMEOUT`. + partitioner (int, str, or callable, optional): A partitioner that partitions the arguments + to multiple workers. (default: :func:`batch_partitioner`) + reducer (callable or None, optional): A reducer that reduces the results from multiple + workers. If :data:`None`, do not reduce the results. (default: :data:`None`) + timeout (float, optional): The timeout for the RPC call. + (default: :data:`rpc.api.UNSET_RPC_TIMEOUT`) Returns: The decorator function. @@ -448,9 +450,9 @@ def parallelize( if reducer is None: reducer = mean_reducer # type: ignore[assignment] - def wrapper(func: Callable[..., T]) -> Callable[..., Union[List[T], U]]: + def wrapper(func: Callable[..., T]) -> Callable[..., list[T] | U]: @functools.wraps(func) - def wrapped(*args: Any, **kwargs: Any) -> Union[List[T], U]: + def wrapped(*args: Any, **kwargs: Any) -> list[T] | U: return remote_sync_call( func, args=args, diff --git a/torchopt/distributed/autograd.py b/torchopt/distributed/autograd.py index 5fe51278..17fa9463 100644 --- a/torchopt/distributed/autograd.py +++ b/torchopt/distributed/autograd.py @@ -14,14 +14,15 @@ # ============================================================================== """Distributed Autograd.""" +from __future__ import annotations + from threading import Lock -from typing import Optional, overload import torch import torch.distributed.autograd as autograd from torch.distributed.autograd import context -from torchopt.typing import TensorOrTensors, TupleOfOptionalTensors, TupleOfTensors +from torchopt.typing import TensorOrTensors, TupleOfOptionalTensors __all__ = ['is_available', 'context'] @@ -43,22 +44,23 @@ def backward( autograd_ctx_id: int, tensors: TensorOrTensors, retain_graph: bool = False, - inputs: Optional[TensorOrTensors] = None, + inputs: TensorOrTensors | None = None, ) -> None: """Perform distributed backward pass for local parameters. Compute the sum of gradients of given tensors with respect to graph leaves. Args: - autograd_ctx_id: The autograd context id. - tensors (Sequence[Tensor] or Tensor): Tensors of which the derivative will be computed. + autograd_ctx_id (int): The autograd context id. + tensors (Tensor or sequence of Tensor): Tensors of which the derivative will be computed. retain_graph (bool, optional): If :data:`False`, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to :data:`True` is not needed and often can be worked around in a much more efficient way. - inputs (Sequence[Tensor] or Tensor, optional): Inputs w.r.t. which the gradient be will - accumulated into ``.grad``. All other Tensors will be ignored. If not provided, the - gradient is accumulated into all the leaf Tensors that were used to compute the - attr::tensors. + (default: :data:`False`) + inputs (Tensor, sequence of Tensor, or None, optional): Inputs w.r.t. which the gradient + be will accumulated into ``.grad``. All other Tensors will be ignored. If not + provided, the gradient is accumulated into all the leaf Tensors that were used to + compute the ``tensors``. (default: :data:`None`) """ if inputs is not None: if isinstance(inputs, torch.Tensor): @@ -85,25 +87,6 @@ def backward( else: p.grad = g - @overload - def grad( - autograd_ctx_id: int, - outputs: TensorOrTensors, - inputs: TensorOrTensors, - retain_graph: bool = False, - ) -> TupleOfTensors: - ... - - @overload - def grad( - autograd_ctx_id: int, - outputs: TensorOrTensors, - inputs: TensorOrTensors, - retain_graph: bool = False, - allow_unused: bool = False, - ) -> TupleOfOptionalTensors: - ... - def grad( autograd_ctx_id: int, outputs: TensorOrTensors, @@ -114,16 +97,17 @@ def grad( """Compute and return the sum of gradients of outputs with respect to the inputs. Args: - autograd_ctx_id: The autograd context id. - outputs (sequence of Tensor): outputs of the differentiated function. - inputs (sequence of Tensor): Inputs w.r.t. which the gradient will be returned (and not - accumulated into ``.grad``). + autograd_ctx_id (int): The autograd context id. + outputs (Tensor or sequence of Tensor): Outputs of the differentiated function. + inputs (Tensor or sequence of Tensor): Inputs w.r.t. which the gradient will be returned + (and not accumulated into ``.grad``). retain_graph (bool, optional): If :data:`False`, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to :data:`True` is not needed and often can be worked around in a much more efficient way. + (default: :data:`False`) allow_unused (bool, optional): If :data:`False`, specifying inputs that were not used when computing outputs (and therefore their grad is always zero) is an error. - Defaults to :data:`False`. + (default: :data:`False`) """ outputs = [outputs] if isinstance(outputs, torch.Tensor) else list(outputs) inputs = (inputs,) if isinstance(inputs, torch.Tensor) else tuple(inputs) diff --git a/torchopt/distributed/world.py b/torchopt/distributed/world.py index 45140df1..804d4b9d 100644 --- a/torchopt/distributed/world.py +++ b/torchopt/distributed/world.py @@ -14,10 +14,12 @@ # ============================================================================== """Utilities for gathering information about the world.""" +from __future__ import annotations + import atexit import functools import os -from typing import Any, Callable, Iterable, NamedTuple, Optional, TypeVar, Union +from typing import Any, Callable, Iterable, NamedTuple, TypeVar import torch.distributed.rpc as rpc from torch.distributed.elastic.multiprocessing.errors import record @@ -127,32 +129,33 @@ def get_local_world_size() -> int: # pylint: disable-next=redefined-builtin,invalid-name -def get_worker_id(id: Optional[Union[str, int]] = None) -> int: +def get_worker_id(id: str | int | None = None) -> int: """Get the worker id from the given id.""" if isinstance(id, int): return id return rpc.get_worker_info(worker_name=id).id -def barrier(worker_names: Optional[Iterable[str]] = None) -> None: +def barrier(worker_names: Iterable[str] | None = None) -> None: r"""Synchronize local and remote RPC processes. This will block until all local and remote RPC processes specified under worker_names reach this method to wait for all outstanding work to complete. Args: - worker_names: The set of workers to synchronize. If :data:`None`, all workers. + worker_names (iterable of str or None, optional): The set of workers to synchronize. + If :data:`None`, all workers. (default: :data:`None`) """ worker_names = {} if worker_names is None else set(worker_names) rpc.api._barrier(worker_names) # pylint: disable=protected-access def auto_init_rpc( - worker_init_fn: Optional[Callable[[], None]] = None, + worker_init_fn: Callable[[], None] | None = None, worker_name_format: Callable[..., str] = default_worker_name_format, *, - backend: Optional['rpc.BackendType'] = None, - rpc_backend_options: Optional['rpc.RpcBackendOptions'] = None, + backend: rpc.BackendType | None = None, + rpc_backend_options: rpc.RpcBackendOptions | None = None, ) -> Callable[[F], F]: """Return a decorator to automatically initialize RPC on the decorated function.""" global _WORKER_NAME_FORMAT # pylint: disable=global-statement diff --git a/torchopt/hook.py b/torchopt/hook.py index 949c76e7..f188415c 100644 --- a/torchopt/hook.py +++ b/torchopt/hook.py @@ -14,7 +14,9 @@ # ============================================================================== """Hook utilities.""" -from typing import Callable, Optional, Tuple +from __future__ import annotations + +from typing import Callable import torch @@ -32,7 +34,7 @@ def zero_nan_hook(g: torch.Tensor) -> torch.Tensor: def nan_to_num_hook( - nan: float = 0.0, posinf: Optional[float] = None, neginf: Optional[float] = None + nan: float = 0.0, posinf: float | None = None, neginf: float | None = None ) -> Callable[[torch.Tensor], torch.Tensor]: """Return a ``nan`` to num hook to replace ``nan`` / ``+inf`` / ``-inf`` with the given numbers.""" @@ -59,9 +61,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, # pylint: disable=unused-argument + params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, # pylint: disable=unused-argument - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: def f(g): return g.register_hook(hook) diff --git a/torchopt/linalg/cg.py b/torchopt/linalg/cg.py index 94daee53..5456f076 100644 --- a/torchopt/linalg/cg.py +++ b/torchopt/linalg/cg.py @@ -33,8 +33,10 @@ # pylint: disable=invalid-name +from __future__ import annotations + from functools import partial -from typing import Callable, Optional, Union +from typing import Callable import torch @@ -100,14 +102,14 @@ def body_fn(value): def _isolve( _isolve_solve: Callable, - A: Union[TensorTree, Callable[[TensorTree], TensorTree]], + A: TensorTree | Callable[[TensorTree], TensorTree], b: TensorTree, - x0: Optional[TensorTree] = None, + x0: TensorTree | None = None, *, rtol: float = 1e-5, atol: float = 0.0, - maxiter: Optional[int] = None, - M: Optional[Union[TensorTree, Callable[[TensorTree], TensorTree]]] = None, + maxiter: int | None = None, + M: TensorTree | Callable[[TensorTree], TensorTree] | None = None, ) -> TensorTree: if x0 is None: x0 = pytree.tree_map(torch.zeros_like, b) @@ -133,14 +135,14 @@ def _isolve( def cg( - A: Union[TensorTree, Callable[[TensorTree], TensorTree]], + A: TensorTree | Callable[[TensorTree], TensorTree], b: TensorTree, - x0: Optional[TensorTree] = None, + x0: TensorTree | None = None, *, rtol: float = 1e-5, atol: float = 0.0, - maxiter: Optional[int] = None, - M: Optional[Union[TensorTree, Callable[[TensorTree], TensorTree]]] = None, + maxiter: int | None = None, + M: TensorTree | Callable[[TensorTree], TensorTree] | None = None, ) -> TensorTree: """Use Conjugate Gradient iteration to solve ``Ax = b``. @@ -153,30 +155,30 @@ def cg( solves converge. Args: - A: (tensor or tree of tensors or function) - 2D array or function that calculates the linear map (matrix-vector product) ``Ax`` when - called like ``A(x)``. ``A`` must represent a hermitian, positive definite matrix, and - must return array(s) with the same structure and shape as its argument. - b: (tensor or tree of tensors) - Right hand side of the linear system representing a single vector. Can be stored as an - array or Python container of array(s) with any shape. - x0: (tensor or tree of tensors, optional) - Starting guess for the solution. Must have the same structure as ``b``. - rtol: (float, optional, default: :const:`1e-5`) - Tolerances for convergence, ``norm(residual) <= max(rtol*norm(b), atol)``. We do not - implement SciPy's "legacy" behavior, so TorchOpt's tolerance will differ from SciPy - unless you explicitly pass ``atol`` to SciPy's ``cg``. - atol: (float, optional, default: :const:`0.0`) - Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``. We do not - implement SciPy's "legacy" behavior, so TorchOpt's tolerance will differ from SciPy - unless you explicitly pass ``atol`` to SciPy's ``cg``. - maxiter: (integer, optional) - Maximum number of iterations. Iteration will stop after maxiter steps even if the - specified tolerance has not been achieved. - M: (tensor or tree of tensors or function) - Pre-conditioner for ``A``. The pre-conditioner should approximate the inverse of ``A``. - Effective preconditioning dramatically improves the rate of convergence, which implies - that fewer iterations are needed to reach a given error tolerance. + A (Tensor or tree of Tensor): 2D array or function that calculates the linear map + (matrix-vector product) ``Ax`` when called like ``A(x)``. ``A`` must represent a + hermitian, positive definite matrix, and must return tensor(s) with the same structure + and shape as its argument. + b (Tensor or tree of Tensor): Right hand side of the linear system representing a single + vector. Can be stored as a tensor or Python container of tensor(s) with any shape. + x0 (Tensor, tree of Tensor, or None, optional): Starting guess for the solution. Must have + the same structure as ``b``. If :data:`None`, use zero initialization. + (default: :data:`None`) + rtol (float, optional): Tolerances for convergence, ``norm(residual) <= max(rtol*norm(b), atol)``. + We do not implement SciPy's "legacy" behavior, so TorchOpt's tolerance will differ from + SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``. (default: :const:`1e-5`) + atol (float, optional): Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``. + We do not implement SciPy's "legacy" behavior, so TorchOpt's tolerance will differ from + SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``. (default: :const:`0.0`) + maxiter (int or None, optional): Maximum number of iterations. Iteration will stop after + maxiter steps even if the specified tolerance has not been achieved. If :data:`None`, + ``10 * size`` will be used, where ``size`` is the size of the flattened input tensor(s). + (default: :data:`None`) + M (Tensor, tree of Tensor, function, or None, optional): Pre-conditioner for ``A``. The + pre-conditioner should approximate the inverse of ``A``. Effective preconditioning + dramatically improves the rate of convergence, which implies that fewer iterations are + needed to reach a given error tolerance. If :data:`None`, no pre-conditioner will be + used. (default: :data:`None`) Returns: the Conjugate Gradient (CG) linear solver diff --git a/torchopt/linalg/ns.py b/torchopt/linalg/ns.py index 04f5dd11..c1975203 100644 --- a/torchopt/linalg/ns.py +++ b/torchopt/linalg/ns.py @@ -16,13 +16,15 @@ # pylint: disable=invalid-name +from __future__ import annotations + import functools -from typing import Callable, Optional, Union +from typing import Callable import torch from torchopt import pytree -from torchopt.linalg.utils import cat_shapes, normalize_matvec +from torchopt.linalg.utils import normalize_matvec from torchopt.typing import TensorTree @@ -33,7 +35,7 @@ def _ns_solve( A: torch.Tensor, b: torch.Tensor, maxiter: int, - alpha: Optional[float] = None, + alpha: float | None = None, ) -> torch.Tensor: """Use Neumann Series Matrix Inversion Approximation to solve ``Ax = b``.""" if A.ndim != 2 or A.shape[0] != A.shape[1]: @@ -57,27 +59,26 @@ def _ns_solve( def ns( - A: Union[TensorTree, Callable[[TensorTree], TensorTree]], + A: TensorTree | Callable[[TensorTree], TensorTree], b: TensorTree, - maxiter: Optional[int] = None, + maxiter: int | None = None, *, - alpha: Optional[float] = None, + alpha: float | None = None, ) -> TensorTree: """Use Neumann Series Matrix Inversion Approximation to solve ``Ax = b``. Args: - A: (tensor or tree of tensors or function) - 2D array or function that calculates the linear map (matrix-vector product) ``Ax`` when - called like ``A(x)``. ``A`` must represent a hermitian, positive definite matrix, and - must return array(s) with the same structure and shape as its argument. - b: (tensor or tree of tensors) - Right hand side of the linear system representing a single vector. Can be stored as an - array or Python container of array(s) with any shape. - maxiter: (integer, optional) - Maximum number of iterations. Iteration will stop after maxiter steps even if the - specified tolerance has not been achieved. - alpha: (float, optional) - Decay coefficient. + A (Tensor or tree of Tensor): 2D array or function that calculates the linear map + (matrix-vector product) ``Ax`` when called like ``A(x)``. ``A`` must represent a + hermitian, positive definite matrix, and must return tensor(s) with the same structure + and shape as its argument. + b (Tensor or tree of Tensor): Right hand side of the linear system representing a single + vector. Can be stored as a tensor or Python container of tensor(s) with any shape. + maxiter (int or None, optional): Maximum number of iterations. Iteration will stop after + maxiter steps even if the specified tolerance has not been achieved. If :data:`None`, + :const:`10` will be used. (default: :const:`10`) + alpha: (float or None, optional): Decay coefficient. If :data:`None`, :const:`1.0` will be + used. (default: :const:`1.0`) Returns: The Neumann Series (NS) matrix inversion approximation. @@ -111,7 +112,7 @@ def ns( return inv_A_hat_b -def _ns_inv(A: torch.Tensor, maxiter: int, alpha: Optional[float] = None): +def _ns_inv(A: torch.Tensor, maxiter: int, alpha: float | None = None): """Use Neumann Series iteration to solve ``A^{-1}``.""" if A.ndim != 2 or A.shape[0] != A.shape[1]: raise ValueError(f'`A` must be a square matrix, but has shape: {A.shape}') @@ -134,28 +135,27 @@ def _ns_inv(A: torch.Tensor, maxiter: int, alpha: Optional[float] = None): def ns_inv( A: TensorTree, - maxiter: Optional[int] = None, + maxiter: int | None = None, *, - alpha: Optional[float] = None, + alpha: float | None = None, ) -> TensorTree: """Use Neumann Series iteration to solve ``A^{-1}``. Args: - A: (tensor or tree of tensors or function) - 2D array or function that calculates the linear map (matrix-vector product) ``Ax`` when - called like ``A(x)``. ``A`` must represent a hermitian, positive definite matrix, and - must return array(s) with the same structure and shape as its argument. - maxiter: (integer, optional) - Maximum number of iterations. Iteration will stop after maxiter steps even if the - specified tolerance has not been achieved. - alpha: (float, optional) - Decay coefficient. + A (Tensor or tree of Tensor): 2D array or function that calculates the linear map + (matrix-vector product) ``Ax`` when called like ``A(x)``. ``A`` must represent a + hermitian, positive definite matrix, and must return tensor(s) with the same structure + and shape as its argument. + maxiter (int or None, optional): Maximum number of iterations. Iteration will stop after + maxiter steps even if the specified tolerance has not been achieved. If :data:`None`, + :const:`10` will be used. (default: :const:`10`) + alpha: (float or None, optional): Decay coefficient. If :data:`None`, :const:`1.0` will be + used. (default: :const:`1.0`) Returns: The Neumann Series (NS) matrix inversion approximation. """ if maxiter is None: - size = sum(cat_shapes(A)) - maxiter = 10 * size # copied from SciPy + maxiter = 10 return pytree.tree_map(functools.partial(_ns_inv, maxiter=maxiter, alpha=alpha), A) diff --git a/torchopt/linalg/utils.py b/torchopt/linalg/utils.py index 275232be..f301a624 100644 --- a/torchopt/linalg/utils.py +++ b/torchopt/linalg/utils.py @@ -14,8 +14,10 @@ # ============================================================================== """Utilities for linear algebra.""" +from __future__ import annotations + import itertools -from typing import Callable, Tuple, Union +from typing import Callable import torch @@ -23,14 +25,14 @@ from torchopt.typing import TensorTree -def cat_shapes(tree: TensorTree) -> Tuple[int, ...]: +def cat_shapes(tree: TensorTree) -> tuple[int, ...]: """Concatenate the shapes of the leaves of a tree of tensors.""" leaves = pytree.tree_leaves(tree) return tuple(itertools.chain.from_iterable(tuple(leaf.shape) for leaf in leaves)) def normalize_matvec( - matvec: Union[TensorTree, Callable[[TensorTree], TensorTree]] + matvec: TensorTree | Callable[[TensorTree], TensorTree] ) -> Callable[[TensorTree], TensorTree]: """Normalize an argument for computing matrix-vector product.""" if callable(matvec): diff --git a/torchopt/linear_solve/cg.py b/torchopt/linear_solve/cg.py index f75ef9f4..844c9407 100644 --- a/torchopt/linear_solve/cg.py +++ b/torchopt/linear_solve/cg.py @@ -33,8 +33,10 @@ # pylint: disable=invalid-name +from __future__ import annotations + import functools -from typing import Callable, Optional +from typing import Callable from torchopt import linalg from torchopt.linear_solve.utils import make_ridge_matvec @@ -47,8 +49,8 @@ def _solve_cg( matvec: Callable[[TensorTree], TensorTree], # (x) -> A @ x b: TensorTree, - ridge: Optional[float] = None, - init: Optional[TensorTree] = None, + ridge: float | None = None, + init: TensorTree | None = None, **kwargs, ) -> TensorTree: """Solve ``A x = b`` using conjugate gradient. @@ -56,10 +58,12 @@ def _solve_cg( This assumes that ``A`` is a hermitian, positive definite matrix. Args: - matvec: A function that returns the product between ``A`` and a vector. - b: A tree of tensors for the right hand side of the equation. - ridge: Optional ridge regularization. - init: Optional initialization to be used by conjugate gradient. + matvec (callable): A function that returns the product between ``A`` and a vector. + b (Tensor or tree of Tensor): A tree of tensors for the right hand side of the equation. + ridge (float or None, optional): Optional ridge regularization. If provided, solves the + equation for ``A x + ridge x = b``. (default: :data:`None`) + init (Tensor, tree of Tensor, or None, optional): Optional initialization to be used by + conjugate gradient. If :data:`None`, uses zero initialization. (default: :data:`None`) **kwargs: Additional keyword arguments for the conjugate gradient solver. Returns: @@ -80,8 +84,10 @@ def solve_cg(**kwargs): This assumes that ``A`` is a hermitian, positive definite matrix. Args: - ridge: Optional ridge regularization. Solves the equation for ``(A + ridge * I) @ x = b``. - init: Optional initialization to be used by conjugate gradient. + ridge (float or None, optional): Optional ridge regularization. If provided, solves the + equation for ``A x + ridge x = b``. (default: :data:`None`) + init (Tensor, tree of Tensor, or None, optional): Optional initialization to be used by + conjugate gradient. If :data:`None`, uses zero initialization. (default: :data:`None`) **kwargs: Additional keyword arguments for the conjugate gradient solver :func:`torchopt.linalg.cg`. diff --git a/torchopt/linear_solve/inv.py b/torchopt/linear_solve/inv.py index c3224a52..399a0ef9 100644 --- a/torchopt/linear_solve/inv.py +++ b/torchopt/linear_solve/inv.py @@ -33,8 +33,10 @@ # pylint: disable=invalid-name +from __future__ import annotations + import functools -from typing import Callable, Optional +from typing import Callable import torch @@ -49,7 +51,7 @@ def _solve_inv( matvec: Callable[[TensorTree], TensorTree], # (x) -> A @ x b: TensorTree, - ridge: Optional[float] = None, + ridge: float | None = None, ns: bool = False, **kwargs, ) -> TensorTree: @@ -59,11 +61,13 @@ def _solve_inv( in memory. Args: - matvec: A function that returns the product between ``A`` and a vector. - b: A tensor for the right hand side of the equation. - ridge: Optional ridge regularization. Solves the equation for ``(A + ridge * I) @ x = b``. - ns: Whether to use Neumann Series matrix inversion approximation. If :data:`False`, - materialize the matrix ``A`` in memory and use :func:`torch.linalg.solve` instead. + matvec (callable): A function that returns the product between ``A`` and a vector. + b (Tensor or tree of Tensor): A tree of tensors for the right hand side of the equation. + ridge (float or None, optional): Optional ridge regularization. If provided, solves the + equation for ``A x + ridge x = b``. (default: :data:`None`) + ns (bool, optional): Whether to use Neumann Series matrix inversion approximation. + If :data:`False`, materialize the matrix ``A`` in memory and use :func:`torch.linalg.solve` + instead. (default: :data:`False`) **kwargs: Additional keyword arguments for the Neumann Series matrix inversion approximation solver :func:`torchopt.linalg.ns`. @@ -94,9 +98,11 @@ def solve_inv(**kwargs): in memory. Args: - ridge: Optional ridge regularization. Solves the equation for ``(A + ridge * I) @ x = b``. - ns: Whether to use Neumann Series matrix inversion approximation. If :data:`False`, - materialize the matrix ``A`` in memory and use :func:`torch.linalg.solve` instead. + ridge (float or None, optional): Optional ridge regularization. If provided, solves the + equation for ``A x + ridge x = b``. (default: :data:`None`) + ns (bool, optional): Whether to use Neumann Series matrix inversion approximation. + If :data:`False`, materialize the matrix ``A`` in memory and use :func:`torch.linalg.solve` + instead. (default: :data:`False`) **kwargs: Additional keyword arguments for the Neumann Series matrix inversion approximation solver :func:`torchopt.linalg.ns`. diff --git a/torchopt/linear_solve/normal_cg.py b/torchopt/linear_solve/normal_cg.py index 3199a490..8d38f77a 100644 --- a/torchopt/linear_solve/normal_cg.py +++ b/torchopt/linear_solve/normal_cg.py @@ -33,8 +33,10 @@ # pylint: disable=invalid-name +from __future__ import annotations + import functools -from typing import Callable, Optional +from typing import Callable from torchopt import linalg from torchopt.linear_solve.utils import make_normal_matvec, make_ridge_matvec, make_rmatvec @@ -47,8 +49,8 @@ def _solve_normal_cg( matvec: Callable[[TensorTree], TensorTree], # (x) -> A @ x b: TensorTree, - ridge: Optional[float] = None, - init: Optional[TensorTree] = None, + ridge: float | None = None, + init: TensorTree | None = None, **kwargs, ) -> TensorTree: """Solve the normal equation ``A^T A x = A^T b`` using conjugate gradient. @@ -57,10 +59,12 @@ def _solve_normal_cg( positive definite. Args: - matvec: A function that returns the product between ``A`` and a vector. - b: A tree of tensors for the right hand side of the equation. - ridge: Optional ridge regularization. Solves the equation for ``(A.T @ A + ridge * I) @ x = A.T @ b``. - init: Optional initialization to be used by normal conjugate gradient. + matvec (callable): A function that returns the product between ``A`` and a vector. + b (Tensor or tree of Tensor): A tree of tensors for the right hand side of the equation. + ridge (float or None, optional): Optional ridge regularization. If provided, solves the + equation for ``A^T A x + ridge x = A^T b``. (default: :data:`None`) + init (Tensor, tree of Tensor, or None, optional): Optional initialization to be used by + conjugate gradient. If :data:`None`, uses zero initialization. (default: :data:`None`) **kwargs: Additional keyword arguments for the conjugate gradient solver :func:`torchopt.linalg.cg`. @@ -93,8 +97,10 @@ def solve_normal_cg(**kwargs): positive definite. Args: - ridge: Optional ridge regularization. Solves the equation for ``(A.T @ A + ridge * I) @ x = A.T @ b``. - init: Optional initialization to be used by normal conjugate gradient. + ridge (float or None, optional): Optional ridge regularization. If provided, solves the + equation for ``A^T A x + ridge x = A^T b``. (default: :data:`None`) + init (Tensor, tree of Tensor, or None, optional): Optional initialization to be used by + conjugate gradient. If :data:`None`, uses zero initialization. (default: :data:`None`) **kwargs: Additional keyword arguments for the conjugate gradient solver :func:`torchopt.linalg.cg`. diff --git a/torchopt/linear_solve/utils.py b/torchopt/linear_solve/utils.py index 9c2f7ced..f4f34e2a 100644 --- a/torchopt/linear_solve/utils.py +++ b/torchopt/linear_solve/utils.py @@ -31,7 +31,9 @@ # ============================================================================== """Utilities for linear algebra solvers.""" -from typing import Callable, Tuple +from __future__ import annotations + +from typing import Callable import functorch @@ -75,7 +77,7 @@ def ridge_matvec(y: TensorTree) -> TensorTree: def materialize_matvec( matvec: Callable[[TensorTree], TensorTree], x: TensorTree -) -> Tuple[ +) -> tuple[ TensorTree, Callable[[TensorTree], TensorTree], Callable[[TensorTree], TensorTree], diff --git a/torchopt/nn/module.py b/torchopt/nn/module.py index 3716f674..f8804864 100644 --- a/torchopt/nn/module.py +++ b/torchopt/nn/module.py @@ -14,8 +14,10 @@ # ============================================================================== """Base class for neural network modules that hold meta-parameters and meta-modules.""" +from __future__ import annotations + from collections import OrderedDict -from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Set, Tuple, Union +from typing import Any, Iterator, NamedTuple import torch import torch.nn as nn @@ -27,8 +29,8 @@ class MetaInputsContainer(NamedTuple): """Container for parameters and modules in the constructor input arguments.""" - meta_parameters: Set[torch.Tensor] - meta_modules: Set[nn.Module] + meta_parameters: set[torch.Tensor] + meta_modules: set[nn.Module] class MetaGradientModule(nn.Module): # pylint: disable=abstract-method @@ -36,12 +38,12 @@ class MetaGradientModule(nn.Module): # pylint: disable=abstract-method _meta_inputs: MetaInputsContainer _meta_parameters: TensorContainer - _meta_modules: Dict[str, Optional[nn.Module]] + _meta_modules: dict[str, nn.Module | None] - def __new__(cls, *args, **kwargs) -> 'MetaGradientModule': + def __new__(cls, *args, **kwargs) -> MetaGradientModule: """Create a new module instance.""" instance = super().__new__(cls) - flat_args: List[Any] + flat_args: list[Any] flat_args = pytree.tree_leaves((args, kwargs)) # type: ignore[arg-type] meta_parameters = {x for x in flat_args if isinstance(x, torch.Tensor) and x.requires_grad} meta_modules = {x for x in flat_args if isinstance(x, nn.Module) and x.training} @@ -51,14 +53,14 @@ def __new__(cls, *args, **kwargs) -> 'MetaGradientModule': instance._meta_inputs = MetaInputsContainer(meta_parameters, meta_modules) instance._meta_parameters: TensorContainer = OrderedDict() # type: ignore[misc] - instance._meta_modules: Dict[str, Optional[nn.Module]] = OrderedDict() # type: ignore[misc] + instance._meta_modules: dict[str, nn.Module | None] = OrderedDict() # type: ignore[misc] return instance def __init__(self, *args, **kwargs) -> None: # pylint: disable=unused-argument """Initialize a new module instance.""" super().__init__() - def __getattr__(self, name: str) -> Union[torch.Tensor, nn.Module]: + def __getattr__(self, name: str) -> torch.Tensor | nn.Module: """Get an attribute of the module.""" if '_parameters' in self.__dict__: _parameters = self.__dict__['_parameters'] @@ -83,7 +85,7 @@ def __getattr__(self, name: str) -> Union[torch.Tensor, nn.Module]: raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") # pylint: disable-next=too-many-branches,too-many-statements - def __setattr__(self, name: str, value: Union[torch.Tensor, nn.Module]) -> None: + def __setattr__(self, name: str, value: torch.Tensor | nn.Module) -> None: """Set an attribute of the module.""" def remove_from(*dicts_or_sets): @@ -186,18 +188,17 @@ def __delattr__(self, name: str) -> None: else: object.__delattr__(self, name) - def register_parameter(self, name: str, param: Optional[torch.Tensor]) -> None: + def register_parameter(self, name: str, param: torch.Tensor | None) -> None: r"""Add a parameter to the module. The parameter can be accessed as an attribute using given name. Args: - name (string): name of the parameter. The parameter can be accessed - from this module using the given name - param (torch.Tensor or None): parameter to be added to the module. If - ``None``, then operations that run on parameters, such as :attr:`cuda`, - are ignored. If ``None``, the parameter is **not** included in the - module's :attr:`state_dict`. + name (str): The name of the parameter. The parameter can be accessed from this module + using the given name. + param (Tensor or None): The parameter to be added to the module. If :data:`None`, then + operations that run on parameters, such as ``cuda``, are ignored. If :data:`None`, + the parameter is **not** included in the module's ``state_dict``. """ if '_parameters' not in self.__dict__: raise AttributeError('cannot assign parameter before Module.__init__() call') @@ -231,18 +232,17 @@ def register_parameter(self, name: str, param: Optional[torch.Tensor]) -> None: self._parameters[name] = param # type: ignore - def register_meta_parameter(self, name: str, param: Optional[torch.Tensor]) -> None: + def register_meta_parameter(self, name: str, param: torch.Tensor | None) -> None: r"""Add a meta-parameter to the module. The meta-parameter can be accessed as an attribute using given name. Args: - name (string): name of the parameter. The parameter can be accessed - from this module using the given name - param (torch.Tensor or None): parameter to be added to the module. If - ``None``, then operations that run on parameters, such as :attr:`cuda`, - are ignored. If ``None``, the parameter is **not** included in the - module's :attr:`state_dict`. + name (str): The name of the meta-parameter. The meta-parameter can be accessed from this + module using the given name. + param (Tensor or None): The meta-parameter to be added to the module. If :data:`None`, + then operations that run on meta-parameters, such as ``cuda``, are ignored. If + :data:`None`, the meta-parameter is **not** included in the module's ``state_dict``. """ if '_meta_parameters' not in self.__dict__: raise AttributeError( @@ -273,15 +273,15 @@ def register_meta_parameter(self, name: str, param: Optional[torch.Tensor]) -> N self._meta_parameters[name] = param - def add_module(self, name: str, module: Optional[nn.Module]) -> None: + def add_module(self, name: str, module: nn.Module | None) -> None: r"""Add a child module to the current module. The module can be accessed as an attribute using the given name. Args: - name (string): name of the child module. The child module can be - accessed from this module using the given name - module (Module): child module to be added to the module. + name (str): The name of the child module. The child module can be accessed from this + module using the given name + module (nn.Module or None): The child module to be added to the module. """ if not isinstance(module, nn.Module) and module is not None: raise TypeError(f'{torch.typename(module)} is not a Module subclass') @@ -301,19 +301,19 @@ def add_module(self, name: str, module: Optional[nn.Module]) -> None: self._modules[name] = module - def register_module(self, name: str, module: Optional[nn.Module]) -> None: + def register_module(self, name: str, module: nn.Module | None) -> None: r"""Alias for :func:`add_module`.""" self.add_module(name, module) - def add_meta_module(self, name: str, meta_module: Optional[nn.Module]) -> None: + def add_meta_module(self, name: str, meta_module: nn.Module | None) -> None: r"""Add a child meta-module to the current module. The meta-module can be accessed as an attribute using the given name. Args: - name (string): name of the child meta-module. The child meta-module can be - accessed from this module using the given name - meta_module (Module): child meta-module to be added to the module. + name (str): The name of the child meta-module. The child meta-module can be accessed + from this module using the given name + meta_module (nn.Module or None): The child meta-module to be added to the module. """ if not isinstance(meta_module, nn.Module) and meta_module is not None: raise TypeError(f'{torch.typename(meta_module)} is not a Module subclass') @@ -328,7 +328,7 @@ def add_meta_module(self, name: str, meta_module: Optional[nn.Module]) -> None: self._meta_modules[name] = meta_module - def register_meta_module(self, name: str, meta_module: Optional[nn.Module]) -> None: + def register_meta_module(self, name: str, meta_module: nn.Module | None) -> None: r"""Alias for :func:`add_meta_module`.""" self.add_meta_module(name, meta_module) @@ -338,9 +338,9 @@ def meta_parameters(self, recurse: bool = True) -> Iterator[torch.Tensor]: This is typically passed to an optimizer. Args: - recurse (bool): if True, then yields parameters of this module and - all submodules. Otherwise, yields only meta-parameters that - are direct members of this module. + recurse (bool, optional): If :data:`True`, then yields parameters of this module and + all submodules. Otherwise, yields only meta-parameters that are direct members of + this module. (default: :data:`True`) Yields: Parameter: module meta-parameter @@ -358,14 +358,15 @@ def meta_parameters(self, recurse: bool = True) -> Iterator[torch.Tensor]: def named_meta_parameters( self, prefix: str = '', recurse: bool = True - ) -> Iterator[Tuple[str, torch.Tensor]]: + ) -> Iterator[tuple[str, torch.Tensor]]: r"""Return an iterator over module meta-parameters, yielding both the name of the meta-parameter as well as the meta-parameter itself. Args: - prefix (str): prefix to prepend to all meta-parameter names. - recurse (bool): if True, then yields meta-parameters of this module - and all submodules. Otherwise, yields only meta-parameters that - are direct members of this module. + prefix (str, optional): The prefix to prepend to all meta-parameter names. + (default: :const:`''`) + recurse (bool, optional): if :data:`True`, then yields meta-parameters of this module + and all submodules. Otherwise, yields only meta-parameters that are direct members + of this module. (default: :data:`True`) Yields: (string, Parameter): Tuple containing the name and parameter @@ -398,7 +399,7 @@ def meta_children(self) -> Iterator[nn.Module]: for _, module in self.named_meta_children(): yield module - def named_meta_children(self) -> Iterator[Tuple[str, nn.Module]]: + def named_meta_children(self) -> Iterator[tuple[str, nn.Module]]: r"""Return an iterator over immediate children meta-modules, yielding both the name of the meta-module as well as the meta-module itself. Yields: @@ -430,15 +431,18 @@ def meta_modules(self) -> Iterator[nn.Module]: yield meta_module def named_meta_modules( - self, memo: Optional[Set[nn.Module]] = None, prefix: str = '', remove_duplicate: bool = True - ) -> Iterator[Tuple[str, nn.Module]]: + self, memo: set[nn.Module] | None = None, prefix: str = '', remove_duplicate: bool = True + ) -> Iterator[tuple[str, nn.Module]]: r"""Return an iterator over all meta-modules in the network, yielding both the name of the meta-module as well as the meta-module itself. Args: - memo: a memo to store the set of meta-modules already added to the result - prefix: a prefix that will be added to the name of the meta-module - remove_duplicate: whether to remove the duplicated meta-module instances in the result - or not + memo (set of nn.Module or None, optional): A memory to store the set of meta-modules + already added to the result. If not provided, a new set will be created. + (default: :const:`None`) + prefix (str, optional): A prefix that will be added to the name of the meta-module. + (default: :const:`''`) + remove_duplicate (bool, optional): whether to remove the duplicated meta-module + instances in the result or not. (default: :const:`True`) Yields: (string, Module): Tuple of name and meta-module diff --git a/torchopt/nn/stateless.py b/torchopt/nn/stateless.py index 2fc0dbb4..9391352f 100644 --- a/torchopt/nn/stateless.py +++ b/torchopt/nn/stateless.py @@ -14,8 +14,10 @@ # ============================================================================== """Utility functions for stateless module calls.""" +from __future__ import annotations + import contextlib -from typing import Dict, Generator, Iterable, Tuple, Union +from typing import Generator, Iterable import torch import torch.nn as nn @@ -29,9 +31,9 @@ def swap_state( module: nn.Module, - named_tensors: Union[Dict[str, torch.Tensor], Iterable[Tuple[str, torch.Tensor]]], + named_tensors: dict[str, torch.Tensor] | Iterable[tuple[str, torch.Tensor]], allow_missing: bool = False, -) -> Dict[str, torch.Tensor]: +) -> dict[str, torch.Tensor]: """Swap the module parameters and/or buffers.""" if not isinstance(named_tensors, dict): named_tensors = dict(named_tensors) @@ -84,7 +86,7 @@ def recursive_setattr(path: str, value: torch.Tensor) -> torch.Tensor: @contextlib.contextmanager def reparametrize( module: nn.Module, - named_tensors: Union[Dict[str, torch.Tensor], Iterable[Tuple[str, torch.Tensor]]], + named_tensors: dict[str, torch.Tensor] | Iterable[tuple[str, torch.Tensor]], allow_missing: bool = False, ) -> Generator[nn.Module, None, None]: """Reparameterize the module parameters and/or buffers.""" diff --git a/torchopt/optim/adam.py b/torchopt/optim/adam.py index c56956f8..640eea1d 100644 --- a/torchopt/optim/adam.py +++ b/torchopt/optim/adam.py @@ -14,7 +14,9 @@ # ============================================================================== """Adam optimizer.""" -from typing import Iterable, Tuple +from __future__ import annotations + +from typing import Iterable import torch @@ -39,7 +41,7 @@ def __init__( self, params: Iterable[torch.Tensor], lr: ScalarOrSchedule, - betas: Tuple[float, float] = (0.9, 0.999), + betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.0, *, @@ -50,25 +52,27 @@ def __init__( r"""Initialize the Adam optimizer. Args: - params: (iterable of torch.Tensor) - An iterable of :class:`torch.Tensor`\s. Specifies what tensors should be optimized. - lr: (default: :const:`1e-3`) - This is a fixed global scaling factor. - betas: (default: :const:`(0.9, 0.999)`) - Coefficients used for computing running averages of gradient and its square. - eps: (default: :const:`1e-8`) - A small constant applied to denominator outside of the square root (as in the Adam - paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - eps_root: (default: :data:`0.0`) - A small constant applied to denominator inside the square root (as in RMSProp), to - avoid dividing by zero when rescaling. This is needed for example when computing - (meta-)gradients through Adam. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - use_accelerated_op: (default: :data:`False`) - If :data:`True` use our implemented fused operator. + params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what + tensors should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to denominator outside of the square + root (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + eps_root (float, optional): A small constant applied to denominator inside the square + root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for + example when computing (meta-)gradients through Adam. (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) + use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator. + (default: :data:`False`) """ super().__init__( params, diff --git a/torchopt/optim/adamw.py b/torchopt/optim/adamw.py index 19c70678..7db5e750 100644 --- a/torchopt/optim/adamw.py +++ b/torchopt/optim/adamw.py @@ -14,13 +14,15 @@ # ============================================================================== """AdamW optimizer.""" -from typing import Any, Callable, Iterable, Optional, Tuple, Union +from __future__ import annotations + +from typing import Callable, Iterable import torch from torchopt import alias from torchopt.optim.base import Optimizer -from torchopt.typing import Params, ScalarOrSchedule +from torchopt.typing import OptState, Params, ScalarOrSchedule __all__ = ['AdamW'] @@ -39,46 +41,48 @@ def __init__( self, params: Iterable[torch.Tensor], lr: ScalarOrSchedule = 1e-3, - betas: Tuple[float, float] = (0.9, 0.999), + betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 1e-2, *, eps_root: float = 0.0, - mask: Optional[Union[Any, Callable[[Params], Any]]] = None, + mask: OptState | Callable[[Params], OptState] | None = None, maximize: bool = False, use_accelerated_op: bool = False, ) -> None: r"""Initialize the AdamW optimizer. Args: - params: (iterable of torch.Tensor) - An iterable of :class:`torch.Tensor`\s. Specifies what tensors should be optimized. - lr: (default: :const:`1e-3`) - This is a fixed global scaling factor. - betas: (default: :const:`(0.9, 0.999)`) - Coefficients used for computing running averages of gradient and its square. - eps: (default: :const:`1e-8`) - A small constant applied to denominator outside of the square root (as in the Adam - paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`1e-2`) - Strength of the weight decay regularization. Note that this weight decay is - multiplied with the learning rate. This is consistent with other frameworks such as - PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only - multiplied with the "schedule multiplier", but not the base learning rate. - eps_root: (default: :data:`0.0`) - A small constant applied to denominator inside the square root (as in RMSProp), to - avoid dividing by zero when rescaling. This is needed for example when computing - (meta-)gradients through Adam. - mask: (default: :data:`None`) - A tree with same structure as (or a prefix of) the params PyTree, or a Callable that + params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what + tensors should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to denominator outside of the square + root (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Strength of the weight decay regularization. Note that + this weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where + the weight decay is only multiplied with the "schedule multiplier", but not the base + learning rate. (default: :const:`1e-2`) + eps_root (float, optional): A small constant applied to denominator inside the square + root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for + example when computing (meta-)gradients through Adam. (default: :const:`0.0`) + mask (tree of Tensor, callable, or None, optional): + A tree with same structure as (or a prefix of) the params pytree, or a function that returns such a pytree given the params/updates. The leaves should be booleans, :data:`True` for leaves/subtrees you want to apply the weight decay to, and :data:`False` for those you want to skip. Note that the Adam gradient - transformations are applied to all parameters. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - use_accelerated_op: (default: :data:`False`) - If :data:`True` use our implemented fused operator. + transformations are applied to all parameters. (default: :data:`None`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) + use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator. + (default: :data:`False`) """ super().__init__( params, diff --git a/torchopt/optim/base.py b/torchopt/optim/base.py index e894b93b..aac3a782 100644 --- a/torchopt/optim/base.py +++ b/torchopt/optim/base.py @@ -14,7 +14,9 @@ # ============================================================================== """The base class for optimizers.""" -from typing import Callable, Iterable, List, Optional, Sequence, Tuple +from __future__ import annotations + +from typing import Callable, Iterable, Sequence import torch @@ -37,8 +39,8 @@ def __init__(self, params: Iterable[torch.Tensor], impl: GradientTransformation) params (iterable of torch.Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what tensors should be optimized. impl (GradientTransformation): A low level optimizer function, it could be a optimizer - function provided by ``alias.py`` or a customized ``chain`` provided by - ``combine.py``. + function provided in :mod:`torchopt.alias` or a customized :func:`torchopt.chain`\ed + transformation. Note that using ``Optimizer(sgd())`` or ``Optimizer(chain(sgd()))`` is equivalent to :class:`torchopt.SGD`. """ @@ -46,9 +48,9 @@ def __init__(self, params: Iterable[torch.Tensor], impl: GradientTransformation) raise TypeError(f'{impl} (type: {type(impl).__name__}) is not a GradientTransformation') self.impl: GradientTransformation = impl - self.param_groups: List[TupleOfTensors] = [] - self.param_treespecs: List[pytree.PyTreeSpec] = [] - self.state_groups: List[OptState] = [] + self.param_groups: list[TupleOfTensors] = [] + self.param_treespecs: list[pytree.PyTreeSpec] = [] + self.state_groups: list[OptState] = [] if not isinstance(params, (list, tuple)): params = tuple(params) @@ -60,7 +62,8 @@ def zero_grad(self, set_to_none: bool = False) -> None: The behavior is similar to :meth:`torch.optim.Optimizer.zero_grad`. Args: - set_to_none (bool): Instead of setting to zero, set the ``grads`` to :data:`None`. + set_to_none (bool, optional): Instead of setting to zero, set the ``grads`` to + :data:`None`. (default: :data:`False`) """ if set_to_none: @@ -80,7 +83,7 @@ def f(p): pytree.tree_map_(f, self.param_groups) # type: ignore[arg-type] - def state_dict(self) -> Tuple[OptState, ...]: + def state_dict(self) -> tuple[OptState, ...]: """Return the state of the optimizer.""" return tuple(self.state_groups) @@ -88,18 +91,19 @@ def load_state_dict(self, state_dict: Sequence[OptState]) -> None: """Load the optimizer state. Args: - state_dict: Optimizer state. Should be an object returned from a call to - :meth:`state_dict`. + state_dict (sequence of tree of Tensor): Optimizer state. Should be an object returned + from a call to :meth:`state_dict`. """ self.state_groups[:] = list(state_dict) - def step(self, closure: Optional[Callable[[], torch.Tensor]] = None) -> Optional[torch.Tensor]: + def step(self, closure: Callable[[], torch.Tensor] | None = None) -> torch.Tensor | None: """Perform a single optimization step. The behavior is similar to :meth:`torch.optim.Optimizer.step`. Args: - closure (callable, optional): A closure that reevaluates the model and returns the loss. + closure (callable or None, optional): A closure that reevaluates the model and returns + the loss. Optional for most optimizers. (default: :data:`None`) """ loss = None if closure is not None: @@ -120,7 +124,7 @@ def f(p): return loss def add_param_group(self, params: Params) -> None: - """Add a param group to the optimizer's :attr:`param_groups`.""" + """Add a param group to the optimizer's ``param_groups``.""" flat_params: TupleOfTensors flat_params, params_treespec = pytree.tree_flatten_as_tuple(params) self.param_groups.append(flat_params) diff --git a/torchopt/optim/func/base.py b/torchopt/optim/func/base.py index 7e51a21b..9dce3412 100644 --- a/torchopt/optim/func/base.py +++ b/torchopt/optim/func/base.py @@ -14,7 +14,7 @@ # ============================================================================== """Functional optimizer wrappers.""" -from typing import Optional +from __future__ import annotations import torch @@ -41,26 +41,27 @@ class FuncOptimizer: # pylint: disable=too-few-public-methods """ def __init__(self, impl: GradientTransformation, *, inplace: bool = False) -> None: - """Initialize the functional optimizer wrapper. + r"""Initialize the functional optimizer wrapper. Args: impl (GradientTransformation): A low level optimizer function, it could be a optimizer - function provided by `alias.py` or a customized `chain` provided by `combine.py`. - inplace (optional): (default: :data:`False`) - The default value of ``inplace`` for each optimization update. + function provided in :mod:`torchopt.alias` or a customized :func:`torchopt.chain`\ed + transformation. + inplace (bool, optional): The default value of ``inplace`` for each optimization update. + (default: :data:`False`) """ if not isinstance(impl, GradientTransformation): raise TypeError(f'{impl} (type: {type(impl).__name__}) is not a GradientTransformation') self.impl: GradientTransformation = impl - self.optim_state: Optional[OptState] = UninitializedState() + self.optim_state: OptState | None = UninitializedState() self.inplace: bool = bool(inplace) def step( self, loss: torch.Tensor, params: Params, - inplace: Optional[bool] = None, + inplace: bool | None = None, ) -> Params: r"""Compute the gradients of loss to the network parameters and update network parameters. @@ -69,13 +70,12 @@ def step( gradients and update the network parameters without modifying tensors in-place. Args: - loss: (torch.Tensor) - loss that is used to compute the gradients to network parameters. - params: (tree of torch.Tensor) - An tree of :class:`torch.Tensor`\s. Specifies what tensors should be optimized. - inplace (optional): (default: :data:`None`) - Whether to update the parameters in-place. If :data:`None`, use the default value - specified in the constructor. + loss (Tensor): The loss that is used to compute the gradients to network parameters. + params (tree of Tensor): An tree of :class:`torch.Tensor`\s. Specifies what tensors + should be optimized. + inplace (bool or None, optional): Whether to update the parameters in-place. If + :data:`None`, use the default value specified in the constructor. + (default: :data:`None`) """ if isinstance(self.optim_state, UninitializedState): self.optim_state = self.impl.init(params) diff --git a/torchopt/optim/meta/adam.py b/torchopt/optim/meta/adam.py index 36d54857..bd9804b9 100644 --- a/torchopt/optim/meta/adam.py +++ b/torchopt/optim/meta/adam.py @@ -14,7 +14,7 @@ # ============================================================================== """Differentiable Adam optimizer.""" -from typing import Tuple +from __future__ import annotations import torch.nn as nn @@ -39,7 +39,7 @@ def __init__( self, module: nn.Module, lr: ScalarOrSchedule = 1e-3, - betas: Tuple[float, float] = (0.9, 0.999), + betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.0, *, @@ -51,28 +51,26 @@ def __init__( """Initialize the meta-Adam optimizer. Args: - module: (nn.Module) - A network whose parameters should be optimized. - lr: (default: :const:`1e-3`) - This is a fixed global scaling factor. - betas: (default: :const:`(0.9, 0.999)`) - Coefficients used for computing running averages of gradient and its square. - eps: (default: :const:`1e-8`) - A small constant applied to denominator outside of the square root (as in the Adam - paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - eps_root: (default: :data:`0.0`) - A small constant applied to denominator inside the square root (as in RMSProp), to - avoid dividing by zero when rescaling. This is needed for example when computing - (meta-)gradients through Adam. - moment_requires_grad: (default: :data:`True`) - If :data:`True` the momentums will be created with flag ``requires_grad=True``, this - flag is often used in Meta-Learning algorithms. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - use_accelerated_op: (default: :data:`False`) - If :data:`True` use our implemented fused operator. + module (nn.Module): A network whose parameters should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to denominator outside of the square + root (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + eps_root (float, optional): A small constant applied to denominator inside the square + root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for + example when computing (meta-)gradients through Adam. (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) + use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator. + (default: :data:`False`) """ super().__init__( module, diff --git a/torchopt/optim/meta/adamw.py b/torchopt/optim/meta/adamw.py index dc869e30..c8a8ef9c 100644 --- a/torchopt/optim/meta/adamw.py +++ b/torchopt/optim/meta/adamw.py @@ -14,13 +14,15 @@ # ============================================================================== """Differentiable AdamW optimizer.""" -from typing import Any, Callable, Optional, Tuple, Union +from __future__ import annotations + +from typing import Callable import torch.nn as nn from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer -from torchopt.typing import Params, ScalarOrSchedule +from torchopt.typing import OptState, Params, ScalarOrSchedule __all__ = ['MetaAdamW'] @@ -39,12 +41,12 @@ def __init__( self, module: nn.Module, lr: ScalarOrSchedule = 1e-3, - betas: Tuple[float, float] = (0.9, 0.999), + betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 1e-2, *, eps_root: float = 0.0, - mask: Optional[Union[Any, Callable[[Params], Any]]] = None, + mask: OptState | Callable[[Params], OptState] | None = None, moment_requires_grad: bool = False, maximize: bool = False, use_accelerated_op: bool = False, @@ -52,37 +54,35 @@ def __init__( """Initialize the meta-AdamW optimizer. Args: - module: (nn.Module) - A network whose parameters should be optimized. - lr: (default: :const:`1e-3`) - This is a fixed global scaling factor. - betas: (default: :const:`(0.9, 0.999)`) - Coefficients used for computing running averages of gradient and its square. - eps: (default: :const:`1e-8`) - A small constant applied to denominator outside of the square root (as in the Adam - paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`1e-2`) - Strength of the weight decay regularization. Note that this weight decay is - multiplied with the learning rate. This is consistent with other frameworks such as - PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only - multiplied with the "schedule multiplier", but not the base learning rate. - eps_root: (default: :data:`0.0`) - A small constant applied to denominator inside the square root (as in RMSProp), to - avoid dividing by zero when rescaling. This is needed for example when computing - (meta-)gradients through Adam. - mask: (default: :data:`None`) - A tree with same structure as (or a prefix of) the params PyTree, or a Callable that + module (nn.Module): A network whose parameters should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to denominator outside of the square + root (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Strength of the weight decay regularization. Note that + this weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where + the weight decay is only multiplied with the "schedule multiplier", but not the base + learning rate. (default: :const:`1e-2`) + eps_root (float, optional): A small constant applied to denominator inside the square + root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for + example when computing (meta-)gradients through Adam. (default: :const:`0.0`) + mask (tree of Tensor, callable, or None, optional): + A tree with same structure as (or a prefix of) the params pytree, or a function that returns such a pytree given the params/updates. The leaves should be booleans, :data:`True` for leaves/subtrees you want to apply the weight decay to, and :data:`False` for those you want to skip. Note that the Adam gradient - transformations are applied to all parameters. - moment_requires_grad: (default: :data:`False`) - If :data:`True` the momentums will be created with flag ``requires_grad=True``, this - flag is often used in Meta-Learning algorithms. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - use_accelerated_op: (default: :data:`False`) - If :data:`True` use our implemented fused operator. + transformations are applied to all parameters. (default: :data:`None`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) + use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator. + (default: :data:`False`) """ super().__init__( module, diff --git a/torchopt/optim/meta/base.py b/torchopt/optim/meta/base.py index 8db4f0a7..c5c9ad73 100644 --- a/torchopt/optim/meta/base.py +++ b/torchopt/optim/meta/base.py @@ -14,7 +14,9 @@ # ============================================================================== """The base class for differentiable meta-optimizers.""" -from typing import List, Sequence, Tuple +from __future__ import annotations + +from typing import Sequence import torch import torch.nn as nn @@ -33,14 +35,13 @@ class MetaOptimizer: """The base class for high-level differentiable optimizers.""" def __init__(self, module: nn.Module, impl: GradientTransformation) -> None: - """Initialize the meta-optimizer. + r"""Initialize the meta-optimizer. Args: - module: (nn.Module) - A network whose parameters should be optimized. - impl: (GradientTransformation) - A low level optimizer function, it could be a optimizer function provided by - ``alias.py`` or a customized ``chain`` provided by ``combine.py``. + module (nn.Module): A network whose parameters should be optimized. + impl (GradientTransformation): A low level optimizer function, it could be a optimizer + function provided in :mod:`torchopt.alias` or a customized :func:`torchopt.chain`\ed + transformation. Note that using ``MetaOptimizer(sgd(moment_requires_grad=True))`` or ``MetaOptimizer(chain(sgd(moment_requires_grad=True)))`` is equivalent to :class:`torchopt.MetaSGD`. @@ -49,8 +50,8 @@ def __init__(self, module: nn.Module, impl: GradientTransformation) -> None: raise TypeError(f'{impl} (type: {type(impl).__name__}) is not a GradientTransformation') self.impl: GradientTransformation = impl - self.param_containers_groups: List[ModuleTensorContainers] = [] - self.state_groups: List[OptState] = [] + self.param_containers_groups: list[ModuleTensorContainers] = [] + self.state_groups: list[OptState] = [] self.add_param_group(module) @@ -62,8 +63,8 @@ def step(self, loss: torch.Tensor) -> None: # pylint: disable=too-many-locals gradients and update the network parameters without modifying tensors in-place. Args: - loss: (torch.Tensor) - The loss that is used to compute the gradients to the network parameters. + loss (torch.Tensor): The loss that is used to compute the gradients to the network + parameters. """ # Step parameter only for i, (param_container, state) in enumerate( @@ -94,12 +95,12 @@ def step(self, loss: torch.Tensor) -> None: # pylint: disable=too-many-locals container.update(new_param) def add_param_group(self, module: nn.Module) -> None: - """Add a param group to the optimizer's :attr:`state_groups`.""" + """Add a param group to the optimizer's ``state_groups``.""" params_container = extract_module_containers(module, with_buffers=False)[0] self.param_containers_groups.append(params_container) self.state_groups.append(UninitializedState()) - def state_dict(self) -> Tuple[OptState, ...]: + def state_dict(self) -> tuple[OptState, ...]: """Extract the references of the optimizer states. Note that the states are references, so any in-place operations will change the states diff --git a/torchopt/optim/meta/rmsprop.py b/torchopt/optim/meta/rmsprop.py index f4dfdae6..3aff20e1 100644 --- a/torchopt/optim/meta/rmsprop.py +++ b/torchopt/optim/meta/rmsprop.py @@ -50,30 +50,26 @@ def __init__( """Initialize the meta-RMSProp optimizer. Args: - module: (nn.Module) - A network whose parameters should be optimized. - lr: (default: :const:`1e-2`) - This is a fixed global scaling factor. - alpha: (default: :const:`0.99`) - Smoothing constant, the decay used to track the magnitude of previous gradients. - eps: (default: :const:`1e-8`) - A small numerical constant to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - momentum: (default: :const:`0.0`) - The decay rate used by the momentum term. The momentum is not used when it is set to - :const:`0.0`. - centered: (default: :data:`False`) - If :data:`True`, use the variance of the past gradients to rescale the latest - gradients. - initial_scale: (default: :data:`0.0`) - Initialization of accumulators tracking the magnitude of previous updates. PyTorch - uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When reproducing results from a - paper, verify the value used by the authors. - nesterov: (default: :data:`False`) - Whether to use Nesterov momentum. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. + module (nn.Module): A network whose parameters should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-2`) + alpha (float, optional): Smoothing constant, the decay used to track the magnitude of + previous gradients. (default: :const:`0.99`) + eps (float, optional): A small numerical constant to avoid dividing by zero when + rescaling. (default: :const:`1e-8`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + momentum (float, optional): The decay rate used by the momentum term. The momentum is + not used when it is set to :const:`0.0`. (default: :const:`0.0`) + centered (bool, optional): If :data:`True`, use the variance of the past gradients to + rescale the latest gradients. (default: :data:`False`) + initial_scale (float, optional): Initialization of accumulators tracking the magnitude + of previous updates. PyTorch uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When + reproducing results from a paper, verify the value used by the authors. + (default: :data:`0.0`) + nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) """ super().__init__( module, diff --git a/torchopt/optim/meta/sgd.py b/torchopt/optim/meta/sgd.py index 5f9177e1..476ed9d6 100644 --- a/torchopt/optim/meta/sgd.py +++ b/torchopt/optim/meta/sgd.py @@ -47,23 +47,20 @@ def __init__( """Initialize the meta-SGD optimizer. Args: - module: (nn.Module) - A network whose parameters should be optimized. - lr: This is a fixed global scaling factor. - momentum: (default: :const:`0.0`) - The decay rate used by the momentum term. The momentum is not used when it is set to - :const:`0.0`. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - dampening: (default: :const:`0.0`) - Dampening for momentum. - nesterov: (default: :const:`False`) - Whether to use Nesterov momentum. - moment_requires_grad: (default: :data:`True`) - If :data:`True` the momentums will be created with flag ``requires_grad=True``, this - flag is often used in Meta-Learning algorithms. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. + module (nn.Module): A network whose parameters should be optimized. + lr (float or callable): This is a fixed global scaling factor or a learning rate + scheduler. + momentum (float, optional): The decay rate used by the momentum term. The momentum is + not used when it is set to :const:`0.0`. (default: :const:`0.0`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + dampening (float, optional): Dampening for momentum. (default: :const:`0.0`) + nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) """ super().__init__( module, diff --git a/torchopt/optim/rmsprop.py b/torchopt/optim/rmsprop.py index 9101984f..5c4e536f 100644 --- a/torchopt/optim/rmsprop.py +++ b/torchopt/optim/rmsprop.py @@ -52,30 +52,27 @@ def __init__( r"""Initialize the RMSProp optimizer. Args: - params: (iterable of torch.Tensor) - An iterable of :class:`torch.Tensor`\s. Specifies what Tensors should be optimized. - lr: (default: :const:`1e-2`) - This is a fixed global scaling factor. - alpha: (default: :const:`0.99`) - Smoothing constant, the decay used to track the magnitude of previous gradients. - eps: (default: :const:`1e-8`) - A small numerical constant to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - momentum: (default: :const:`0.0`) - The decay rate used by the momentum term. The momentum is not used when it is set to - :const:`0.0`. - centered: (default: :data:`False`) - If :data:`True`, use the variance of the past gradients to rescale the latest - gradients. - initial_scale: (default: :data:`0.0`) - Initialization of accumulators tracking the magnitude of previous updates. PyTorch - uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When reproducing results from a - paper, verify the value used by the authors. - nesterov: (default: :data:`False`) - Whether to use Nesterov momentum. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. + params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what + tensors should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-2`) + alpha (float, optional): Smoothing constant, the decay used to track the magnitude of + previous gradients. (default: :const:`0.99`) + eps (float, optional): A small numerical constant to avoid dividing by zero when + rescaling. (default: :const:`1e-8`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + momentum (float, optional): The decay rate used by the momentum term. The momentum is + not used when it is set to :const:`0.0`. (default: :const:`0.0`) + centered (bool, optional): If :data:`True`, use the variance of the past gradients to + rescale the latest gradients. (default: :data:`False`) + initial_scale (float, optional): Initialization of accumulators tracking the magnitude + of previous updates. PyTorch uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When + reproducing results from a paper, verify the value used by the authors. + (default: :data:`0.0`) + nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) """ super().__init__( params, diff --git a/torchopt/optim/sgd.py b/torchopt/optim/sgd.py index 223e856e..3da9595a 100644 --- a/torchopt/optim/sgd.py +++ b/torchopt/optim/sgd.py @@ -48,20 +48,21 @@ def __init__( r"""Initialize the SGD optimizer. Args: - params: (iterable of torch.Tensor) - An iterable of :class:`torch.Tensor`\s. Specifies what tensors should be optimized. - lr: This is a fixed global scaling factor. - momentum: (default: :const:`0.0`) - The decay rate used by the momentum term. The momentum is not used when it is set to - :const:`0.0`. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - dampening: (default: :const:`0.0`) - Dampening for momentum. - nesterov: (default: :data:`False`) - Whether to use Nesterov momentum. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. + params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what + tensors should be optimized. + lr (float or callable): This is a fixed global scaling factor or a learning rate + scheduler. + momentum (float, optional): The decay rate used by the momentum term. The momentum is + not used when it is set to :const:`0.0`. (default: :const:`0.0`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + dampening (float, optional): Dampening for momentum. (default: :const:`0.0`) + nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) """ super().__init__( params, diff --git a/torchopt/pytree.py b/torchopt/pytree.py index 0abcf4fd..d3b2d181 100644 --- a/torchopt/pytree.py +++ b/torchopt/pytree.py @@ -14,9 +14,11 @@ # ============================================================================== """The PyTree utilities.""" +from __future__ import annotations + import functools import operator -from typing import Callable, List, Optional, Tuple +from typing import Callable import optree import optree.typing as typing # pylint: disable=unused-import @@ -47,19 +49,20 @@ def tree_flatten_as_tuple( tree: PyTree[T], - is_leaf: Optional[Callable[[T], bool]] = None, + is_leaf: Callable[[T], bool] | None = None, *, none_is_leaf: bool = False, namespace: str = '', -) -> Tuple[Tuple[T, ...], PyTreeSpec]: +) -> tuple[tuple[T, ...], PyTreeSpec]: """Flatten a pytree to a tuple of leaves and a PyTreeSpec. Args: - tree: The pytree to flatten. - is_leaf: A function that returns :data:`True` if a given node is a leaf. - none_is_leaf: If :data:`True`, None is considered a leaf rather than a internal node with no - children. - namespace: The namespace of custom tree node types. + tree (pytree): The pytree to flatten. + is_leaf (callable or None, optional): An optionally specified function that returns + :data:`True` if a given node is a leaf. (default: :data:`None`) + none_is_leaf (bool, optional): If :data:`True`, :data:`None` is considered a leaf rather + than a internal node with no children. (default: :data:`False`) + namespace (str, optional): The namespace of custom tree node types. (default: :const:`''`) Returns: A tuple of (leaves, treespec). @@ -99,7 +102,7 @@ def tree_add(*trees: PyTree[T]) -> PyTree[T]: def tree_add_scalar_mul( - tree_x: TensorTree, tree_y: TensorTree, alpha: Optional[Scalar] = None + tree_x: TensorTree, tree_y: TensorTree, alpha: Scalar | None = None ) -> TensorTree: """Compute ``tree_x + alpha * tree_y``.""" if alpha is None: @@ -113,7 +116,7 @@ def tree_sub(minuend_tree: PyTree[T], subtrahend_tree: PyTree[T]) -> PyTree[T]: def tree_sub_scalar_mul( - tree_x: TensorTree, tree_y: TensorTree, alpha: Optional[Scalar] = None + tree_x: TensorTree, tree_y: TensorTree, alpha: Scalar | None = None ) -> TensorTree: """Compute ``tree_x - alpha * tree_y``.""" if alpha is None: @@ -190,4 +193,4 @@ def tree_local_value(rref_tree: PyTree[RRef[T]]) -> PyTree[T]: __all__.extend(['tree_as_rref', 'tree_to_here']) -del Callable, List, Optional, Tuple, optree, rpc, Scalar, T, RRef +del Callable, optree, rpc, Scalar, T, RRef diff --git a/torchopt/schedule/polynomial.py b/torchopt/schedule/polynomial.py index 8a8e51e8..d54dbf17 100644 --- a/torchopt/schedule/polynomial.py +++ b/torchopt/schedule/polynomial.py @@ -52,18 +52,17 @@ def polynomial_schedule( """Construct a schedule with polynomial transition from init to end value. Args: - init_value: Initial value for the scalar to be annealed. - end_value: End value of the scalar to be annealed. - power: The power of the polynomial used to transition from ``init`` to ``end``. - transition_steps: - Number of steps over which annealing takes place, the scalar starts changing at - ``transition_begin`` steps and completes the transition by - ``transition_begin + transition_steps`` steps. - If ``transition_steps <= 0``, then the entire annealing process is disabled and the - value is held fixed at ``init_value``. - transition_begin: - Must be *positive*. After how many steps to start annealing (before this many steps the - scalar value is held fixed at ``init_value``). + init_value (float or Tensor): Initial value for the scalar to be annealed. + end_value (float or Tensor): End value of the scalar to be annealed. + power (float or Tensor): The power of the polynomial used to transition from ``init`` to + ``end``. + transition_steps (int): Number of steps over which annealing takes place, the scalar starts + changing at ``transition_begin`` steps and completes the transition by + ``transition_begin + transition_steps`` steps. If ``transition_steps <= 0``, then the + entire annealing process is disabled and the value is held fixed at ``init_value``. + transition_begin (int, optional): Must be *positive*. After how many steps to start + annealing (before this many steps the scalar value is held fixed at ``init_value``). + (default: :const:`0`) Returns: schedule: diff --git a/torchopt/transform/add_decayed_weights.py b/torchopt/transform/add_decayed_weights.py index 772e6291..14745766 100644 --- a/torchopt/transform/add_decayed_weights.py +++ b/torchopt/transform/add_decayed_weights.py @@ -32,7 +32,9 @@ # ============================================================================== """Preset transformations for adding weight decay to updates.""" -from typing import Any, Callable, NamedTuple, Optional, Tuple, Union +from __future__ import annotations + +from typing import Any, Callable, NamedTuple from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation, identity @@ -59,7 +61,7 @@ class MaskedNode(NamedTuple): def masked( inner: GradientTransformation, - mask: Union[Any, Callable[[Params], Any]], + mask: OptState | Callable[[Params], OptState] | None = None, ) -> GradientTransformation: """Mask updates so only some are transformed, the rest are passed through. @@ -75,11 +77,12 @@ def masked( of :data:`True`. Args: - inner: Inner transformation to mask. - mask: A tree with same structure as (or a prefix of) the params tree, or a Callable that - returns such a tree given the params/updates. The leaves should be booleans, :data:`True` - for leaves/subtrees you want to apply the transformation to, and :data:`False` for those - you want to skip. The mask must be static for the gradient transformation to be jit-compilable. + inner (GradientTransformation): Inner transformation to mask. + mask (tree of Tensor, callable, or None, optional): A tree with same structure as (or a + prefix of) the params tree, or a function that returns such a tree given the + params/updates. The leaves should be booleans, :data:`True` for leaves/subtrees you want + to apply the transformation to, and :data:`False` for those you want to skip. + (default: :data:`None`) Returns: A :class:`GradientTransformation` wrapping ``inner``. @@ -89,14 +92,14 @@ def masked( def _masked_flat( inner: GradientTransformation, - mask: Union[Any, Callable[[Params], Any]], + mask: OptState | Callable[[Params], OptState] | None = None, ) -> GradientTransformation: return _masked(inner, mask, already_flattened=True) def _masked( inner: GradientTransformation, - mask: Union[Any, Callable[[Params], Any]], + mask: OptState | Callable[[Params], OptState] | None = None, *, already_flattened: bool = False, ) -> GradientTransformation: @@ -117,9 +120,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, + params: Params | None = None, inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: mask_tree = mask(updates) if callable(mask) else mask masked_updates = tree_mask(updates, mask_tree) masked_params = None if params is None else tree_mask(params, mask_tree) @@ -145,16 +148,17 @@ def update_fn( def add_decayed_weights( weight_decay: float = 0.0, - mask: Optional[Union[Any, Callable[[Params], Any]]] = None, + mask: OptState | Callable[[Params], OptState] | None = None, ) -> GradientTransformation: """Add parameter scaled by `weight_decay`. Args: - weight_decay: a scalar weight decay rate. - mask: a tree with same structure as (or a prefix of) the params tree, or a Callable that - returns such a pytree given the params/updates. The leaves should be booleans, - :data:`True` for leaves/subtrees you want to apply the transformation to, and - :data:`False` for those you want to skip. + weight_decay (float, optional): A scalar weight decay rate. (default: :const:`0.0`) + mask (tree of Tensor, callable, or None, optional): A tree with same structure as (or a + prefix of) the params tree, or a function that returns such a tree given the + params/updates. The leaves should be booleans, :data:`True` for leaves/subtrees you want + to apply the transformation to, and :data:`False` for those you want to skip. + (default: :data:`None`) Returns: An (init_fn, update_fn) tuple. @@ -168,7 +172,7 @@ def add_decayed_weights( def _add_decayed_weights_flat( weight_decay: float = 0.0, - mask: Optional[Union[Any, Callable[[Params], Any]]] = None, + mask: OptState | Callable[[Params], OptState] | None = None, ) -> GradientTransformation: return _add_decayed_weights( weight_decay=weight_decay, @@ -179,7 +183,7 @@ def _add_decayed_weights_flat( def _add_decayed_weights( weight_decay: float = 0.0, - mask: Optional[Union[Any, Callable[[Params], Any]]] = None, + mask: OptState | Callable[[Params], OptState] | None = None, *, already_flattened: bool = False, ) -> GradientTransformation: @@ -204,9 +208,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, + params: Params | None = None, inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: assert params is not None, ( 'Parameters are required for weight decay. ' 'Call `update(updates, state, params=params)` instead.' diff --git a/torchopt/transform/nan_to_num.py b/torchopt/transform/nan_to_num.py index 2c0b9d5e..804f8219 100644 --- a/torchopt/transform/nan_to_num.py +++ b/torchopt/transform/nan_to_num.py @@ -14,7 +14,7 @@ # ============================================================================== """Preset transformations that replaces updates with non-finite values to the given numbers.""" -from typing import Optional, Tuple +from __future__ import annotations from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation @@ -23,8 +23,8 @@ def nan_to_num( nan: float = 0.0, - posinf: Optional[float] = None, - neginf: Optional[float] = None, + posinf: float | None = None, + neginf: float | None = None, ) -> GradientTransformation: """Replace updates with values ``nan`` / ``+inf`` / ``-inf`` to the given numbers. @@ -39,9 +39,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, # pylint: disable=unused-argument + params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: if inplace: def f(g): diff --git a/torchopt/transform/scale.py b/torchopt/transform/scale.py index 4afac163..639c903e 100644 --- a/torchopt/transform/scale.py +++ b/torchopt/transform/scale.py @@ -31,7 +31,7 @@ # ============================================================================== """Preset transformation for scaling updates by learning rate.""" -from typing import Optional, Tuple +from __future__ import annotations from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation @@ -49,7 +49,7 @@ def scale(step_size: float) -> GradientTransformation: """Scale updates by some fixed scalar ``step_size``. Args: - step_size: A scalar corresponding to a fixed scaling factor for updates. + step_size (float): A scalar corresponding to a fixed scaling factor for updates. Returns: An ``(init_fn, update_fn)`` tuple. @@ -80,9 +80,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, # pylint: disable=unused-argument + params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: if inplace: def f(g): diff --git a/torchopt/transform/scale_by_adam.py b/torchopt/transform/scale_by_adam.py index 039d31fb..36f30be9 100644 --- a/torchopt/transform/scale_by_adam.py +++ b/torchopt/transform/scale_by_adam.py @@ -33,7 +33,9 @@ # pylint: disable=invalid-name -from typing import NamedTuple, Optional, Tuple +from __future__ import annotations + +from typing import NamedTuple import torch @@ -88,17 +90,17 @@ def scale_by_adam( [Kingma et al, 2014](https://arxiv.org/abs/1412.6980) Args: - b1: (default: :const:`0.9`) - Decay rate for the exponentially weighted average of grads. - b2: (default: :const:`0.999`) - Decay rate for the exponentially weighted average of squared grads. - eps: (default: :const:`1e-8`) - Term added to the denominator to improve numerical stability. - eps_root: (default: :const:`0.0`) - Term added to the denominator inside the square-root to improve + b1 (float, optional): Decay rate for the exponentially weighted average of grads. + (default: :const:`0.9`) + b2 (float, optional): Decay rate for the exponentially weighted average of squared grads. + (default: :const:`0.999`) + eps (float, optional): Term added to the denominator to improve numerical stability. + (default: :const:`1e-8`) + eps_root (float, optional): Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. - moment_requires_grad: (default: :data:`False`) - If :data:`True`, states will be created with flag `requires_grad = True`. + (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag + ``requires_grad = True``. (default: :data:`False`) Returns: An (init_fn, update_fn) tuple. @@ -169,9 +171,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, # pylint: disable=unused-argument + params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: mu = update_moment.impl( # type: ignore[attr-defined] updates, state.mu, b1, order=1, inplace=inplace, already_flattened=already_flattened ) @@ -218,17 +220,17 @@ def scale_by_accelerated_adam( [Kingma et al, 2014](https://arxiv.org/abs/1412.6980) Args: - b1: (default: :const:`0.9`) - Decay rate for the exponentially weighted average of grads. - b2: (default: :const:`0.999`) - Decay rate for the exponentially weighted average of squared grads. - eps: (default: :const:`1e-8`) - Term added to the denominator to improve numerical stability. - eps_root: (default: :const:`0.0`) - Term added to the denominator inside the square-root to improve + b1 (float, optional): Decay rate for the exponentially weighted average of grads. + (default: :const:`0.9`) + b2 (float, optional): Decay rate for the exponentially weighted average of squared grads. + (default: :const:`0.999`) + eps (float, optional): Term added to the denominator to improve numerical stability. + (default: :const:`1e-8`) + eps_root (float, optional): Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. - moment_requires_grad: (default: :data:`False`) - If :data:`True`, states will be created with flag `requires_grad = True`. + (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag + ``requires_grad = True``. (default: :data:`False`) Returns: An (init_fn, update_fn) tuple. @@ -285,9 +287,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, # pylint: disable=unused-argument + params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: count_inc = inc_count.impl(updates, state.count, already_flattened=True) # type: ignore[attr-defined] op = AdamOp(b1=b1, b2=b2, eps=eps, eps_root=eps_root, inplace=inplace) @@ -303,9 +305,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, # pylint: disable=unused-argument + params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: count_inc = inc_count.impl(updates, state.count, already_flattened=False) # type: ignore[attr-defined] treespec = pytree.tree_structure(updates, none_is_leaf=True) diff --git a/torchopt/transform/scale_by_rms.py b/torchopt/transform/scale_by_rms.py index 7a685f6b..7a0c8c20 100644 --- a/torchopt/transform/scale_by_rms.py +++ b/torchopt/transform/scale_by_rms.py @@ -31,7 +31,9 @@ # ============================================================================== """Preset transformations for scaling updates by exponential root mean-squared (RMS).""" -from typing import NamedTuple, Optional, Tuple +from __future__ import annotations + +from typing import NamedTuple import torch @@ -61,12 +63,11 @@ def scale_by_rms( [Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) Args: - alpha: (default: :const:`0.9`) - Decay rate for the exponentially weighted average of squared grads. - eps: (default: :const:`1e-8`) - Term added to the denominator to improve numerical stability. - initial_scale: (default: :const:`0.0`) - Initial value for second moment + alpha (float, optional): Decay rate for the exponentially weighted average of squared grads. + (default: :const:`0.9`) + eps (float, optional): Term added to the denominator to improve numerical stability. + (default: :const:`1e-8`) + initial_scale (float, optional): Initial value for second moment. (default: :const:`0.0`) Returns: An (init_fn, update_fn) tuple. @@ -121,9 +122,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, # pylint: disable=unused-argument + params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: nu = update_moment.impl( # type: ignore[attr-defined] updates, state.nu, alpha, order=2, inplace=inplace, already_flattened=already_flattened ) diff --git a/torchopt/transform/scale_by_schedule.py b/torchopt/transform/scale_by_schedule.py index 5556d111..d6e3b0fa 100644 --- a/torchopt/transform/scale_by_schedule.py +++ b/torchopt/transform/scale_by_schedule.py @@ -31,7 +31,9 @@ # ============================================================================== """Preset transformation for scaling updates by learning rate schedules.""" -from typing import NamedTuple, Optional, Tuple +from __future__ import annotations + +from typing import NamedTuple import torch @@ -54,9 +56,8 @@ def scale_by_schedule(step_size_fn: Schedule) -> GradientTransformation: """Scale updates using a custom schedule for the ``step_size``. Args: - step_size_fn: - A function that takes an update count as input and proposes the ``step_size`` to - multiply the updates by. + step_size_fn (callable): A function that takes an update count as input and proposes the + ``step_size`` to multiply the updates by. Returns: An ``(init_fn, update_fn)`` tuple. @@ -90,9 +91,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, # pylint: disable=unused-argument + params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: if inplace: def f(g, c): # pylint: disable=invalid-name diff --git a/torchopt/transform/scale_by_stddev.py b/torchopt/transform/scale_by_stddev.py index c15a0d6c..228ed707 100644 --- a/torchopt/transform/scale_by_stddev.py +++ b/torchopt/transform/scale_by_stddev.py @@ -33,7 +33,9 @@ # pylint: disable=invalid-name -from typing import NamedTuple, Optional, Tuple +from __future__ import annotations + +from typing import NamedTuple import torch @@ -64,12 +66,11 @@ def scale_by_stddev( [Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) Args: - alpha: (default: :const:`0.9`) - Decay rate for the exponentially weighted average of squared grads. - eps: (default: :const:`1e-8`) - Term added to the denominator to improve numerical stability. - initial_scale: (default: :const:`0.0`) - Initial value for second moment + alpha (float, optional): Decay rate for the exponentially weighted average of squared grads. + (default: :const:`0.9`) + eps (float, optional): Term added to the denominator to improve numerical stability. + (default: :const:`1e-8`) + initial_scale (float, optional): Initial value for second moment. (default: :const:`0.0`) Returns: An (init_fn, update_fn) tuple. @@ -125,9 +126,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, # pylint: disable=unused-argument + params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: mu = update_moment.impl( # type: ignore[attr-defined] updates, state.mu, alpha, order=1, inplace=inplace, already_flattened=already_flattened ) diff --git a/torchopt/transform/trace.py b/torchopt/transform/trace.py index 45e043f0..03d2441d 100644 --- a/torchopt/transform/trace.py +++ b/torchopt/transform/trace.py @@ -33,7 +33,9 @@ # pylint: disable=invalid-name -from typing import NamedTuple, Optional, Tuple +from __future__ import annotations + +from typing import NamedTuple import torch @@ -65,14 +67,12 @@ def trace( Both are frequently found in the optimization literature. Args: - momentum: (default: :const:`0.9`) - The decay rate for the trace of past updates. - dampening: (default: :const:`0.0`) - Dampening for momentum. - nesterov: (default: :data:`False`) - Whether to use Nesterov momentum. - moment_requires_grad: (default: :data:`False`) - If :data:`True`, states will be created with flag `requires_grad = True`. + momentum (float, optional): The decay rate for the trace of past updates. + (default: :const:`0.9`) + dampening (float, optional): Dampening for momentum. (default: :const:`0.0`) + nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`) + moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag + ``requires_grad = True``. (default: :data:`False`) Returns: An (init_fn, update_fn) tuple. @@ -139,9 +139,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, # pylint: disable=unused-argument + params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: nonlocal first_call if nesterov: diff --git a/torchopt/transform/utils.py b/torchopt/transform/utils.py index a9f02295..77ba58ca 100644 --- a/torchopt/transform/utils.py +++ b/torchopt/transform/utils.py @@ -31,6 +31,8 @@ # ============================================================================== """Utilities for the preset transformations.""" +from __future__ import annotations + from collections import deque from typing import Any, Callable, Sequence diff --git a/torchopt/update.py b/torchopt/update.py index 3fdd38e1..9485896b 100644 --- a/torchopt/update.py +++ b/torchopt/update.py @@ -48,11 +48,11 @@ def apply_updates(params: Params, updates: Updates, *, inplace: bool = True) -> :func:`tree_map` (e.g. if you want to manipulate updates in custom ways before applying them). Args: - params: A tree of parameters. - updates: - A tree of updates, the tree structure and the shape of the leaf nodes must match that - of ``params``. - inplace: If :data:`True`, will update params in a inplace manner. + params (tree of Tensor): A tree of parameters. + updates (tree of Tensor): A tree of updates, the tree structure and the shape of the leaf + nodes must match that of ``params``. + inplace (bool, optional): If :data:`True`, will update params in a inplace manner. + (default: :data:`True`) Returns: Updated parameters, with same structure, shape and type as ``params``. diff --git a/torchopt/utils.py b/torchopt/utils.py index 4deaba8b..12adb214 100644 --- a/torchopt/utils.py +++ b/torchopt/utils.py @@ -14,21 +14,11 @@ # ============================================================================== """Utilities for TorchOpt.""" +from __future__ import annotations + import copy import itertools -from typing import ( - TYPE_CHECKING, - Dict, - List, - NamedTuple, - Optional, - Sequence, - Set, - Tuple, - Union, - cast, - overload, -) +from typing import TYPE_CHECKING, NamedTuple, Sequence, cast, overload from typing_extensions import Literal # Python 3.8+ from typing_extensions import TypeAlias # Python 3.10+ @@ -56,32 +46,30 @@ class ModuleState(NamedTuple): """Container for module state.""" - params: Tuple[Dict[str, torch.Tensor], ...] - buffers: Tuple[Dict[str, torch.Tensor], ...] - visual_contents: Optional[Dict] = None + params: tuple[dict[str, torch.Tensor], ...] + buffers: tuple[dict[str, torch.Tensor], ...] + visual_contents: dict | None = None detach_buffers: bool = False CopyMode: TypeAlias = Literal['reference', 'copy', 'deepcopy', 'ref', 'clone', 'deepclone'] -def stop_gradient(target: Union[ModuleState, nn.Module, 'MetaOptimizer', TensorTree]) -> None: +def stop_gradient(target: ModuleState | nn.Module | MetaOptimizer | TensorTree) -> None: """Stop the gradient for the input object. - Since a tensor use :attr:`grad_fn` to connect itself with the previous computation graph, the + Since a tensor use ``grad_fn`` to connect itself with the previous computation graph, the backpropagated gradient will flow over the tensor and continue flow to the tensors that is - connected by :attr:`grad_fn`. Some algorithms requires manually detaching tensors from the + connected by ``grad_fn``. Some algorithms requires manually detaching tensors from the computation graph. Note that the :func:`stop_gradient` operation is in-place. Args: - target: The target that to be detached from the computation graph, it could be a - :class:`nn.Module`, :class:`torchopt.MetaOptimizer`, state of the - :class:`torchopt.MetaOptimizer`, or just a plain list of tensors. - inplace: If :data:`True`, the target will be detached in-place. if :data:`Frue`, this - function will return a detached copy of the target. The in-place operation is fast and - memory efficient but may raise backpropagation error. + target (ModuleState, nn.Module, MetaOptimizer, or tree of Tensor): The target that to be + detached from the computation graph, it could be a :class:`nn.Module`, + :class:`torchopt.MetaOptimizer`, state of the :class:`torchopt.MetaOptimizer`, or just + a plain list of tensors. """ # pylint: disable-next=import-outside-toplevel from torchopt.optim.meta.base import MetaOptimizer @@ -108,7 +96,7 @@ def extract_state_dict( target: nn.Module, *, by: CopyMode = 'reference', - device: Optional[Device] = None, + device: Device | None = None, with_buffers: bool = True, enable_visual: bool = False, visual_prefix: str = '', @@ -118,57 +106,62 @@ def extract_state_dict( @overload def extract_state_dict( - target: 'MetaOptimizer', + target: MetaOptimizer, *, by: CopyMode = 'reference', - device: Optional[Device] = None, + device: Device | None = None, with_buffers: bool = True, enable_visual: bool = False, visual_prefix: str = '', -) -> Tuple[OptState, ...]: # pragma: no cover +) -> tuple[OptState, ...]: # pragma: no cover ... # pylint: disable-next=too-many-branches,too-many-locals def extract_state_dict( - target: Union[nn.Module, 'MetaOptimizer'], + target: nn.Module | MetaOptimizer, *, by: CopyMode = 'reference', - device: Optional[Device] = None, + device: Device | None = None, with_buffers: bool = True, detach_buffers: bool = False, enable_visual: bool = False, visual_prefix: str = '', -) -> Union[ModuleState, Tuple[OptState, ...]]: +) -> ModuleState | tuple[OptState, ...]: """Extract target state. - Since a tensor use :attr:`grad_fn` to connect itself with the previous computation graph, the + Since a tensor use ``grad_fn`` to connect itself with the previous computation graph, the backpropagated gradient will flow over the tensor and continue flow to the tensors that is - connected by :attr:`grad_fn`. Some algorithms requires manually detaching tensors from the + connected by ``grad_fn``. Some algorithms requires manually detaching tensors from the computation graph. Note that the extracted state is a reference, which means any in-place operator will affect the target that the state is extracted from. Args: - target: It could be a :class:`nn.Module` or :class:`torchopt.MetaOptimizer`. - by: The extract policy of tensors in the target. + target (nn.Module or MetaOptimizer): It could be a :class:`nn.Module` or + :class:`torchopt.MetaOptimizer`. + by (str, optional): The extract policy of tensors in the target. (default: :const:`'reference'`) - :const:`'reference'`: The extracted tensors will be references to the original tensors. - :const:`'copy'`: The extracted tensors will be clones of the original tensors. This - makes the copied tensors have :attr:`grad_fn` to be a ```` function - points to the original tensors. + makes the copied tensors have ``grad_fn`` to be a ```` function points + to the original tensors. - :const:`'deepcopy'`: The extracted tensors will be deep-copied from the original tensors. The deep-copied tensors will detach from the original computation graph. - device: If specified, move the extracted state to the specified device. - with_buffers: Extract buffer together with parameters, this argument is only used if the - input target is :class:`nn.Module`. - detach_buffers: Whether to detach the reference to the buffers, this argument is only used - if the input target is :class:`nn.Module` and ``by='reference'``. - enable_visual: Add additional annotations, which could be used in computation graph - visualization. Currently, this flag only has effect on :class:`nn.Module` but we will - support :class:`torchopt.MetaOptimizer` later. - visual_prefix: Prefix for the visualization annotations. + device (Device or None, optional): If specified, move the extracted state to the specified + device. (default: :const:`None`) + with_buffers (bool, optional): Extract buffer together with parameters, this argument is + only used if the input target is :class:`nn.Module`. (default: :const:`True`) + detach_buffers (bool, optional): Whether to detach the reference to the buffers, this + argument is only used if the input target is :class:`nn.Module` and ``by='reference'``. + (default: :const:`False`) + enable_visual (bool, optional): Add additional annotations, which could be used in + computation graph visualization. Currently, this flag only has effect on + :class:`nn.Module` but we will support :class:`torchopt.MetaOptimizer` later. + (default: :const:`False`) + visual_prefix (str, optional): Prefix for the visualization annotations. + (default: :const:`''`) Returns: State extracted of the input object. @@ -228,9 +221,9 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor: else: visual_contents = None - params: List[Dict[str, torch.Tensor]] = [] - buffers: List[Dict[str, torch.Tensor]] = [] - memo: Set[nn.Module] = set() + params: list[dict[str, torch.Tensor]] = [] + buffers: list[dict[str, torch.Tensor]] = [] + memo: set[nn.Module] = set() def update_params(container): if len(container) > 0: @@ -287,12 +280,12 @@ def get_variable(t): def extract_module_containers( module: nn.Module, with_buffers: bool = True -) -> Tuple[ModuleTensorContainers, ModuleTensorContainers]: +) -> tuple[ModuleTensorContainers, ModuleTensorContainers]: """Extract the references to the containers of parameters and buffers from a module.""" if isinstance(module, nn.Module): - params: List[TensorContainer] = [] - buffers: List[TensorContainer] = [] - memo: Set[nn.Module] = set() + params: list[TensorContainer] = [] + buffers: list[TensorContainer] = [] + memo: set[nn.Module] = set() def update_container(container, items): if len(items) > 0: @@ -316,8 +309,8 @@ def update_container(container, items): def recover_state_dict( - target: Union[nn.Module, 'MetaOptimizer'], - state: Union[ModuleState, Sequence[OptState]], + target: nn.Module | MetaOptimizer, + state: ModuleState | Sequence[OptState], ) -> None: """Recover state. @@ -327,8 +320,8 @@ def recover_state_dict( modified. Args: - target: Target that need to recover. - state: The recovering state. + target (nn.Module or MetaOptimizer): Target that need to recover. + state (ModuleState or sequence of tree of Tensor): The recovering state. """ # pylint: disable-next=import-outside-toplevel from torchopt.optim.meta.base import MetaOptimizer @@ -344,10 +337,7 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor: return nn.Parameter(t.clone().detach_(), requires_grad=t.requires_grad) return t.clone().detach_().requires_grad_(t.requires_grad) - buffers = cast( - Tuple[Dict[str, torch.Tensor], ...], - pytree.tree_map(clone_detach_, buffers), # type: ignore[arg-type] - ) + buffers = pytree.tree_map(clone_detach_, buffers) # type: ignore[assignment,arg-type] for tgt, src in itertools.chain( zip(params_containers, params), @@ -367,19 +357,19 @@ def module_clone( *, by: CopyMode = 'reference', detach_buffers: bool = False, - device: Optional[Device] = None, + device: Device | None = None, ) -> nn.Module: # pragma: no cover ... @overload def module_clone( - target: 'MetaOptimizer', + target: MetaOptimizer, *, by: CopyMode = 'reference', detach_buffers: bool = False, - device: Optional[Device] = None, -) -> 'MetaOptimizer': # pragma: no cover + device: Device | None = None, +) -> MetaOptimizer: # pragma: no cover ... @@ -389,34 +379,36 @@ def module_clone( *, by: CopyMode = 'reference', detach_buffers: bool = False, - device: Optional[Device] = None, + device: Device | None = None, ) -> TensorTree: # pragma: no cover ... # pylint: disable-next=too-many-locals def module_clone( - target: Union[nn.Module, 'MetaOptimizer', TensorTree], + target: nn.Module | MetaOptimizer | TensorTree, *, by: CopyMode = 'reference', detach_buffers: bool = False, - device: Optional[Device] = None, -) -> Union[nn.Module, 'MetaOptimizer', TensorTree]: + device: Device | None = None, +) -> nn.Module | MetaOptimizer | TensorTree: """Clone a module. Args: - target: The target to be cloned. - by: The extract policy of tensors in the target. + target (nn.Module, MetaOptimizer, or tree of Tensor): The target to be cloned. + by (str, optional): The extract policy of tensors in the target. (default: :const:`'reference'`) - :const:`'reference'`: The extracted tensors will be references to the original tensors. - :const:`'copy'`: The extracted tensors will be clones of the original tensors. This - makes the copied tensors have :attr:`grad_fn` to be a ```` function - points to the original tensors. + makes the copied tensors have ``grad_fn`` to be a ```` function points + to the original tensors. - :const:`'deepcopy'`: The extracted tensors will be deep-copied from the original tensors. The deep-copied tensors will detach from the original computation graph. - detach_buffers: Whether to detach the reference to the buffers, this argument is only used - if the input target is :class:`nn.Module` and ``by='reference'``. - device: If specified, move the cloned module to the specified device. + detach_buffers (bool, optional): Whether to detach the reference to the buffers, this + argument is only used if the input target is :class:`nn.Module` and ``by='reference'``. + (default: :const:`False`) + device (Device or None, optional): If specified, move the cloned module to the specified + device. (default: :const:`None`) Returns: The cloned module. @@ -499,7 +491,7 @@ def module_detach_(target: nn.Module) -> nn.Module: # pragma: no cover @overload -def module_detach_(target: 'MetaOptimizer') -> 'MetaOptimizer': # pragma: no cover +def module_detach_(target: MetaOptimizer) -> MetaOptimizer: # pragma: no cover ... @@ -509,12 +501,13 @@ def module_detach_(target: TensorTree) -> TensorTree: # pragma: no cover def module_detach_( - target: Union[ModuleState, nn.Module, 'MetaOptimizer', TensorTree] -) -> Union[ModuleState, nn.Module, 'MetaOptimizer', TensorTree]: + target: ModuleState | nn.Module | MetaOptimizer | TensorTree, +) -> ModuleState | nn.Module | MetaOptimizer | TensorTree: """Detach a module from the computation graph. Args: - target: The target to be detached. + target (ModuleState, nn.Module, MetaOptimizer, or tree of Tensor): The + target to be detached. Returns: The detached module. diff --git a/torchopt/visual.py b/torchopt/visual.py index e8145240..7afe65a4 100644 --- a/torchopt/visual.py +++ b/torchopt/visual.py @@ -17,8 +17,10 @@ # ============================================================================== """Computation graph visualization.""" +from __future__ import annotations + from collections import namedtuple -from typing import Generator, Iterable, Mapping, Optional, Union, cast +from typing import Generator, Iterable, Mapping, cast import torch from graphviz import Digraph @@ -71,14 +73,13 @@ def truncate(s): # pylint: disable=invalid-name # pylint: disable-next=too-many-branches,too-many-statements,too-many-locals def make_dot( var: TensorOrTensors, - params: Optional[ - Union[ - Mapping[str, torch.Tensor], - ModuleState, - Generator, - Iterable[Union[Mapping[str, torch.Tensor], ModuleState, Generator]], - ] - ] = None, + params: ( + Mapping[str, torch.Tensor] + | ModuleState + | Generator + | Iterable[Mapping[str, torch.Tensor] | ModuleState | Generator] + | None + ) = None, show_attrs: bool = False, show_saved: bool = False, max_attr_chars: int = 50, @@ -89,7 +90,7 @@ def make_dot( and is either blue, orange, or green: - **Blue** - Reachable leaf tensors that requires grad (tensors whose :attr:`grad` fields will be + Reachable leaf tensors that requires grad (tensors whose ``grad`` fields will be populated during :meth:`backward`). - **Orange** Saved tensors of custom autograd functions as well as those saved by built-in backward @@ -100,16 +101,16 @@ def make_dot( If any output is a view, we represent its base tensor with a dark green node. Args: - var: Output tensor. - params: ([dict of (name, tensor) or state_dict]) - Parameters to add names to node that requires grad. - show_attrs: Whether to display non-tensor attributes of backward nodes - (Requires PyTorch version >= 1.9) - show_saved: Whether to display saved tensor nodes that are not by custom autograd - functions. Saved tensor nodes for custom functions, if present, are always displayed. - (Requires PyTorch version >= 1.9) - max_attr_chars: If ``show_attrs`` is :data:`True`, sets max number of characters to display - for any given attribute. + var (Tensor or sequence of Tensor): Output tensor. + params: (dict[str, Tensor], ModuleState, iterable of tuple[str, Tensor], or None, optional): + Parameters to add names to node that requires grad. (default: :data:`None`) + show_attrs (bool, optional): Whether to display non-tensor attributes of backward nodes. + (default: :data:`False`) + show_saved (bool, optional): Whether to display saved tensor nodes that are not by custom + autograd functions. Saved tensor nodes for custom functions, if present, are always + displayed. (default: :data:`False`) + max_attr_chars (int, optional): If ``show_attrs`` is :data:`True`, sets max number of + characters to display for any given attribute. (default: :const:`50`) """ param_map = {} pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy