Skip to content

Commit 673b949

Browse files
committed
chore: cleanup
1 parent 54e2dd2 commit 673b949

File tree

15 files changed

+387
-358
lines changed

15 files changed

+387
-358
lines changed

.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ indent-after-paren=4
444444
indent-string=' '
445445

446446
# Maximum number of characters on a single line.
447-
max-line-length=100
447+
max-line-length=120
448448

449449
# Maximum number of lines in a module.
450450
max-module-lines=1000

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ target-version = ["py37", "py38", "py39", "py310"]
174174
atomic = true
175175
profile = "black"
176176
src_paths = ["torchopt", "examples", "tests"]
177+
extra_standard_library = ["typing_extensions"]
177178
indent = 4
178179
line_length = 100
179180
lines_after_imports = 2

torchopt/accelerated_op/adam_op.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ class MuOp(torch.autograd.Function): # pylint: disable=abstract-method
3131

3232
@staticmethod
3333
def jvp(ctx: Any, *grad_inputs: Any) -> Any:
34-
# pylint: disable-next=line-too-long
3534
"""Defines a formula for differentiating the operation with forward mode automatic differentiation."""
3635

3736
@staticmethod
@@ -58,7 +57,6 @@ class NuOp(torch.autograd.Function): # pylint: disable=abstract-method
5857

5958
@staticmethod
6059
def jvp(ctx: Any, *grad_inputs: Any) -> Any:
61-
# pylint: disable-next=line-too-long
6260
"""Defines a formula for differentiating the operation with forward mode automatic differentiation."""
6361

6462
@staticmethod
@@ -85,7 +83,6 @@ class UpdatesOp(torch.autograd.Function): # pylint: disable=abstract-method
8583

8684
@staticmethod
8785
def jvp(ctx: Any, *grad_inputs: Any) -> Any:
88-
# pylint: disable-next=line-too-long
8986
"""Defines a formula for differentiating the operation with forward mode automatic differentiation."""
9087

9188
@staticmethod

torchopt/alias.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def _scale_by_neg_lr(lr: ScalarOrSchedule):
130130
if callable(lr):
131131

132132
def schedule_wrapper(count):
133-
return -lr(count)
133+
return -lr(count) # type: ignore[operator]
134134

135135
# pylint: disable-next=protected-access
136136
return transform._scale_by_schedule(schedule_wrapper, already_flattened=True)

torchopt/base.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,7 @@
3434
import itertools
3535
from abc import abstractmethod
3636
from typing import TYPE_CHECKING, Callable, NamedTuple, Optional, Tuple
37-
38-
39-
try:
40-
from typing import Protocol # pylint: disable=unused-import
41-
except ImportError:
42-
from typing_extensions import Protocol # type: ignore[assignment]
37+
from typing_extensions import Protocol # Python 3.8+
4338

4439

4540
if TYPE_CHECKING:
@@ -229,7 +224,7 @@ def __new__(cls):
229224
@staticmethod
230225
def init_fn(params: 'Params') -> 'OptState': # pylint: disable=unused-argument
231226
"""Returns empty state."""
232-
return EmptyState()
227+
return EmptyState() # type: ignore[return-value]
233228

