Skip to content

[Distributed: RPC] Sending nn.Parameter as RPC argument automatically detaches from the computation graph #86525

Open
@XuehaiPan

Description

@XuehaiPan

🐛 Describe the bug

While using RPC call to distributedly train a model, if the remote model's parameters() are nn.Parameters, it automatically detaches from the original computation graph when calling model_rref.to_here(). This results in wrong gradients in the distributed backward pass.

The policies of torch.Tensor and torch.nn.Parameter are different and undocumented. The behavior is not intuitive and invisible.

1. Send nn.Module that contains nn.Parameter as argument to RPC call:

import os
import random
from threading import Lock
import atexit

import numpy as np
import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
import torch.nn as nn


WORLD_RANK = RANK = int(os.getenv('RANK'))
WORLD_SIZE = int(os.getenv('WORLD_SIZE'))

AUTOGRAD_LOCK = Lock()


fmt = '{name} ({tensor.__class__.__module__}.{tensor.__class__.__qualname__}): param={tensor.data}, grad={tensor.grad}, grad_fn={tensor.grad_fn}'.format


def get_model():
    model = nn.Linear(1, 1)
    nn.init.ones_(model.weight)
    nn.init.zeros_(model.bias)
    return model


def worker_init():
    random.seed(RANK)
    np.random.seed(RANK)
    torch.manual_seed(RANK)

    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    print(f'Worker init => Rank {RANK}')


def main():
    rpc.init_rpc(name=f'worker{RANK}', rank=RANK, world_size=WORLD_SIZE)
    atexit.register(rpc.shutdown, graceful=True)

    rpc.api._barrier(worker_names={f'worker{i}' for i in range(WORLD_SIZE)})
    worker_init()
    rpc.api._barrier(worker_names={f'worker{i}' for i in range(WORLD_SIZE)})

    if RANK == 0:
        model = get_model()
        train(model)
    else:
        pass  # listening for RPC calls


def compute_loss(model_rref, x):
    if isinstance(model_rref, nn.Module):
        model = model_rref
    else:
        model = model_rref.to_here()

    print()
    print('Before RPC forward:')
    for name, param in model.named_parameters():
        print(f'    worker{RANK} => {fmt(name=name, tensor=param)}')

    out = model(x)
    return out.mean()


def convert_param_type(model):
    for param in model.parameters():
        param.__class__ = torch.Tensor


def train(model):
    x = torch.tensor([[1.0], [2.0], [3.0], [4.0]])

    print()
    print('Before forward:')
    for name, param in model.named_parameters():
        print(f'    {fmt(name=name, tensor=param)}')

    # convert_param_type(model)
    model_rref = rpc.RRef(model)
    with dist_autograd.context() as ctx:
        loss0 = rpc.rpc_sync('worker0', compute_loss, args=(model_rref, x[: len(x) // 2]))
        loss1 = rpc.rpc_sync('worker1', compute_loss, args=(model_rref, x[len(x) // 2 :]))
        loss = torch.mean(torch.stack([loss0, loss1]))
        print()
        print(f'Loss: {loss}')
        dist_autograd.backward(ctx, [loss])

        with AUTOGRAD_LOCK:
            all_local_grads = dist_autograd.get_gradients(ctx)
            for p, g in all_local_grads.items():
                if p.grad is not None:
                    p.grad = p.grad.add(g)
                else:
                    p.grad = g

    print()
    print('After backward:')
    for name, param in model.named_parameters():
        print(f'    {fmt(name=name, tensor=param)}')


if __name__ == '__main__':
    main()

Result: only the gradient on the local worker (worker0) is collected. Because model_rref.to_here() on the remote (worker1) automatically detaches nn.Parameters.

As we can see the type of param is nn.Paramter and param.grad_fn is None on worker1:

$ torchrun --nnode 1 --nproc_per_node 2 rpc_test.py
WARNING:torch.distributed.run:
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
*****************************************
Worker init => Rank 1
Worker init => Rank 0

Before forward:
    weight (torch.nn.parameter.Parameter): param=tensor([[1.]]), grad=None, grad_fn=None
    bias (torch.nn.parameter.Parameter): param=tensor([0.]), grad=None, grad_fn=None

Before RPC forward:
    worker0 => weight (torch.nn.parameter.Parameter): param=tensor([[1.]]), grad=None, grad_fn=None
    worker0 => bias (torch.nn.parameter.Parameter): param=tensor([0.]), grad=None, grad_fn=None

Before RPC forward:
    worker1 => weight (torch.nn.parameter.Parameter): param=tensor([[1.]]), grad=None, grad_fn=None
    worker1 => bias (torch.nn.parameter.Parameter): param=tensor([0.]), grad=None, grad_fn=None

Loss: 2.5

After backward:
    weight (torch.nn.parameter.Parameter): param=tensor([[1.]]), grad=tensor([[0.7500]]), grad_fn=None
    bias (torch.nn.parameter.Parameter): param=tensor([0.]), grad=tensor([0.5000]), grad_fn=None

We get the wrong gradients:

weight.grad = tensor([[0.7500]])
bais.grad = tensor([[0.7500]])

2. Convert all nn.Parameters to torch.Tensors before RPC calls:

The script is same as the above one, but uncomment the # convert_param_type(model):

    convert_param_type(model)  # change __class__ of all `nn.Parameter`s
    model_rref = rpc.RRef(model)
    with dist_autograd.context() as ctx:
        loss0 = rpc.rpc_sync('worker0', compute_loss, args=(model_rref, x[: len(x) // 2]))
        loss1 = rpc.rpc_sync('worker1', compute_loss, args=(model_rref, x[len(x) // 2 :]))
        loss = torch.mean(torch.stack([loss0, loss1]))
        print()
        print(f'Loss: {loss}')
        dist_autograd.backward(ctx, [loss])

        with AUTOGRAD_LOCK:
            all_local_grads = dist_autograd.get_gradients(ctx)
            for p, g in all_local_grads.items():
                if p.grad is not None:
                    p.grad = p.grad.add(g)
                else:
                    p.grad = g

Result:

As we can see the type of param is torch.Tensor and param.grad_fn is not None on worker1. It points to the original module on worker0 and can be traced by distributed autograd.

torchrun --nnode 1 --nproc_per_node 2 rpc_test.py
WARNING:torch.distributed.run:
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
*****************************************
Worker init => Rank 1
Worker init => Rank 0

Before forward:
    weight (torch.nn.parameter.Parameter): param=tensor([[1.]]), grad=None, grad_fn=None
    bias (torch.nn.parameter.Parameter): param=tensor([0.]), grad=None, grad_fn=None

Before RPC forward:
    worker0 => weight (torch.Tensor): param=tensor([[1.]]), grad=None, grad_fn=None
    worker0 => bias (torch.Tensor): param=tensor([0.]), grad=None, grad_fn=None

Before RPC forward:
    worker1 => weight (torch.Tensor): param=tensor([[1.]]), grad=None, grad_fn=<CppFunction object at 0x7fbc145feca0>
    worker1 => bias (torch.Tensor): param=tensor([0.]), grad=None, grad_fn=<CppFunction object at 0x7fbc145feca0>

Loss: 2.5

After backward:
    weight (torch.Tensor): param=tensor([[1.]]), grad=tensor([[2.5000]]), grad_fn=None
    bias (torch.Tensor): param=tensor([0.]), grad=tensor([1.]), grad_fn=None

We get the correct gradients:

weight.grad = tensor([[2.5000]])
bais.grad = tensor([[1.]])

Versions

Collecting environment information...
PyTorch version: 1.12.1
Is debug build: False
CUDA used to build PyTorch: 11.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (conda-forge gcc 10.4.0-16) 10.4.0
Clang version: 10.0.1 
CMake version: version 3.22.1
Libc version: glibc-2.31

Python version: 3.8.12 | packaged by conda-forge | (default, Jan 30 2022, 23:42:07)  [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.10.102.1-microsoft-standard-WSL2-x86_64-with-glibc2.10
Is CUDA available: True
CUDA runtime version: 11.6.124
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3070
Nvidia driver version: 516.94
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] functorch==0.2.1
[pip3] mypy==0.990+dev.589ad1c17eeb220a41ba41425b61b8593f8bc42d
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.1
[pip3] torch==1.12.1
[pip3] torchopt==0.5.1.dev66+gd738101
[pip3] torchvision==0.13.1
[pip3] torchviz==0.0.2
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               11.6.0               habf752d_9  
[conda] functorch                 0.2.1                    pypi_0  
[conda] libblas                   3.9.0            12_linux64_mkl  
[conda] libcblas                  3.9.0            12_linux64_mkl  
[conda] liblapack                 3.9.0            12_linux64_mkl  
[conda] mkl                       2021.4.0           h06a4308_640  
[conda] mkl-service               2.4.0            py38h7f8727e_0  
[conda] mkl_fft                   1.3.1            py38hd3c417c_0  
[conda] mkl_random                1.2.2            py38h51133e4_0  
[conda] numpy                     1.23.1           py38h6c91a56_0  
[conda] numpy-base                1.23.1           py38ha15fc14_0  
[conda] pytorch                   1.12.1          py3.8_cuda11.6_cudnn8.3.2_0  
[conda] pytorch-mutex             1.0                        cuda  
[conda] torchopt                  0.5.1.dev66+gd738101          pypi_0  
[conda] torchvision               0.13.1               py38_cu116  
[conda] torchviz                  0.0.2                    pypi_0

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: distributedAdd this issue/PR to distributed oncall triage queue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      pFad - Phonifier reborn

      Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

      Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


      Alternative Proxies:

      Alternative Proxy

      pFad Proxy

      pFad v3 Proxy

      pFad v4 Proxy