Skip to content

style: use postponed evaluation of annotations and update doctring style #135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Use postponed evaluation of annotations and update doctring style by [@XuehaiPan](https://github.com/XuehaiPan) in [#135](https://github.com/metaopt/torchopt/pull/135).
- Rewrite setup CUDA Toolkit logic by [@XuehaiPan](https://github.com/XuehaiPan) in [#133](https://github.com/metaopt/torchopt/pull/133).

### Fixed
Expand Down
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
<a href="https://codecov.io/gh/metaopt/torchopt">![CodeCov](https://img.shields.io/codecov/c/gh/metaopt/torchopt)</a>
<a href="https://torchopt.readthedocs.io">![Documentation Status](https://img.shields.io/readthedocs/torchopt?logo=readthedocs)</a>
<a href="https://pepy.tech/project/torchopt">![Downloads](https://static.pepy.tech/personalized-badge/torchopt?period=total&left_color=grey&right_color=blue&left_text=downloads)</a>
<a href="https://github.com/metaopt/torchopt/stargazers">![GitHub Repo Stars](https://img.shields.io/github/stars/metaopt/torchopt?color=brightgreen&logo=github)</a>
<a href="https://github.com/metaopt/torchopt/blob/HEAD/LICENSE">![License](https://img.shields.io/github/license/metaopt/torchopt?label=license&logo=data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHZpZXdCb3g9IjAgMCAyNCAyNCIgd2lkdGg9IjI0IiBoZWlnaHQ9IjI0IiBmaWxsPSIjZmZmZmZmIj48cGF0aCBmaWxsLXJ1bGU9ImV2ZW5vZGQiIGQ9Ik0xMi43NSAyLjc1YS43NS43NSAwIDAwLTEuNSAwVjQuNUg5LjI3NmExLjc1IDEuNzUgMCAwMC0uOTg1LjMwM0w2LjU5NiA1Ljk1N0EuMjUuMjUgMCAwMTYuNDU1IDZIMi4zNTNhLjc1Ljc1IDAgMTAwIDEuNUgzLjkzTC41NjMgMTUuMThhLjc2Mi43NjIgMCAwMC4yMS44OGMuMDguMDY0LjE2MS4xMjUuMzA5LjIyMS4xODYuMTIxLjQ1Mi4yNzguNzkyLjQzMy42OC4zMTEgMS42NjIuNjIgMi44NzYuNjJhNi45MTkgNi45MTkgMCAwMDIuODc2LS42MmMuMzQtLjE1NS42MDYtLjMxMi43OTItLjQzMy4xNS0uMDk3LjIzLS4xNTguMzEtLjIyM2EuNzUuNzUgMCAwMC4yMDktLjg3OEw1LjU2OSA3LjVoLjg4NmMuMzUxIDAgLjY5NC0uMTA2Ljk4NC0uMzAzbDEuNjk2LTEuMTU0QS4yNS4yNSAwIDAxOS4yNzUgNmgxLjk3NXYxNC41SDYuNzYzYS43NS43NSAwIDAwMCAxLjVoMTAuNDc0YS43NS43NSAwIDAwMC0xLjVIMTIuNzVWNmgxLjk3NGMuMDUgMCAuMS4wMTUuMTQuMDQzbDEuNjk3IDEuMTU0Yy4yOS4xOTcuNjMzLjMwMy45ODQuMzAzaC44ODZsLTMuMzY4IDcuNjhhLjc1Ljc1IDAgMDAuMjMuODk2Yy4wMTIuMDA5IDAgMCAuMDAyIDBhMy4xNTQgMy4xNTQgMCAwMC4zMS4yMDZjLjE4NS4xMTIuNDUuMjU2Ljc5LjRhNy4zNDMgNy4zNDMgMCAwMDIuODU1LjU2OCA3LjM0MyA3LjM0MyAwIDAwMi44NTYtLjU2OWMuMzM4LS4xNDMuNjA0LS4yODcuNzktLjM5OWEzLjUgMy41IDAgMDAuMzEtLjIwNi43NS43NSAwIDAwLjIzLS44OTZMMjAuMDcgNy41aDEuNTc4YS43NS43NSAwIDAwMC0xLjVoLTQuMTAyYS4yNS4yNSAwIDAxLS4xNC0uMDQzbC0xLjY5Ny0xLjE1NGExLjc1IDEuNzUgMCAwMC0uOTg0LS4zMDNIMTIuNzVWMi43NXpNMi4xOTMgMTUuMTk4YTUuNDE4IDUuNDE4IDAgMDAyLjU1Ny42MzUgNS40MTggNS40MTggMCAwMDIuNTU3LS42MzVMNC43NSA5LjM2OGwtMi41NTcgNS44M3ptMTQuNTEtLjAyNGMuMDgyLjA0LjE3NC4wODMuMjc1LjEyNi41My4yMjMgMS4zMDUuNDUgMi4yNzIuNDVhNS44NDYgNS44NDYgMCAwMDIuNTQ3LS41NzZMMTkuMjUgOS4zNjdsLTIuNTQ3IDUuODA3eiI+PC9wYXRoPjwvc3ZnPgo=)</a>
</div>

Expand Down
109 changes: 109 additions & 0 deletions docs/source/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,115 @@ Chain
.. autofunction:: chain


Distributed Utilities
=====================

.. currentmodule:: torchopt.distributed

Initialization and Synchronization
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autosummary::

auto_init_rpc
barrier

.. autofunction:: auto_init_rpc
.. autofunction:: barrier

Process group information
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autosummary::

get_world_info
get_world_rank
get_rank
get_world_size
get_local_rank
get_local_world_size
get_worker_id

.. autofunction:: get_world_info
.. autofunction:: get_world_rank
.. autofunction:: get_rank
.. autofunction:: get_world_size
.. autofunction:: get_local_rank
.. autofunction:: get_local_world_size
.. autofunction:: get_worker_id

Worker selection
~~~~~~~~~~~~~~~~

.. autosummary::

on_rank
not_on_rank
rank_zero_only
rank_non_zero_only

.. autofunction:: on_rank
.. autofunction:: not_on_rank
.. autofunction:: rank_zero_only
.. autofunction:: rank_non_zero_only

Remote Procedure Call (RPC)
~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autosummary::

remote_async_call
remote_sync_call

.. autofunction:: remote_async_call
.. autofunction:: remote_sync_call

Predefined partitioners and reducers
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autosummary::

dim_partitioner
batch_partitioner
mean_reducer
sum_reducer

.. autofunction:: dim_partitioner
.. autofunction:: batch_partitioner
.. autofunction:: mean_reducer
.. autofunction:: sum_reducer

Function parallelization wrappers
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autosummary::

parallelize
parallelize_async
parallelize_sync

.. autofunction:: parallelize
.. autofunction:: parallelize_async
.. autofunction:: parallelize_sync

Distributed Autograd
~~~~~~~~~~~~~~~~~~~~

.. currentmodule:: torchopt.distributed.autograd

.. autosummary::

context
get_gradients
backward
grad

.. autofunction:: context
.. autofunction:: get_gradients
.. autofunction:: backward
.. autofunction:: grad


General Utilities
=================

Expand Down
7 changes: 0 additions & 7 deletions docs/source/distributed/distributed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ Initialization and Synchronization

.. autosummary::


torchopt.distributed.auto_init_rpc
torchopt.distributed.barrier

Expand Down Expand Up @@ -197,7 +196,6 @@ Process group information

.. autosummary::


torchopt.distributed.get_world_info
torchopt.distributed.get_world_rank
torchopt.distributed.get_rank
Expand Down Expand Up @@ -228,7 +226,6 @@ Worker selection

.. autosummary::


torchopt.distributed.on_rank
torchopt.distributed.not_on_rank
torchopt.distributed.rank_zero_only
Expand Down Expand Up @@ -275,7 +272,6 @@ Remote Procedure Call (RPC)

.. autosummary::


torchopt.distributed.remote_async_call
torchopt.distributed.remote_sync_call

Expand Down Expand Up @@ -354,7 +350,6 @@ Predefined partitioners and reducers

.. autosummary::


torchopt.distributed.dim_partitioner
torchopt.distributed.batch_partitioner
torchopt.distributed.mean_reducer
Expand Down Expand Up @@ -439,7 +434,6 @@ Function parallelization wrappers

.. autosummary::


torchopt.distributed.parallelize
torchopt.distributed.parallelize_async
torchopt.distributed.parallelize_sync
Expand Down Expand Up @@ -490,7 +484,6 @@ Distributed Autograd

.. autosummary::


torchopt.distributed.autograd.context
torchopt.distributed.autograd.get_gradients
torchopt.distributed.autograd.backward
Expand Down
1 change: 1 addition & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,4 @@ issubclass
abc
ABCMeta
subclasscheck
ctx
22 changes: 12 additions & 10 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.
# ==============================================================================

from __future__ import annotations

import copy
import itertools
import os
import random
from typing import Iterable, Optional, Tuple, Union
from typing import Iterable

import numpy as np
import pytest
Expand Down Expand Up @@ -137,7 +139,7 @@ def get_model():
@torch.no_grad()
def get_models(
device: torch.types.Device = None, dtype: torch.dtype = torch.float32
) -> Tuple[nn.Module, nn.Module, nn.Module, data.DataLoader]:
) -> tuple[nn.Module, nn.Module, nn.Module, data.DataLoader]:
seed_everything(seed=42)

model_base = get_model().to(dtype=dtype)
Expand Down Expand Up @@ -166,12 +168,12 @@ def get_models(

@torch.no_grad()
def assert_model_all_close(
model: Union[nn.Module, Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]],
model: nn.Module | tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]],
model_ref: nn.Module,
model_base: nn.Module,
dtype: torch.dtype = torch.float32,
rtol: Optional[float] = None,
atol: Optional[float] = None,
rtol: float | None = None,
atol: float | None = None,
equal_nan: bool = False,
) -> None:
if isinstance(model, tuple):
Expand All @@ -194,8 +196,8 @@ def assert_all_close(
actual: torch.Tensor,
expected: torch.Tensor,
base: torch.Tensor = None,
rtol: Optional[float] = None,
atol: Optional[float] = None,
rtol: float | None = None,
atol: float | None = None,
equal_nan: bool = False,
) -> None:
if base is not None:
Expand Down Expand Up @@ -223,9 +225,9 @@ def assert_all_close(
def assert_pytree_all_close(
actual: TensorTree,
expected: TensorTree,
base: Optional[TensorTree] = None,
rtol: Optional[float] = None,
atol: Optional[float] = None,
base: TensorTree | None = None,
rtol: float | None = None,
atol: float | None = None,
equal_nan: bool = False,
) -> None:
actual_leaves, actual_treespec = pytree.tree_flatten(actual)
Expand Down
14 changes: 8 additions & 6 deletions tests/test_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# limitations under the License.
# ==============================================================================

from typing import Callable, Tuple
from __future__ import annotations

from typing import Callable

import functorch
import pytest
Expand Down Expand Up @@ -107,7 +109,7 @@ def test_sgd(
def test_adam(
dtype: torch.dtype,
lr: float,
betas: Tuple[float, float],
betas: tuple[float, float],
eps: float,
inplace: bool,
weight_decay: float,
Expand Down Expand Up @@ -177,7 +179,7 @@ def test_maml_adam(
outer_lr: float,
inner_lr: float,
inner_update: int,
betas: Tuple[float, float],
betas: tuple[float, float],
eps: float,
inplace: bool,
weight_decay: float,
Expand Down Expand Up @@ -263,7 +265,7 @@ def maml_inner_solver_torchopt(params, data, use_accelerated_op):
def test_adamw(
dtype: torch.dtype,
lr: float,
betas: Tuple[float, float],
betas: tuple[float, float],
eps: float,
inplace: bool,
weight_decay: float,
Expand Down Expand Up @@ -333,8 +335,8 @@ def test_adamw(
def test_adam_accelerated_cuda(
dtype: torch.dtype,
lr: float,
optimizers: Tuple[Callable, torch.optim.Optimizer],
betas: Tuple[float, float],
optimizers: tuple[Callable, torch.optim.Optimizer],
betas: tuple[float, float],
eps: float,
inplace: bool,
weight_decay: float,
Expand Down
7 changes: 4 additions & 3 deletions tests/test_implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
# limitations under the License.
# ==============================================================================

from __future__ import annotations

import copy
from collections import OrderedDict
from types import FunctionType
from typing import Tuple

import functorch
import jax
Expand Down Expand Up @@ -55,7 +56,7 @@ def forward(self, x):
return self.fc(x)


def get_model_jax(dtype: np.dtype = np.float32) -> Tuple[FunctionType, OrderedDict]:
def get_model_jax(dtype: np.dtype = np.float32) -> tuple[FunctionType, OrderedDict]:
helpers.seed_everything(seed=42)

def func(params, x):
Expand All @@ -73,7 +74,7 @@ def func(params, x):
@torch.no_grad()
def get_model_torch(
device: torch.types.Device = None, dtype: torch.dtype = torch.float32
) -> Tuple[nn.Module, data.DataLoader]:
) -> tuple[nn.Module, data.DataLoader]:
helpers.seed_everything(seed=42)

model = FcNet(MODEL_NUM_INPUTS, MODEL_NUM_CLASSES).to(dtype=dtype)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_meta_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================

from typing import Tuple
from __future__ import annotations

import torch
import torch.nn.functional as F
Expand All @@ -40,7 +40,7 @@ def test_maml_meta_adam(
outer_lr: float,
inner_lr: float,
inner_update: int,
betas: Tuple[float, float],
betas: tuple[float, float],
eps: float,
eps_root: float,
weight_decay: float,
Expand Down
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy