-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Open
Labels
module: functorchPertaining to torch.func or pytorch/functorchPertaining to torch.func or pytorch/functorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
I'm trying to convert some intermediate tensors to numpy arrays. This is a common use case for Reinforcement Learning (RL) tasks.
I put the sampling process in my objective function. The PyTorch module takes observation
tensor and produces action
, then the action
is converted to a numpy array and sent to the RL environment.
The pipeline is:
def fpolicy(params, observation):
# forward model
logits = fmodel(params, observation)
dist = torch.distributions.Categorical(logits=logits)
action = dist.sample()
logprob = dist.log_prob(action)
return action, logprob
def sample(params):
batch = []
observation = env.reset()
for i in range(horizon):
action, action_logprob = fpolicy(params, observation)
action_numpy = action.detach().cpu().numpy() # <== has side effect here: `.numpy()` access the data storage
next_observation, reward, *_ = env.step(action_numpy)
batch.append((observation, action_logprob, reward, next_observation))
observations, action_logprobs, rewards, next_observations = tuple(zip*(batch))
observations = torch.from_numpy(np.stack(observations))
action_logprobs = torch.stack(action_logprobs)
rewards = torch.from_numpy(np.stack(rewards))
next_observations = torch.from_numpy(np.stack(next_observations))
return observations, action_logprobs, rewards, next_observations
def objective(params):
observations, action_logprobs, rewards, next_observations = sample(params)
return (rewards * action_logprobs).mean()
grad_fn = functorch.grad(objective)
grads = grad_fn(params)
While interacting with the environment, only the value of the tensor is used, the computation graph has been detached.
In functorch.grad
/ functorch.vjp
, the input params
tensors are wrapped as GradTrackingTensor
, the all intermediate tensors will also be GradTrackingTensor
(e.g., the action
tensor). They can be .detach()
but cannot convert to numpy arrays .numpy()
.
A minimal script to reproduce this:
import functorch
import torch
def mean(x):
mu = x.mean()
mu.detach() # OK
mu.detach().cpu() # OK
mu.detach().cpu().numpy() # FAIL
return mu
grad_fn = functorch.grad(mean)
grads = grad_fn(torch.randn(8))
╭───────────────────────────────────────────── Traceback (most recent call last) ─────────────────────────────────────────────╮
│ /home/PanXuehai/test.py:12 in <module> │
│ │
│ 9 │ return mu │
│ 10 │
│ 11 grad_fn = functorch.grad(mean) │
│ ❱ 12 grads = grad_fn(torch.randn(8)) │
│ 13 │
│ │
│ /home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.9/site-packages/functorch/_src/eager_transforms.py:1241 in wrapper │
│ │
│ 1238 │ """ │
│ 1239 │ @wraps(func) │
│ 1240 │ def wrapper(*args, **kwargs): │
│ ❱ 1241 │ │ results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs) │
│ 1242 │ │ if has_aux: │
│ 1243 │ │ │ grad, (_, aux) = results │
│ 1244 │ │ │ return grad, aux │
│ │
│ /home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.9/site-packages/functorch/_src/vmap.py:35 in fn │
│ │
│ 32 │ @functools.wraps(f) │
│ 33 │ def fn(*args, **kwargs): │
│ 34 │ │ with torch.autograd.graph.disable_saved_tensors_hooks(message): │
│ ❱ 35 │ │ │ return f(*args, **kwargs) │
│ 36 │ return fn │
│ 37 │
│ 38 │
│ │
│ /home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.9/site-packages/functorch/_src/eager_transforms.py:1111 in wrapper │
│ │
│ 1108 │ │ │ │ diff_args = _slice_argnums(args, argnums, as_tuple=False) │
│ 1109 │ │ │ │ tree_map_(partial(_create_differentiable, level=level), diff_args) │
│ 1110 │ │ │ │ │
│ ❱ 1111 │ │ │ │ output = func(*args, **kwargs) │
│ 1112 │ │ │ │ if has_aux: │
│ 1113 │ │ │ │ │ if not (isinstance(output, tuple) and len(output) == 2): │
│ 1114 │ │ │ │ │ │ raise RuntimeError( │
│ │
│ /home/PanXuehai/test.py:8 in mean │
│ │
│ 5 │ mu = x.mean() │
│ 6 │ mu.detach() # OK │
│ 7 │ mu.detach().cpu() # OK │
│ ❱ 8 │ mu.detach().cpu().numpy() # FAIL │
│ 9 │ return mu │
│ 10 │
│ 11 grad_fn = functorch.grad(mean) │
╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
Versions
Collecting environment information...
PyTorch version: 1.13.1
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (conda-forge gcc 10.4.0-19) 10.4.0
Clang version: 10.0.1
CMake version: version 3.22.1
Libc version: glibc-2.31
Python version: 3.9.15 (main, Nov 24 2022, 14:31:59) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.79.1-microsoft-standard-WSL2-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.7.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090
Nvidia driver version: 527.56
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] mypy==0.991
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.5
[pip3] torch==1.13.1
[pip3] torchopt==0.6.1.dev7+g6cb2b4f
[pip3] torchvision==0.14.1
[pip3] torchviz==0.0.2
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.3.1 h2bc3f7f_2
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py39h7f8727e_0
[conda] mkl_fft 1.3.1 py39hd3c417c_0
[conda] mkl_random 1.2.2 py39h51133e4_0
[conda] numpy 1.23.5 py39h14f4228_0
[conda] numpy-base 1.23.5 py39h31eccc5_0
[conda] pytorch 1.13.1 py3.9_cuda11.7_cudnn8.5.0_0
[conda] pytorch-cuda 11.7 h67b0de4_1
[conda] pytorch-mutex 1.0 cuda
[conda] torchopt 0.6.1.dev7+g6cb2b4f pypi_0
[conda] torchvision 0.14.1 py39_cu117
[conda] torchviz 0.0.2 pypi_0
ConnorStoneAstro, tinaoberoi and gabrieldernbach
Metadata
Metadata
Assignees
Labels
module: functorchPertaining to torch.func or pytorch/functorchPertaining to torch.func or pytorch/functorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module