35
35
36
36
import torchopt .pytree as pytree
37
37
from torchopt .distributed .world import get_worker_id , get_world_rank , get_world_size
38
- from torchopt .typing import Future , PyTree
38
+ from torchopt .typing import Future
39
39
40
40
41
41
__all__ = [
@@ -116,11 +116,12 @@ def __call__(
116
116
workers = list (map (get_worker_id , self .workers ))
117
117
num_workers = len (workers )
118
118
119
- args_tree = cast (PyTree [Any ], (args , kwargs ))
120
- flattened_args , treedef = pytree .tree_flatten (args_tree )
119
+ args_tree = (args , kwargs )
120
+ flat_args : List [Any ]
121
+ flat_args , treedef = pytree .tree_flatten (args_tree ) # type: ignore[arg-type]
121
122
122
123
batch_size = None
123
- for arg in flattened_args :
124
+ for arg in flat_args :
124
125
if isinstance (arg , torch .Tensor ):
125
126
if batch_size is None :
126
127
batch_size = arg .shape [self .dim ]
@@ -134,7 +135,6 @@ def __call__(
134
135
return [(get_world_rank (), args , kwargs .copy ())]
135
136
136
137
dim_slices : List [Union [int , slice ]]
137
- # pylint: disable-next=line-too-long
138
138
batch_slices : List [Tuple [Union [int , slice , Ellipsis .__class__ ], ...]] # type: ignore[name-defined]
139
139
if self .exclusive :
140
140
num_replicas = batch_size
@@ -169,18 +169,18 @@ def __call__(
169
169
for dim_slice in dim_slices
170
170
]
171
171
172
- flattened_args_replicas : List [List [Any ]] = [[] for _ in range (num_replicas )]
173
- for arg in flattened_args :
172
+ flat_args_replicas : List [List [Any ]] = [[] for _ in range (num_replicas )]
173
+ for arg in flat_args :
174
174
if isinstance (arg , torch .Tensor ):
175
175
for i , batch_slice in enumerate (batch_slices ):
176
- flattened_args_replicas [i ].append (arg [batch_slice ])
176
+ flat_args_replicas [i ].append (arg [batch_slice ])
177
177
else :
178
178
for i in range (num_replicas ):
179
- flattened_args_replicas [i ].append (arg )
179
+ flat_args_replicas [i ].append (arg )
180
180
181
181
args_replicas : List [Tuple [Args , KwArgs ]] = [
182
182
pytree .tree_unflatten (treedef , args_replica ) # type: ignore[misc]
183
- for args_replica in flattened_args_replicas
183
+ for args_replica in flat_args_replicas
184
184
]
185
185
186
186
return [
@@ -237,8 +237,6 @@ def dim_partitioner(
237
237
return TensorDimensionPartitioner (dim , exclusive = exclusive , keepdim = keepdim , workers = workers )
238
238
239
239
240
- # fmt: off
241
- # pylint: disable=line-too-long
242
240
batch_partitioner : PartitionFunction = dim_partitioner (dim = 0 , keepdim = True , exclusive = False )
243
241
"""Partitioner for batch dimension. Divide and replicates the arguments to all workers along the first dimension.
244
242
@@ -249,16 +247,14 @@ def dim_partitioner(
249
247
All tensors in the ``args`` and ``kwargs`` will be partitioned along the dimension ``dim``,
250
248
while the non-tensor values will be broadcasted to partitions.
251
249
"""
252
- exclusive_batch_partitioner : PartitionFunction = dim_partitioner (dim = 0 , keepdim = True , exclusive = True )
250
+ exclusive_batch_partitioner : PartitionFunction = dim_partitioner (dim = 0 , keepdim = True , exclusive = True ) # fmt: skip
253
251
"""Partitioner for batch dimension. Divide and replicates the arguments to all workers along the first dimension.
254
252
255
253
Each batch sample will be assigned to a separate RPC call.
256
254
257
255
All tensors in the ``args`` and ``kwargs`` will be partitioned along the dimension ``dim``,
258
256
while the non-tensor values will be broadcasted to partitions.
259
257
"""
260
- # pylint: enable=line-too-long
261
- # fmt: on
262
258
263
259
264
260
def mean_reducer (results : Iterable [torch .Tensor ]) -> torch .Tensor :
@@ -280,7 +276,6 @@ def remote_async_call(
280
276
reducer : Optional [Callable [[Iterable [T ]], U ]] = None ,
281
277
timeout : Optional [float ] = UNSET_RPC_TIMEOUT ,
282
278
) -> Union [Future [List [T ]], Future [U ]]:
283
- # pylint: disable=line-too-long
284
279
"""Asynchronously do an RPC on remote workers and return the a :class:`torch.Future` instance at the current worker.
285
280
286
281
Args:
0 commit comments