33
33
import torch
34
34
import torch .distributed .rpc as rpc
35
35
36
- from torchopt .distributed .world import get_world_size
37
- from torchopt .typing import Future
36
+ import torchopt .pytree as pytree
37
+ from torchopt .distributed .world import get_worker_id , get_world_rank , get_world_size
38
+ from torchopt .typing import Future , PyTree
38
39
39
40
40
41
__all__ = [
41
- 'default_partitioner' ,
42
+ 'DimPartitioner' ,
43
+ 'dim_partitioner' ,
44
+ 'batch_partitioner' ,
45
+ 'mean_reducer' ,
46
+ 'sum_reducer' ,
42
47
'remote_async_call' ,
43
48
'remote_sync_call' ,
44
49
'parallelize' ,
47
52
]
48
53
49
54
55
+ UNSET_RPC_TIMEOUT = rpc .api .UNSET_RPC_TIMEOUT
56
+
57
+
50
58
T = TypeVar ('T' )
51
59
U = TypeVar ('U' )
52
60
Args = Tuple [Any , ...]
53
61
KwArgs = Dict [str , Any ]
54
- Partitioner = Union [int , str , Callable [..., Sequence [Tuple [int , Optional [Args ], Optional [KwArgs ]]]]]
62
+ PartitionFunction = Callable [..., Sequence [Tuple [int , Optional [Args ], Optional [KwArgs ]]]]
63
+ Partitioner = Union [int , str , PartitionFunction ]
55
64
56
65
57
- def default_partitioner (
58
- * args : Any ,
59
- ** kwargs : Any ,
60
- ) -> List [Tuple [int , Optional [Args ], Optional [KwArgs ]]]:
61
- """Default partitioner.
66
+ class DimPartitioner :
67
+ """Partitioner class that partitions a batch of inputs along a given dimension.
62
68
63
- Replicates the arguments to all workers.
69
+ Args:
70
+ dim: The dimension to partition.
71
+ exclusive: Whether to partition the batch exclusively.
72
+ If ``exclusive=True``, the batch will be partitioned into ``batch_size`` partitions,
73
+ where ``batch_size`` is the size of the batch along the given dimension.
74
+ If ``exclusive=False``, the batch will be partitioned into
75
+ ``min(batch_size, num_workers)`` partitions, where ``num_workers`` is the number of
76
+ workers in the world.
77
+ keepdim: Whether to keep the partitioned dimension. Defaults to :data:`True`, i.e., keep the
78
+ batch dimension. If :data:`False`, use select instead of slicing. This functionality
79
+ should be used with ``exclusive=True``.
80
+ workers: The workers to partition the batch to. If :data:`None`, the batch will be
81
+ partitioned to all workers in the world.
64
82
"""
65
- return [(rank , args , kwargs .copy ()) for rank in range (get_world_size ())]
83
+
84
+ def __init__ (
85
+ self ,
86
+ dim : int ,
87
+ * ,
88
+ exclusive : bool = False ,
89
+ keepdim : bool = False ,
90
+ workers : Optional [Sequence [Union [int , str ]]] = None ,
91
+ ) -> None :
92
+ if not keepdim and not exclusive :
93
+ raise ValueError ('keepdim=False should be used with exclusive=True.' )
94
+
95
+ self .dim = dim
96
+ self .exclusive = exclusive
97
+ self .keepdim = keepdim
98
+ self .workers = workers
99
+
100
+ # pylint: disable-next=too-many-branches,too-many-locals
101
+ def __call__ (
102
+ self ,
103
+ * args : Any ,
104
+ ** kwargs : Any ,
105
+ ) -> List [Tuple [int , Optional [Args ], Optional [KwArgs ]]]:
106
+ if self .workers is None :
107
+ workers = list (range (get_world_size ()))
108
+ else :
109
+ workers = self .workers
110
+ workers : List [int ] = list (map (get_worker_id , workers ))
111
+ num_workers = len (workers )
112
+
113
+ args_tree : PyTree [Any ] = (args , kwargs )
114
+ flattened_args , treedef = pytree .tree_flatten (args_tree )
115
+
116
+ batch_size = None
117
+ for arg in flattened_args :
118
+ if isinstance (arg , torch .Tensor ):
119
+ if batch_size is None :
120
+ batch_size = arg .shape [self .dim ]
121
+ elif batch_size != arg .shape [self .dim ]:
122
+ raise ValueError (
123
+ f'Batch size mismatch on dim={ self .dim } . '
124
+ f'Expected { batch_size } , got { arg .shape [self .dim ]} (shape: { arg .shape } ).'
125
+ )
126
+
127
+ if batch_size is None :
128
+ return [(get_world_rank (), args , kwargs .copy ())]
129
+
130
+ if self .exclusive :
131
+ num_replicas = batch_size
132
+ if self .keepdim :
133
+ batch_slices = [slice (i , i + 1 ) for i in range (num_replicas )]
134
+ else :
135
+ batch_slices = list (range (num_replicas ))
136
+ else :
137
+ if batch_size <= num_workers :
138
+ num_replicas = batch_size
139
+ batch_slices = [slice (i , i + 1 ) for i in range (batch_size )] # keepdim=True
140
+ else :
141
+ num_replicas = num_workers
142
+ local_size = batch_size // num_workers
143
+ local_batch_indices = [i * local_size for i in range (num_workers )] + [batch_size ]
144
+ batch_slices = [
145
+ slice (local_batch_indices [i ], local_batch_indices [i + 1 ])
146
+ for i in range (num_workers )
147
+ ]
148
+
149
+ if self .dim >= 0 :
150
+ batch_slices = [(slice (),) * self .dim + (batch_slice ,) for batch_slice in batch_slices ]
151
+ elif self .dim < 0 :
152
+ batch_slices = [
153
+ (
154
+ ...,
155
+ batch_slice ,
156
+ )
157
+ + (slice (),) * (- self .dim - 1 )
158
+ for batch_slice in batch_slices
159
+ ]
160
+
161
+ flattened_args_replicas = [[] for _ in range (num_replicas )]
162
+ for arg in flattened_args :
163
+ if isinstance (arg , torch .Tensor ):
164
+ for i , batch_slice in enumerate (batch_slices ):
165
+ flattened_args_replicas [i ].append (arg [batch_slice ])
166
+ else :
167
+ for i in range (num_replicas ):
168
+ flattened_args_replicas [i ].append (arg )
169
+
170
+ args_replicas = [
171
+ pytree .tree_unflatten (treedef , args_replica ) for args_replica in flattened_args_replicas
172
+ ]
173
+
174
+ return [
175
+ (workers [i % num_workers ], worker_args , worker_kwargs )
176
+ for i , (worker_args , worker_kwargs ) in enumerate (args_replicas )
177
+ ]
178
+
179
+ def __reduce__ (
180
+ self ,
181
+ ) -> Tuple [
182
+ Callable [..., 'DimPartitioner' ],
183
+ Tuple [int ],
184
+ Dict [str , Union [bool , Optional [Sequence [Union [int , str ]]]]],
185
+ ]:
186
+ return (
187
+ DimPartitioner ,
188
+ (self .dim ,),
189
+ dict (exclusive = self .exclusive , keepdim = self .keepdim , workers = self .workers ),
190
+ )
191
+
192
+
193
+ def dim_partitioner (
194
+ dim : int = 0 ,
195
+ * ,
196
+ exclusive : bool = False ,
197
+ keepdim : bool = True ,
198
+ workers : Optional [Sequence [Union [int , str ]]] = None ,
199
+ ) -> PartitionFunction :
200
+ """Partition a batch of inputs along a given dimension.
201
+
202
+ Args:
203
+ dim: The dimension to partition.
204
+ exclusive: Whether to partition the batch exclusively.
205
+ If ``exclusive=True``, the batch will be partitioned into ``batch_size`` partitions,
206
+ where ``batch_size`` is the size of the batch along the given dimension.
207
+ If ``exclusive=False``, the batch will be partitioned into
208
+ ``min(batch_size, num_workers)`` partitions, where ``num_workers`` is the number of
209
+ workers in the world.
210
+ keepdim: Whether to keep the partitioned dimension. Defaults to :data:`True`, i.e., keep the
211
+ batch dimension. If :data:`False`, use select instead of slicing. This functionality
212
+ should be used with ``exclusive=True``.
213
+ workers: The workers to partition the batch to. If :data:`None`, the batch will be
214
+ partitioned to all workers in the world.
215
+
216
+ Returns:
217
+ A partition function.
218
+ """
219
+ return DimPartitioner (dim , exclusive = exclusive , keepdim = keepdim , workers = workers )
220
+
221
+
222
+ # pylint: disable=line-too-long
223
+ batch_partitioner : PartitionFunction = dim_partitioner (dim = 0 , keepdim = True , exclusive = False )
224
+ """Partitioner for batch dimension. Divide and replicates the arguments to all workers along the first dimension."""
225
+ # pylint: enable=line-too-long
226
+
227
+
228
+ def mean_reducer (results : Iterable [torch .Tensor ]) -> torch .Tensor :
229
+ """Reduce the results by averaging them."""
230
+ return torch .mean (torch .stack (tuple (results ), dim = 0 ), dim = 0 )
231
+
232
+
233
+ def sum_reducer (results : Iterable [torch .Tensor ]) -> torch .Tensor :
234
+ """Reduce the results by summing them."""
235
+ return torch .sum (torch .stack (tuple (results ), dim = 0 ), dim = 0 )
66
236
67
237
68
238
def remote_async_call (
69
239
func : Callable [..., T ],
70
240
* ,
71
241
args : Optional [Args ] = None ,
72
242
kwargs : Optional [KwArgs ] = None ,
73
- partitioner : Partitioner = default_partitioner ,
243
+ partitioner : Partitioner = batch_partitioner ,
74
244
reducer : Optional [Callable [[Iterable [T ]], U ]] = None ,
75
- timeout : Optional [float ] = rpc . api . UNSET_RPC_TIMEOUT ,
245
+ timeout : Optional [float ] = UNSET_RPC_TIMEOUT ,
76
246
) -> Union [Future [List [T ]], Future [U ]]:
77
247
# pylint: disable=line-too-long
78
248
"""Asynchronously do an RPC on remote workers and return the a :class:`torch.Future` instance at the current worker.
@@ -84,7 +254,7 @@ def remote_async_call(
84
254
kwargs (Optional[KwArgs], optional): The keyword arguments to pass to the function. Defaults
85
255
to :data:`None`.
86
256
partitioner (Partitioner, optional): A partitioner that partitions the arguments to multiple
87
- workers. Defaults to :func:`default_partitioner `.
257
+ workers. Defaults to :func:`batch_partitioner `.
88
258
reducer (Callable[[Iterable[T]], U], optional): A reducer that reduces the results from
89
259
multiple workers. Defaults to :data:`None`.
90
260
timeout (float, optional): The timeout for the RPC call. Defaults to
@@ -97,10 +267,8 @@ def remote_async_call(
97
267
args = ()
98
268
if kwargs is None :
99
269
kwargs = {}
100
- if isinstance (partitioner , int ):
101
- partitions = [(partitioner , args , kwargs )]
102
- elif isinstance (partitioner , str ):
103
- partitions = [(rpc .get_worker_info (worker_name = partitioner ).id , args , kwargs )]
270
+ if isinstance (partitioner , (int , str )):
271
+ partitions = [(get_worker_id (id = partitioner ), args , kwargs )]
104
272
elif callable (partitioner ):
105
273
partitions = partitioner (* args , ** kwargs ) # type: ignore[assignment]
106
274
else :
@@ -128,9 +296,9 @@ def remote_sync_call(
128
296
* ,
129
297
args : Optional [Args ] = None ,
130
298
kwargs : Optional [KwArgs ] = None ,
131
- partitioner : Partitioner = default_partitioner ,
299
+ partitioner : Partitioner = batch_partitioner ,
132
300
reducer : Optional [Callable [[Iterable [T ]], U ]] = None ,
133
- timeout : Optional [float ] = rpc . api . UNSET_RPC_TIMEOUT ,
301
+ timeout : Optional [float ] = UNSET_RPC_TIMEOUT ,
134
302
) -> Union [List [T ], U ]:
135
303
"""Synchronously do an RPC on remote workers and return the result to the current worker.
136
304
@@ -141,7 +309,7 @@ def remote_sync_call(
141
309
kwargs (Optional[KwArgs], optional): The keyword arguments to pass to the function. Defaults
142
310
to :data:`None`.
143
311
partitioner (Partitioner, optional): A partitioner that partitions the arguments to multiple
144
- workers. Defaults to :func:`default_partitioner `.
312
+ workers. Defaults to :func:`batch_partitioner `.
145
313
reducer (Callable[[Iterable[T]], U], optional): A reducer that reduces the results from
146
314
multiple workers. Defaults to :data:`None`.
147
315
timeout (float, optional): The timeout for the RPC call. Defaults to
@@ -161,9 +329,9 @@ def remote_sync_call(
161
329
162
330
163
331
def parallelize_async (
164
- partitioner : Partitioner = default_partitioner ,
332
+ partitioner : Partitioner = batch_partitioner ,
165
333
reducer : Optional [Callable [[Iterable [T ]], U ]] = None ,
166
- timeout : Optional [float ] = rpc . api . UNSET_RPC_TIMEOUT ,
334
+ timeout : Optional [float ] = UNSET_RPC_TIMEOUT ,
167
335
) -> Callable [[Callable [..., T ]], Callable [..., Union [Future [List [T ]], Future [U ]]]]:
168
336
"""Decorator for parallelizing a function.
169
337
@@ -173,7 +341,7 @@ def parallelize_async(
173
341
174
342
Args:
175
343
partitioner (Partitioner, optional): A partitioner that partitions the arguments to multiple
176
- workers. Defaults to :func:`default_partitioner `.
344
+ workers. Defaults to :func:`batch_partitioner `.
177
345
reducer (Callable[[Iterable[T]], U], optional): A reducer that reduces the results from
178
346
multiple workers. Defaults to :data:`None`.
179
347
timeout (float, optional): The timeout for the RPC call. Defaults to
@@ -214,17 +382,17 @@ def wrapped(*args: Any, **kwargs: Any) -> Union[Future[List[T]], Future[U]]:
214
382
215
383
216
384
def parallelize (
217
- partitioner : Partitioner = default_partitioner ,
385
+ partitioner : Partitioner = batch_partitioner ,
218
386
reducer : Optional [Callable [[Iterable [T ]], U ]] = None ,
219
- timeout : Optional [float ] = rpc . api . UNSET_RPC_TIMEOUT ,
387
+ timeout : Optional [float ] = UNSET_RPC_TIMEOUT ,
220
388
) -> Callable [[Callable [..., T ]], Callable [..., Union [List [T ], U ]]]:
221
389
"""Decorator for parallelizing a function.
222
390
223
391
This decorator can be used to parallelize a function call across multiple workers.
224
392
225
393
Args:
226
394
partitioner (Partitioner, optional): A partitioner that partitions the arguments to multiple
227
- workers. Defaults to :func:`default_partitioner `.
395
+ workers. Defaults to :func:`batch_partitioner `.
228
396
reducer (Callable[[Iterable[T]], U], optional): A reducer that reduces the results from
229
397
multiple workers. Defaults to :data:`None`.
230
398
timeout (float, optional): The timeout for the RPC call. Defaults to
0 commit comments