234229
@staticmethod
235230
def update_fn(

torchopt/distributed/api.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
import torchopt.pytree as pytree
3737
from torchopt.distributed.world import get_worker_id, get_world_rank, get_world_size
38-
from torchopt.typing import Future, PyTree
38+
from torchopt.typing import Future
3939

4040

4141
__all__ = [
@@ -116,11 +116,12 @@ def __call__(
116116
workers = list(map(get_worker_id, self.workers))
117117
num_workers = len(workers)
118118

119-
args_tree = cast(PyTree[Any], (args, kwargs))
120-
flattened_args, treedef = pytree.tree_flatten(args_tree)
119+
args_tree = (args, kwargs)
120+
flat_args: List[Any]
121+
flat_args, treedef = pytree.tree_flatten(args_tree) # type: ignore[arg-type]
121122

122123
batch_size = None
123-
for arg in flattened_args:
124+
for arg in flat_args:
124125
if isinstance(arg, torch.Tensor):
125126
if batch_size is None:
126127
batch_size = arg.shape[self.dim]
@@ -134,7 +135,6 @@ def __call__(
134135
return [(get_world_rank(), args, kwargs.copy())]
135136

136137
dim_slices: List[Union[int, slice]]
137-
# pylint: disable-next=line-too-long
138138
batch_slices: List[Tuple[Union[int, slice, Ellipsis.__class__], ...]] # type: ignore[name-defined]
139139
if self.exclusive:
140140
num_replicas = batch_size
@@ -169,18 +169,18 @@ def __call__(
169169
for dim_slice in dim_slices
170170
]
171171

172-
flattened_args_replicas: List[List[Any]] = [[] for _ in range(num_replicas)]
173-
for arg in flattened_args:
172+
flat_args_replicas: List[List[Any]] = [[] for _ in range(num_replicas)]
173+
for arg in flat_args:
174174
if isinstance(arg, torch.Tensor):
175175
for i, batch_slice in enumerate(batch_slices):
176-
flattened_args_replicas[i].append(arg[batch_slice])
176+
flat_args_replicas[i].append(arg[batch_slice])
177177
else:
178178
for i in range(num_replicas):
179-
flattened_args_replicas[i].append(arg)
179+
flat_args_replicas[i].append(arg)
180180

181181
args_replicas: List[Tuple[Args, KwArgs]] = [
182182
pytree.tree_unflatten(treedef, args_replica) # type: ignore[misc]
183-
for args_replica in flattened_args_replicas
183+
for args_replica in flat_args_replicas
184184
]
185185

186186
return [
@@ -237,8 +237,6 @@ def dim_partitioner(
237237
return TensorDimensionPartitioner(dim, exclusive=exclusive, keepdim=keepdim, workers=workers)
238238

239239

240-
# fmt: off
241-
# pylint: disable=line-too-long
242240
batch_partitioner: PartitionFunction = dim_partitioner(dim=0, keepdim=True, exclusive=False)
243241
"""Partitioner for batch dimension. Divide and replicates the arguments to all workers along the first dimension.
244242
@@ -249,16 +247,14 @@ def dim_partitioner(
249247
All tensors in the ``args`` and ``kwargs`` will be partitioned along the dimension ``dim``,
250248
while the non-tensor values will be broadcasted to partitions.
251249
"""
252-
exclusive_batch_partitioner: PartitionFunction = dim_partitioner(dim=0, keepdim=True, exclusive=True)
250+
exclusive_batch_partitioner: PartitionFunction = dim_partitioner(dim=0, keepdim=True, exclusive=True) # fmt: skip
253251
"""Partitioner for batch dimension. Divide and replicates the arguments to all workers along the first dimension.
254252
255253
Each batch sample will be assigned to a separate RPC call.
256254
257255
All tensors in the ``args`` and ``kwargs`` will be partitioned along the dimension ``dim``,
258256
while the non-tensor values will be broadcasted to partitions.
259257
"""
260-
# pylint: enable=line-too-long
261-
# fmt: on
262258

263259

264260
def mean_reducer(results: Iterable[torch.Tensor]) -> torch.Tensor:
@@ -280,7 +276,6 @@ def remote_async_call(
280276
reducer: Optional[Callable[[Iterable[T]], U]] = None,
281277
timeout: Optional[float] = UNSET_RPC_TIMEOUT,
282278
) -> Union[Future[List[T]], Future[U]]:
283-
# pylint: disable=line-too-long
284279
"""Asynchronously do an RPC on remote workers and return the a :class:`torch.Future` instance at the current worker.
285280
286281
Args:

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy