-
Notifications
You must be signed in to change notification settings - Fork 39
Closed
Labels
distributedSomething related to distributed trainingSomething related to distributed trainingenhancementNew feature or requestNew feature or requestfeatureNew featureNew featurefunctorchSomething functorch relatedSomething functorch related
Milestone
Description
Motivation
Task-level parallelization for multi-host multi-process optimization.
Batch-level parallelization can be implemented easily by wrapping the network (nn.Module
) with:
torch.nn.DataParallel
(single-host multi-GPUs) (SPMD)torch.nn.parallel.DistributedDataParallel
(multi-host multi-GPUs)
However, for algorithms that require task-level parallelization, non of the above solutions work. torch.nn.DataParallel
and torch.nn.parallel.DistributedDataParallel
are used for module-level parallelization. The wrapper will replicate the user module to multiple copies, then do the forward pass in parallel. For task-level parallelization, each task needs to maintain its own model parameters and (optional) training data. The module parameters may be different across tasks.
Solution
functorch.vmap
+ distributed data parallel optimization.
Example
import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
from torch import optim
from torch.distributed.optim import DistributedOptimizer
with dist_autograd.context() as context_id:
# Forward pass.
rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
loss = rref1.to_here() + rref2.to_here()
# Backward pass.
dist_autograd.backward(context_id, [loss.sum()])
# Optimizer.
dist_optim = DistributedOptimizer(
optim.SGD,
[rref1, rref2],
lr=0.05,
)
dist_optim.step(context_id)
Additional context
Resources:
PyTorch:
- Tutorial: PyTorch Distributed Overview
- Tutorial: Distributed Data Parallel
- API: Module level Data Parallel
torch.nn.DataParallel
(SPMD) - API: Module level Distributed Data Parallel
torch.nn.parallel.DistributedDataParallel
- API: PyTorch Distributed Optimizers
torch.distributed.optim
- API: Vectorization map
functorch.vmap
- API: NVIDIA apex.parallel
JAX:
- Tutorial: Named axes and easy-to-revise parallelism
- API: Vectorization map
jax.vmap
- API: Parallel map
jax.pmap
(SPMD) - API (Experimental):
jax.experimental.maps.xmap
- Tutorial: Using JAX in multi-host and multi-process environments
Checklist
- I have checked that there is no similar issue in the repo (required)
Metadata
Metadata
Assignees
Labels
distributedSomething related to distributed trainingSomething related to distributed trainingenhancementNew feature or requestNew feature or requestfeatureNew featureNew featurefunctorchSomething functorch relatedSomething functorch related
Type
Projects
Status
Done