Description
🐛 Describe the bug
While using RPC call to distributedly train a model, if the remote model's parameters()
are nn.Parameter
s, it automatically detaches from the original computation graph when calling model_rref.to_here()
. This results in wrong gradients in the distributed backward pass.
The policies of torch.Tensor
and torch.nn.Parameter
are different and undocumented. The behavior is not intuitive and invisible.
1. Send nn.Module
that contains nn.Parameter
as argument to RPC call:
import os
import random
from threading import Lock
import atexit
import numpy as np
import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
import torch.nn as nn
WORLD_RANK = RANK = int(os.getenv('RANK'))
WORLD_SIZE = int(os.getenv('WORLD_SIZE'))
AUTOGRAD_LOCK = Lock()
fmt = '{name} ({tensor.__class__.__module__}.{tensor.__class__.__qualname__}): param={tensor.data}, grad={tensor.grad}, grad_fn={tensor.grad_fn}'.format
def get_model():
model = nn.Linear(1, 1)
nn.init.ones_(model.weight)
nn.init.zeros_(model.bias)
return model
def worker_init():
random.seed(RANK)
np.random.seed(RANK)
torch.manual_seed(RANK)
torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
print(f'Worker init => Rank {RANK}')
def main():
rpc.init_rpc(name=f'worker{RANK}', rank=RANK, world_size=WORLD_SIZE)
atexit.register(rpc.shutdown, graceful=True)
rpc.api._barrier(worker_names={f'worker{i}' for i in range(WORLD_SIZE)})
worker_init()
rpc.api._barrier(worker_names={f'worker{i}' for i in range(WORLD_SIZE)})
if RANK == 0:
model = get_model()
train(model)
else:
pass # listening for RPC calls
def compute_loss(model_rref, x):
if isinstance(model_rref, nn.Module):
model = model_rref
else:
model = model_rref.to_here()
print()
print('Before RPC forward:')
for name, param in model.named_parameters():
print(f' worker{RANK} => {fmt(name=name, tensor=param)}')
out = model(x)
return out.mean()
def convert_param_type(model):
for param in model.parameters():
param.__class__ = torch.Tensor
def train(model):
x = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
print()
print('Before forward:')
for name, param in model.named_parameters():
print(f' {fmt(name=name, tensor=param)}')
# convert_param_type(model)
model_rref = rpc.RRef(model)
with dist_autograd.context() as ctx:
loss0 = rpc.rpc_sync('worker0', compute_loss, args=(model_rref, x[: len(x) // 2]))
loss1 = rpc.rpc_sync('worker1', compute_loss, args=(model_rref, x[len(x) // 2 :]))
loss = torch.mean(torch.stack([loss0, loss1]))
print()
print(f'Loss: {loss}')
dist_autograd.backward(ctx, [loss])
with AUTOGRAD_LOCK:
all_local_grads = dist_autograd.get_gradients(ctx)
for p, g in all_local_grads.items():
if p.grad is not None:
p.grad = p.grad.add(g)
else:
p.grad = g
print()
print('After backward:')
for name, param in model.named_parameters():
print(f' {fmt(name=name, tensor=param)}')
if __name__ == '__main__':
main()
Result: only the gradient on the local worker (worker0) is collected. Because model_rref.to_here()
on the remote (worker1) automatically detaches nn.Parameter
s.
As we can see the type of param
is nn.Paramter
and param.grad_fn
is None
on worker1:
$ torchrun --nnode 1 --nproc_per_node 2 rpc_test.py
WARNING:torch.distributed.run:
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
*****************************************
Worker init => Rank 1
Worker init => Rank 0
Before forward:
weight (torch.nn.parameter.Parameter): param=tensor([[1.]]), grad=None, grad_fn=None
bias (torch.nn.parameter.Parameter): param=tensor([0.]), grad=None, grad_fn=None
Before RPC forward:
worker0 => weight (torch.nn.parameter.Parameter): param=tensor([[1.]]), grad=None, grad_fn=None
worker0 => bias (torch.nn.parameter.Parameter): param=tensor([0.]), grad=None, grad_fn=None
Before RPC forward:
worker1 => weight (torch.nn.parameter.Parameter): param=tensor([[1.]]), grad=None, grad_fn=None
worker1 => bias (torch.nn.parameter.Parameter): param=tensor([0.]), grad=None, grad_fn=None
Loss: 2.5
After backward:
weight (torch.nn.parameter.Parameter): param=tensor([[1.]]), grad=tensor([[0.7500]]), grad_fn=None
bias (torch.nn.parameter.Parameter): param=tensor([0.]), grad=tensor([0.5000]), grad_fn=None
We get the wrong gradients:
weight.grad = tensor([[0.7500]])
bais.grad = tensor([[0.7500]])
2. Convert all nn.Parameter
s to torch.Tensor
s before RPC calls:
The script is same as the above one, but uncomment the # convert_param_type(model)
:
convert_param_type(model) # change __class__ of all `nn.Parameter`s
model_rref = rpc.RRef(model)
with dist_autograd.context() as ctx:
loss0 = rpc.rpc_sync('worker0', compute_loss, args=(model_rref, x[: len(x) // 2]))
loss1 = rpc.rpc_sync('worker1', compute_loss, args=(model_rref, x[len(x) // 2 :]))
loss = torch.mean(torch.stack([loss0, loss1]))
print()
print(f'Loss: {loss}')
dist_autograd.backward(ctx, [loss])
with AUTOGRAD_LOCK:
all_local_grads = dist_autograd.get_gradients(ctx)
for p, g in all_local_grads.items():
if p.grad is not None:
p.grad = p.grad.add(g)
else:
p.grad = g
Result:
As we can see the type of param
is torch.Tensor
and param.grad_fn
is not None
on worker1. It points to the original module on worker0 and can be traced by distributed autograd.
torchrun --nnode 1 --nproc_per_node 2 rpc_test.py
WARNING:torch.distributed.run:
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
*****************************************
Worker init => Rank 1
Worker init => Rank 0
Before forward:
weight (torch.nn.parameter.Parameter): param=tensor([[1.]]), grad=None, grad_fn=None
bias (torch.nn.parameter.Parameter): param=tensor([0.]), grad=None, grad_fn=None
Before RPC forward:
worker0 => weight (torch.Tensor): param=tensor([[1.]]), grad=None, grad_fn=None
worker0 => bias (torch.Tensor): param=tensor([0.]), grad=None, grad_fn=None
Before RPC forward:
worker1 => weight (torch.Tensor): param=tensor([[1.]]), grad=None, grad_fn=<CppFunction object at 0x7fbc145feca0>
worker1 => bias (torch.Tensor): param=tensor([0.]), grad=None, grad_fn=<CppFunction object at 0x7fbc145feca0>
Loss: 2.5
After backward:
weight (torch.Tensor): param=tensor([[1.]]), grad=tensor([[2.5000]]), grad_fn=None
bias (torch.Tensor): param=tensor([0.]), grad=tensor([1.]), grad_fn=None
We get the correct gradients:
weight.grad = tensor([[2.5000]])
bais.grad = tensor([[1.]])
Versions
Collecting environment information...
PyTorch version: 1.12.1
Is debug build: False
CUDA used to build PyTorch: 11.6
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (conda-forge gcc 10.4.0-16) 10.4.0
Clang version: 10.0.1
CMake version: version 3.22.1
Libc version: glibc-2.31
Python version: 3.8.12 | packaged by conda-forge | (default, Jan 30 2022, 23:42:07) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.10.102.1-microsoft-standard-WSL2-x86_64-with-glibc2.10
Is CUDA available: True
CUDA runtime version: 11.6.124
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3070
Nvidia driver version: 516.94
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] functorch==0.2.1
[pip3] mypy==0.990+dev.589ad1c17eeb220a41ba41425b61b8593f8bc42d
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.1
[pip3] torch==1.12.1
[pip3] torchopt==0.5.1.dev66+gd738101
[pip3] torchvision==0.13.1
[pip3] torchviz==0.0.2
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.6.0 habf752d_9
[conda] functorch 0.2.1 pypi_0
[conda] libblas 3.9.0 12_linux64_mkl
[conda] libcblas 3.9.0 12_linux64_mkl
[conda] liblapack 3.9.0 12_linux64_mkl
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py38h7f8727e_0
[conda] mkl_fft 1.3.1 py38hd3c417c_0
[conda] mkl_random 1.2.2 py38h51133e4_0
[conda] numpy 1.23.1 py38h6c91a56_0
[conda] numpy-base 1.23.1 py38ha15fc14_0
[conda] pytorch 1.12.1 py3.8_cuda11.6_cudnn8.3.2_0
[conda] pytorch-mutex 1.0 cuda
[conda] torchopt 0.5.1.dev66+gd738101 pypi_0
[conda] torchvision 0.13.1 py38_cu116
[conda] torchviz 0.0.2 pypi_0
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu