Skip to content

feat(zero_order): implemented the zero order feature #93

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Oct 11, 2022
Merged
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ indent-after-paren=4
indent-string=' '

# Maximum number of characters on a single line.
max-line-length=100
max-line-length=120

# Maximum number of lines in a module.
max-module-lines=1000
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add zero-order gradient estimation by [@JieRen98](https://github.com/JieRen98) in [#93](https://github.com/metaopt/torchopt/pull/93).
- Add RPC-based distributed training support and add distributed MAML example by [@XuehaiPan](https://github.com/XuehaiPan) in [#83](https://github.com/metaopt/torchopt/pull/83).
- Add full type hints by [@XuehaiPan](https://github.com/XuehaiPan) in [#92](https://github.com/metaopt/torchopt/pull/92).
- Add API documentation and tutorial for implicit gradients by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@JieRen98](https://github.com/JieRen98) and [@XuehaiPan](https://github.com/XuehaiPan) in [#73](https://github.com/metaopt/torchopt/pull/73).
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ target-version = ["py37", "py38", "py39", "py310"]
atomic = true
profile = "black"
src_paths = ["torchopt", "examples", "tests"]
extra_standard_library = ["typing_extensions"]
indent = 4
line_length = 100
lines_after_imports = 2
Expand Down
1 change: 1 addition & 0 deletions torchopt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
schedule,
typing,
visual,
zero_order_diff,
)
from torchopt.accelerated_op import is_available as accelerated_op_available
from torchopt.alias import adam, adamw, rmsprop, sgd
Expand Down
3 changes: 0 additions & 3 deletions torchopt/accelerated_op/adam_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class MuOp(torch.autograd.Function): # pylint: disable=abstract-method

@staticmethod
def jvp(ctx: Any, *grad_inputs: Any) -> Any:
# pylint: disable-next=line-too-long
"""Defines a formula for differentiating the operation with forward mode automatic differentiation."""

@staticmethod
Expand All @@ -58,7 +57,6 @@ class NuOp(torch.autograd.Function): # pylint: disable=abstract-method

@staticmethod
def jvp(ctx: Any, *grad_inputs: Any) -> Any:
# pylint: disable-next=line-too-long
"""Defines a formula for differentiating the operation with forward mode automatic differentiation."""

@staticmethod
Expand All @@ -85,7 +83,6 @@ class UpdatesOp(torch.autograd.Function): # pylint: disable=abstract-method

@staticmethod
def jvp(ctx: Any, *grad_inputs: Any) -> Any:
# pylint: disable-next=line-too-long
"""Defines a formula for differentiating the operation with forward mode automatic differentiation."""

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion torchopt/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _scale_by_neg_lr(lr: ScalarOrSchedule):
if callable(lr):

def schedule_wrapper(count):
return -lr(count)
return -lr(count) # type: ignore[operator]

# pylint: disable-next=protected-access
return transform._scale_by_schedule(schedule_wrapper, already_flattened=True)
Expand Down
9 changes: 2 additions & 7 deletions torchopt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,7 @@
import itertools
from abc import abstractmethod
from typing import TYPE_CHECKING, Callable, NamedTuple, Optional, Tuple


try:
from typing import Protocol # pylint: disable=unused-import
except ImportError:
from typing_extensions import Protocol # type: ignore[assignment]
from typing_extensions import Protocol # Python 3.8+


if TYPE_CHECKING:
Expand Down Expand Up @@ -229,7 +224,7 @@ def __new__(cls):
@staticmethod
def init_fn(params: 'Params') -> 'OptState': # pylint: disable=unused-argument
"""Returns empty state."""
return EmptyState()
return EmptyState() # type: ignore[return-value]

@staticmethod
def update_fn(
Expand Down
27 changes: 11 additions & 16 deletions torchopt/distributed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

import torchopt.pytree as pytree
from torchopt.distributed.world import get_worker_id, get_world_rank, get_world_size
from torchopt.typing import Future, PyTree
from torchopt.typing import Future


__all__ = [
Expand Down Expand Up @@ -116,11 +116,12 @@ def __call__(
workers = list(map(get_worker_id, self.workers))
num_workers = len(workers)

args_tree = cast(PyTree[Any], (args, kwargs))
flattened_args, treedef = pytree.tree_flatten(args_tree)
args_tree = (args, kwargs)
flat_args: List[Any]
flat_args, treedef = pytree.tree_flatten(args_tree) # type: ignore[arg-type]

batch_size = None
for arg in flattened_args:
for arg in flat_args:
if isinstance(arg, torch.Tensor):
if batch_size is None:
batch_size = arg.shape[self.dim]
Expand All @@ -134,7 +135,6 @@ def __call__(
return [(get_world_rank(), args, kwargs.copy())]

dim_slices: List[Union[int, slice]]
# pylint: disable-next=line-too-long
batch_slices: List[Tuple[Union[int, slice, Ellipsis.__class__], ...]] # type: ignore[name-defined]
if self.exclusive:
num_replicas = batch_size
Expand Down Expand Up @@ -169,18 +169,18 @@ def __call__(
for dim_slice in dim_slices
]

flattened_args_replicas: List[List[Any]] = [[] for _ in range(num_replicas)]
for arg in flattened_args:
flat_args_replicas: List[List[Any]] = [[] for _ in range(num_replicas)]
for arg in flat_args:
if isinstance(arg, torch.Tensor):
for i, batch_slice in enumerate(batch_slices):
flattened_args_replicas[i].append(arg[batch_slice])
flat_args_replicas[i].append(arg[batch_slice])
else:
for i in range(num_replicas):
flattened_args_replicas[i].append(arg)
flat_args_replicas[i].append(arg)

args_replicas: List[Tuple[Args, KwArgs]] = [
pytree.tree_unflatten(treedef, args_replica) # type: ignore[misc]
for args_replica in flattened_args_replicas
for args_replica in flat_args_replicas
]

return [
Expand Down Expand Up @@ -237,8 +237,6 @@ def dim_partitioner(
return TensorDimensionPartitioner(dim, exclusive=exclusive, keepdim=keepdim, workers=workers)


# fmt: off
# pylint: disable=line-too-long
batch_partitioner: PartitionFunction = dim_partitioner(dim=0, keepdim=True, exclusive=False)
"""Partitioner for batch dimension. Divide and replicates the arguments to all workers along the first dimension.

Expand All @@ -249,16 +247,14 @@ def dim_partitioner(
All tensors in the ``args`` and ``kwargs`` will be partitioned along the dimension ``dim``,
while the non-tensor values will be broadcasted to partitions.
"""
exclusive_batch_partitioner: PartitionFunction = dim_partitioner(dim=0, keepdim=True, exclusive=True)
exclusive_batch_partitioner: PartitionFunction = dim_partitioner(dim=0, keepdim=True, exclusive=True) # fmt: skip
"""Partitioner for batch dimension. Divide and replicates the arguments to all workers along the first dimension.

Each batch sample will be assigned to a separate RPC call.

All tensors in the ``args`` and ``kwargs`` will be partitioned along the dimension ``dim``,
while the non-tensor values will be broadcasted to partitions.
"""
# pylint: enable=line-too-long
# fmt: on


def mean_reducer(results: Iterable[torch.Tensor]) -> torch.Tensor:
Expand All @@ -280,7 +276,6 @@ def remote_async_call(
reducer: Optional[Callable[[Iterable[T]], U]] = None,
timeout: Optional[float] = UNSET_RPC_TIMEOUT,
) -> Union[Future[List[T]], Future[U]]:
# pylint: disable=line-too-long
"""Asynchronously do an RPC on remote workers and return the a :class:`torch.Future` instance at the current worker.

Args:
Expand Down
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy