Skip to content

[Bug/functorch] Cannot use tensor.detach().numpy() for GradTrackingTensor: Cannot access data pointer of Tensor that doesn't have storage #91810

@XuehaiPan

Description

@XuehaiPan

🐛 Describe the bug

I'm trying to convert some intermediate tensors to numpy arrays. This is a common use case for Reinforcement Learning (RL) tasks.

I put the sampling process in my objective function. The PyTorch module takes observation tensor and produces action, then the action is converted to a numpy array and sent to the RL environment.

The pipeline is:

def fpolicy(params, observation):
    # forward model
    logits = fmodel(params, observation)
    dist = torch.distributions.Categorical(logits=logits)
    action = dist.sample()
    logprob = dist.log_prob(action)
    return action, logprob

def sample(params):
    batch = []
    observation = env.reset()
    for i in range(horizon):
        action, action_logprob = fpolicy(params, observation)
        action_numpy = action.detach().cpu().numpy()  # <== has side effect here: `.numpy()` access the data storage
        next_observation, reward, *_ = env.step(action_numpy)
        batch.append((observation, action_logprob, reward, next_observation))
    
    observations, action_logprobs, rewards, next_observations = tuple(zip*(batch))
    observations = torch.from_numpy(np.stack(observations))
    action_logprobs = torch.stack(action_logprobs)
    rewards = torch.from_numpy(np.stack(rewards))
    next_observations = torch.from_numpy(np.stack(next_observations))
    return observations, action_logprobs, rewards, next_observations

def objective(params):
    observations, action_logprobs, rewards, next_observations = sample(params)
    return (rewards * action_logprobs).mean()

grad_fn = functorch.grad(objective)
grads = grad_fn(params)

While interacting with the environment, only the value of the tensor is used, the computation graph has been detached.

In functorch.grad / functorch.vjp, the input params tensors are wrapped as GradTrackingTensor, the all intermediate tensors will also be GradTrackingTensor (e.g., the action tensor). They can be .detach() but cannot convert to numpy arrays .numpy().

A minimal script to reproduce this:

import functorch
import torch

def mean(x):
    mu = x.mean()
    mu.detach()                # OK
    mu.detach().cpu()          # OK
    mu.detach().cpu().numpy()  # FAIL
    return mu

grad_fn = functorch.grad(mean)
grads = grad_fn(torch.randn(8))
╭───────────────────────────────────────────── Traceback (most recent call last) ─────────────────────────────────────────────╮
│ /home/PanXuehai/test.py:12 in <module>                                                                                      │
│                                                                                                                             │
│    9return mu                                                                                                          │
│   10                                                                                                                        │
│   11 grad_fn = functorch.grad(mean)                                                                                         │
│ ❱ 12 grads = grad_fn(torch.randn(8))                                                                                        │
│   13                                                                                                                        │
│                                                                                                                             │
│ /home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.9/site-packages/functorch/_src/eager_transforms.py:1241 in wrapper     │
│                                                                                                                             │
│   1238 │   """                                                                                                              │
│   1239 │   @wraps(func)                                                                                                     │
│   1240def wrapper(*args, **kwargs):                                                                                    │
│ ❱ 1241 │   │   results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)                                    │
│   1242 │   │   if has_aux:                                                                                                  │
│   1243 │   │   │   grad, (_, aux) = results                                                                                 │
│   1244 │   │   │   return grad, aux                                                                                         │
│                                                                                                                             │
│ /home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.9/site-packages/functorch/_src/vmap.py:35 in fn                        │
│                                                                                                                             │
│    32 │   @functools.wraps(f)                                                                                               │
│    33def fn(*args, **kwargs):                                                                                          │
│    34 │   │   with torch.autograd.graph.disable_saved_tensors_hooks(message):                                               │
│ ❱  35 │   │   │   return f(*args, **kwargs)                                                                                 │
│    36return fn                                                                                                         │
│    37                                                                                                                       │
│    38                                                                                                                       │
│                                                                                                                             │
│ /home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.9/site-packages/functorch/_src/eager_transforms.py:1111 in wrapper     │
│                                                                                                                             │
│   1108 │   │   │   │   diff_args = _slice_argnums(args, argnums, as_tuple=False)                                            │
│   1109 │   │   │   │   tree_map_(partial(_create_differentiable, level=level), diff_args)                                   │
│   1110 │   │   │   │                                                                                                        │
│ ❱ 1111 │   │   │   │   output = func(*args, **kwargs)                                                                       │
│   1112 │   │   │   │   if has_aux:                                                                                          │
│   1113 │   │   │   │   │   if not (isinstance(output, tuple) and len(output) == 2):                                         │
│   1114 │   │   │   │   │   │   raise RuntimeError(                                                                          │
│                                                                                                                             │
│ /home/PanXuehai/test.py:8 in mean                                                                                           │
│                                                                                                                             │
│    5mu = x.mean()                                                                                                      │
│    6mu.detach()                # OK                                                                                    │7mu.detach().cpu()          # OK                                                                                    │
│ ❱  8mu.detach().cpu().numpy()  # FAIL                                                                                  │9return mu                                                                                                          │
│   10                                                                                                                        │
│   11 grad_fn = functorch.grad(mean)                                                                                         │
╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage

Versions

Collecting environment information...
PyTorch version: 1.13.1
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (conda-forge gcc 10.4.0-19) 10.4.0
Clang version: 10.0.1 
CMake version: version 3.22.1
Libc version: glibc-2.31

Python version: 3.9.15 (main, Nov 24 2022, 14:31:59)  [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.79.1-microsoft-standard-WSL2-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.7.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090
Nvidia driver version: 527.56
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy==0.991
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.5
[pip3] torch==1.13.1
[pip3] torchopt==0.6.1.dev7+g6cb2b4f
[pip3] torchvision==0.14.1
[pip3] torchviz==0.0.2
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               11.3.1               h2bc3f7f_2  
[conda] mkl                       2021.4.0           h06a4308_640  
[conda] mkl-service               2.4.0            py39h7f8727e_0  
[conda] mkl_fft                   1.3.1            py39hd3c417c_0  
[conda] mkl_random                1.2.2            py39h51133e4_0  
[conda] numpy                     1.23.5           py39h14f4228_0  
[conda] numpy-base                1.23.5           py39h31eccc5_0  
[conda] pytorch                   1.13.1          py3.9_cuda11.7_cudnn8.5.0_0  
[conda] pytorch-cuda              11.7                 h67b0de4_1  
[conda] pytorch-mutex             1.0                        cuda  
[conda] torchopt                  0.6.1.dev7+g6cb2b4f          pypi_0  
[conda] torchvision               0.14.1               py39_cu117  
[conda] torchviz                  0.0.2                    pypi_0

cc @zou3519 @Chillee @samdow @soumith

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: functorchPertaining to torch.func or pytorch/functorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      pFad - Phonifier reborn

      Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

      Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


      Alternative Proxies:

      Alternative Proxy

      pFad Proxy

      pFad v3 Proxy

      pFad v4 Proxy