Skip to content

Commit 6ba8eae

Browse files
committed
refactor: refactor example with new partitioner
1 parent 1643a30 commit 6ba8eae

File tree

3 files changed

+213
-53
lines changed

3 files changed

+213
-53
lines changed

examples/distributed/few-shot/maml_omniglot.py

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -186,15 +186,18 @@ def partitioner(net_rref, x_spt, y_spt, x_qry, y_qry, n_inner_iter, task_num):
186186
return partitions
187187

188188

189-
def reducer(results):
189+
def transpose_mean_reducer(results):
190190
qry_losses, qry_accs = tuple(zip(*results))
191191
qry_loss = torch.mean(torch.stack(qry_losses))
192192
qry_acc = np.mean(qry_accs)
193193
return qry_loss, qry_acc
194194

195195

196-
@todist.parallelize(partitioner=partitioner, reducer=reducer)
197-
def inner_loop(net_rref, x_spt, y_spt, x_qry, y_qry, n_inner_iter, task_num):
196+
@todist.parallelize(
197+
partitioner=todist.dim_partitioner(dim=0, exclusive=True, keepdim=False),
198+
reducer=transpose_mean_reducer,
199+
)
200+
def inner_loop(net_rref, x_spt, y_spt, x_qry, y_qry, n_inner_iter):
198201
if torch.cuda.is_available():
199202
device = torch.device(f'cuda:{todist.get_local_rank() % torch.cuda.device_count()}')
200203
torch.cuda.set_device(device)
@@ -235,8 +238,6 @@ def train(db: OmniglotNShot, net: nn.Module, meta_opt: optim.Adam, epoch: int, l
235238
# Sample a batch of support and query images and labels.
236239
x_spt, y_spt, x_qry, y_qry = db.next()
237240

238-
task_num = x_spt.size(0)
239-
240241
# TODO: Maybe pull this out into a separate module so it
241242
# doesn't have to be duplicated between `train` and `test`?
242243

@@ -246,15 +247,7 @@ def train(db: OmniglotNShot, net: nn.Module, meta_opt: optim.Adam, epoch: int, l
246247

247248
meta_opt.zero_grad()
248249
with todist.autograd.context() as context_id:
249-
qry_loss, qry_acc = inner_loop(
250-
net_rref,
251-
x_spt,
252-
y_spt,
253-
x_qry,
254-
y_qry,
255-
n_inner_iter,
256-
task_num,
257-
)
250+
qry_loss, qry_acc = inner_loop(net_rref, x_spt, y_spt, x_qry, y_qry, n_inner_iter)
258251
todist.autograd.backward(context_id, qry_loss)
259252
meta_opt.step()
260253

@@ -295,21 +288,11 @@ def test(db, net, epoch, log):
295288
for _ in range(n_test_iter):
296289
x_spt, y_spt, x_qry, y_qry = db.next('test')
297290

298-
task_num = x_spt.size(0)
299-
300291
# TODO: Maybe pull this out into a separate module so it
301292
# doesn't have to be duplicated between `train` and `test`?
302293
n_inner_iter = 5
303294

304-
qry_loss, qry_acc = inner_loop(
305-
net_rref,
306-
x_spt,
307-
y_spt,
308-
x_qry,
309-
y_qry,
310-
n_inner_iter,
311-
task_num,
312-
)
295+
qry_loss, qry_acc = inner_loop(net_rref, x_spt, y_spt, x_qry, y_qry, n_inner_iter)
313296
qry_losses.append(qry_loss.item())
314297
qry_accs.append(qry_acc)
315298

torchopt/distributed/api.py

Lines changed: 195 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,17 @@
3333
import torch
3434
import torch.distributed.rpc as rpc
3535

36-
from torchopt.distributed.world import get_world_size
37-
from torchopt.typing import Future
36+
import torchopt.pytree as pytree
37+
from torchopt.distributed.world import get_worker_id, get_world_rank, get_world_size
38+
from torchopt.typing import Future, PyTree
3839

3940

4041
__all__ = [
41-
'default_partitioner',
42+
'DimPartitioner',
43+
'dim_partitioner',
44+
'batch_partitioner',
45+
'mean_reducer',
46+
'sum_reducer',
4247
'remote_async_call',
4348
'remote_sync_call',
4449
'parallelize',
@@ -47,32 +52,197 @@
4752
]
4853

4954

55+
UNSET_RPC_TIMEOUT = rpc.api.UNSET_RPC_TIMEOUT
56+
57+
5058
T = TypeVar('T')
5159
U = TypeVar('U')
5260
Args = Tuple[Any, ...]
5361
KwArgs = Dict[str, Any]
54-
Partitioner = Union[int, str, Callable[..., Sequence[Tuple[int, Optional[Args], Optional[KwArgs]]]]]
62+
PartitionFunction = Callable[..., Sequence[Tuple[int, Optional[Args], Optional[KwArgs]]]]
63+
Partitioner = Union[int, str, PartitionFunction]
5564

5665

57-
def default_partitioner(
58-
*args: Any,
59-
**kwargs: Any,
60-
) -> List[Tuple[int, Optional[Args], Optional[KwArgs]]]:
61-
"""Default partitioner.
66+
class DimPartitioner:
67+
"""Partitioner class that partitions a batch of inputs along a given dimension.
6268
63-
Replicates the arguments to all workers.
69+
Args:
70+
dim: The dimension to partition.
71+
exclusive: Whether to partition the batch exclusively.
72+
If ``exclusive=True``, the batch will be partitioned into ``batch_size`` partitions,
73+
where ``batch_size`` is the size of the batch along the given dimension.
74+
If ``exclusive=False``, the batch will be partitioned into
75+
``min(batch_size, num_workers)`` partitions, where ``num_workers`` is the number of
76+
workers in the world.
77+
keepdim: Whether to keep the partitioned dimension. Defaults to :data:`True`, i.e., keep the
78+
batch dimension. If :data:`False`, use select instead of slicing. This functionality
79+
should be used with ``exclusive=True``.
80+
workers: The workers to partition the batch to. If :data:`None`, the batch will be
81+
partitioned to all workers in the world.
6482
"""
65-
return [(rank, args, kwargs.copy()) for rank in range(get_world_size())]
83+
84+
def __init__(
85+
self,
86+
dim: int,
87+
*,
88+
exclusive: bool = False,
89+
keepdim: bool = False,
90+
workers: Optional[Sequence[Union[int, str]]] = None,
91+
) -> None:
92+
if not keepdim and not exclusive:
93+
raise ValueError('keepdim=False should be used with exclusive=True.')
94+
95+
self.dim = dim
96+
self.exclusive = exclusive
97+
self.keepdim = keepdim
98+
self.workers = workers
99+
100+
# pylint: disable-next=too-many-branches,too-many-locals
101+
def __call__(
102+
self,
103+
*args: Any,
104+
**kwargs: Any,
105+
) -> List[Tuple[int, Optional[Args], Optional[KwArgs]]]:
106+
if self.workers is None:
107+
workers = list(range(get_world_size()))
108+
else:
109+
workers = self.workers
110+
workers: List[int] = list(map(get_worker_id, workers))
111+
num_workers = len(workers)
112+
113+
args_tree: PyTree[Any] = (args, kwargs)
114+
flattened_args, treedef = pytree.tree_flatten(args_tree)
115+
116+
batch_size = None
117+
for arg in flattened_args:
118+
if isinstance(arg, torch.Tensor):
119+
if batch_size is None:
120+
batch_size = arg.shape[self.dim]
121+
elif batch_size != arg.shape[self.dim]:
122+
raise ValueError(
123+
f'Batch size mismatch on dim={self.dim}. '
124+
f'Expected {batch_size}, got {arg.shape[self.dim]} (shape: {arg.shape}).'
125+
)
126+
127+
if batch_size is None:
128+
return [(get_world_rank(), args, kwargs.copy())]
129+
130+
if self.exclusive:
131+
num_replicas = batch_size
132+
if self.keepdim:
133+
batch_slices = [slice(i, i + 1) for i in range(num_replicas)]
134+
else:
135+
batch_slices = list(range(num_replicas))
136+
else:
137+
if batch_size <= num_workers:
138+
num_replicas = batch_size
139+
batch_slices = [slice(i, i + 1) for i in range(batch_size)] # keepdim=True
140+
else:
141+
num_replicas = num_workers
142+
local_size = batch_size // num_workers
143+
local_batch_indices = [i * local_size for i in range(num_workers)] + [batch_size]
144+
batch_slices = [
145+
slice(local_batch_indices[i], local_batch_indices[i + 1])
146+
for i in range(num_workers)
147+
]
148+
149+
if self.dim >= 0:
150+
batch_slices = [(slice(),) * self.dim + (batch_slice,) for batch_slice in batch_slices]
151+
elif self.dim < 0:
152+
batch_slices = [
153+
(
154+
...,
155+
batch_slice,
156+
)
157+
+ (slice(),) * (-self.dim - 1)
158+
for batch_slice in batch_slices
159+
]
160+
161+
flattened_args_replicas = [[] for _ in range(num_replicas)]
162+
for arg in flattened_args:
163+
if isinstance(arg, torch.Tensor):
164+
for i, batch_slice in enumerate(batch_slices):
165+
flattened_args_replicas[i].append(arg[batch_slice])
166+
else:
167+
for i in range(num_replicas):
168+
flattened_args_replicas[i].append(arg)
169+
170+
args_replicas = [
171+
pytree.tree_unflatten(treedef, args_replica) for args_replica in flattened_args_replicas
172+
]
173+
174+
return [
175+
(workers[i % num_workers], worker_args, worker_kwargs)
176+
for i, (worker_args, worker_kwargs) in enumerate(args_replicas)
177+
]
178+
179+
def __reduce__(
180+
self,
181+
) -> Tuple[
182+
Callable[..., 'DimPartitioner'],
183+
Tuple[int],
184+
Dict[str, Union[bool, Optional[Sequence[Union[int, str]]]]],
185+
]:
186+
return (
187+
DimPartitioner,
188+
(self.dim,),
189+
dict(exclusive=self.exclusive, keepdim=self.keepdim, workers=self.workers),
190+
)
191+
192+
193+
def dim_partitioner(
194+
dim: int = 0,
195+
*,
196+
exclusive: bool = False,
197+
keepdim: bool = True,
198+
workers: Optional[Sequence[Union[int, str]]] = None,
199+
) -> PartitionFunction:
200+
"""Partition a batch of inputs along a given dimension.
201+
202+
Args:
203+
dim: The dimension to partition.
204+
exclusive: Whether to partition the batch exclusively.
205+
If ``exclusive=True``, the batch will be partitioned into ``batch_size`` partitions,
206+
where ``batch_size`` is the size of the batch along the given dimension.
207+
If ``exclusive=False``, the batch will be partitioned into
208+
``min(batch_size, num_workers)`` partitions, where ``num_workers`` is the number of
209+
workers in the world.
210+
keepdim: Whether to keep the partitioned dimension. Defaults to :data:`True`, i.e., keep the
211+
batch dimension. If :data:`False`, use select instead of slicing. This functionality
212+
should be used with ``exclusive=True``.
213+
workers: The workers to partition the batch to. If :data:`None`, the batch will be
214+
partitioned to all workers in the world.
215+
216+
Returns:
217+
A partition function.
218+
"""
219+
return DimPartitioner(dim, exclusive=exclusive, keepdim=keepdim, workers=workers)
220+
221+
222+
# pylint: disable=line-too-long
223+
batch_partitioner: PartitionFunction = dim_partitioner(dim=0, keepdim=True, exclusive=False)
224+
"""Partitioner for batch dimension. Divide and replicates the arguments to all workers along the first dimension."""
225+
# pylint: enable=line-too-long
226+
227+
228+
def mean_reducer(results: Iterable[torch.Tensor]) -> torch.Tensor:
229+
"""Reduce the results by averaging them."""
230+
return torch.mean(torch.stack(tuple(results), dim=0), dim=0)
231+
232+
233+
def sum_reducer(results: Iterable[torch.Tensor]) -> torch.Tensor:
234+
"""Reduce the results by summing them."""
235+
return torch.sum(torch.stack(tuple(results), dim=0), dim=0)
66236

67237

68238
def remote_async_call(
69239
func: Callable[..., T],
70240
*,
71241
args: Optional[Args] = None,
72242
kwargs: Optional[KwArgs] = None,
73-
partitioner: Partitioner = default_partitioner,
243+
partitioner: Partitioner = batch_partitioner,
74244
reducer: Optional[Callable[[Iterable[T]], U]] = None,
75-
timeout: Optional[float] = rpc.api.UNSET_RPC_TIMEOUT,
245+
timeout: Optional[float] = UNSET_RPC_TIMEOUT,
76246
) -> Union[Future[List[T]], Future[U]]:
77247
# pylint: disable=line-too-long
78248
"""Asynchronously do an RPC on remote workers and return the a :class:`torch.Future` instance at the current worker.
@@ -84,7 +254,7 @@ def remote_async_call(
84254
kwargs (Optional[KwArgs], optional): The keyword arguments to pass to the function. Defaults
85255
to :data:`None`.
86256
partitioner (Partitioner, optional): A partitioner that partitions the arguments to multiple
87-
workers. Defaults to :func:`default_partitioner`.
257+
workers. Defaults to :func:`batch_partitioner`.
88258
reducer (Callable[[Iterable[T]], U], optional): A reducer that reduces the results from
89259
multiple workers. Defaults to :data:`None`.
90260
timeout (float, optional): The timeout for the RPC call. Defaults to
@@ -97,10 +267,8 @@ def remote_async_call(
97267
args = ()
98268
if kwargs is None:
99269
kwargs = {}
100-
if isinstance(partitioner, int):
101-
partitions = [(partitioner, args, kwargs)]
102-
elif isinstance(partitioner, str):
103-
partitions = [(rpc.get_worker_info(worker_name=partitioner).id, args, kwargs)]
270+
if isinstance(partitioner, (int, str)):
271+
partitions = [(get_worker_id(id=partitioner), args, kwargs)]
104272
elif callable(partitioner):
105273
partitions = partitioner(*args, **kwargs) # type: ignore[assignment]
106274
else:
@@ -128,9 +296,9 @@ def remote_sync_call(
128296
*,
129297
args: Optional[Args] = None,
130298
kwargs: Optional[KwArgs] = None,
131-
partitioner: Partitioner = default_partitioner,
299+
partitioner: Partitioner = batch_partitioner,
132300
reducer: Optional[Callable[[Iterable[T]], U]] = None,
133-
timeout: Optional[float] = rpc.api.UNSET_RPC_TIMEOUT,
301+
timeout: Optional[float] = UNSET_RPC_TIMEOUT,
134302
) -> Union[List[T], U]:
135303
"""Synchronously do an RPC on remote workers and return the result to the current worker.
136304
@@ -141,7 +309,7 @@ def remote_sync_call(
141309
kwargs (Optional[KwArgs], optional): The keyword arguments to pass to the function. Defaults
142310
to :data:`None`.
143311
partitioner (Partitioner, optional): A partitioner that partitions the arguments to multiple
144-
workers. Defaults to :func:`default_partitioner`.
312+
workers. Defaults to :func:`batch_partitioner`.
145313
reducer (Callable[[Iterable[T]], U], optional): A reducer that reduces the results from
146314
multiple workers. Defaults to :data:`None`.
147315
timeout (float, optional): The timeout for the RPC call. Defaults to
@@ -161,9 +329,9 @@ def remote_sync_call(
161329

162330

163331
def parallelize_async(
164-
partitioner: Partitioner = default_partitioner,
332+
partitioner: Partitioner = batch_partitioner,
165333
reducer: Optional[Callable[[Iterable[T]], U]] = None,
166-
timeout: Optional[float] = rpc.api.UNSET_RPC_TIMEOUT,
334+
timeout: Optional[float] = UNSET_RPC_TIMEOUT,
167335
) -> Callable[[Callable[..., T]], Callable[..., Union[Future[List[T]], Future[U]]]]:
168336
"""Decorator for parallelizing a function.
169337
@@ -173,7 +341,7 @@ def parallelize_async(
173341
174342
Args:
175343
partitioner (Partitioner, optional): A partitioner that partitions the arguments to multiple
176-
workers. Defaults to :func:`default_partitioner`.
344+
workers. Defaults to :func:`batch_partitioner`.
177345
reducer (Callable[[Iterable[T]], U], optional): A reducer that reduces the results from
178346
multiple workers. Defaults to :data:`None`.
179347
timeout (float, optional): The timeout for the RPC call. Defaults to
@@ -214,17 +382,17 @@ def wrapped(*args: Any, **kwargs: Any) -> Union[Future[List[T]], Future[U]]:
214382

215383

216384
def parallelize(
217-
partitioner: Partitioner = default_partitioner,
385+
partitioner: Partitioner = batch_partitioner,
218386
reducer: Optional[Callable[[Iterable[T]], U]] = None,
219-
timeout: Optional[float] = rpc.api.UNSET_RPC_TIMEOUT,
387+
timeout: Optional[float] = UNSET_RPC_TIMEOUT,
220388
) -> Callable[[Callable[..., T]], Callable[..., Union[List[T], U]]]:
221389
"""Decorator for parallelizing a function.
222390
223391
This decorator can be used to parallelize a function call across multiple workers.
224392
225393
Args:
226394
partitioner (Partitioner, optional): A partitioner that partitions the arguments to multiple
227-
workers. Defaults to :func:`default_partitioner`.
395+
workers. Defaults to :func:`batch_partitioner`.
228396
reducer (Callable[[Iterable[T]], U], optional): A reducer that reduces the results from
229397
multiple workers. Defaults to :data:`None`.
230398
timeout (float, optional): The timeout for the RPC call. Defaults to

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy