Closed
Description
Motivation
Task-level parallelization for multi-host multi-process optimization.
Batch-level parallelization can be implemented easily by wrapping the network (nn.Module
) with:
torch.nn.DataParallel
(single-host multi-GPUs) (SPMD)torch.nn.parallel.DistributedDataParallel
(multi-host multi-GPUs)
However, for algorithms that require task-level parallelization, non of the above solutions work. torch.nn.DataParallel
and torch.nn.parallel.DistributedDataParallel
are used for module-level parallelization. The wrapper will replicate the user module to multiple copies, then do the forward pass in parallel. For task-level parallelization, each task needs to maintain its own model parameters and (optional) training data. The module parameters may be different across tasks.
Solution
functorch.vmap
+ distributed data parallel optimization.
Example
import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
from torch import optim
from torch.distributed.optim import DistributedOptimizer
with dist_autograd.context() as context_id:
# Forward pass.
rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
loss = rref1.to_here() + rref2.to_here()
# Backward pass.
dist_autograd.backward(context_id, [loss.sum()])
# Optimizer.
dist_optim = DistributedOptimizer(
optim.SGD,
[rref1, rref2],
lr=0.05,
)
dist_optim.step(context_id)
Additional context
Resources:
PyTorch:
- Tutorial: PyTorch Distributed Overview
- Tutorial: Distributed Data Parallel
- API: Module level Data Parallel
torch.nn.DataParallel
(SPMD) - API: Module level Distributed Data Parallel
torch.nn.parallel.DistributedDataParallel
- API: PyTorch Distributed Optimizers
torch.distributed.optim
- API: Vectorization map
functorch.vmap
- API: NVIDIA apex.parallel
JAX:
- Tutorial: Named axes and easy-to-revise parallelism
- API: Vectorization map
jax.vmap
- API: Parallel map
jax.pmap
(SPMD) - API (Experimental):
jax.experimental.maps.xmap
- Tutorial: Using JAX in multi-host and multi-process environments
Checklist
- I have checked that there is no similar issue in the repo (required)
Metadata
Metadata
Assignees
Labels
Type
Projects
Status