Shortcuts

torch.func.grad

torch.func.grad(func, argnums=0, has_aux=False)[source]

grad operator helps computing gradients of func with respect to the input(s) specified by argnums. This operator can be nested to compute higher-order gradients.

Parameters
  • func (Callable) – A Python function that takes one or more arguments. Must return a single-element Tensor. If specified has_aux equals True, function can return a tuple of single-element Tensor and other auxiliary objects: (output, aux).

  • argnums (int or Tuple[int]) – Specifies arguments to compute gradients with respect to. argnums can be single integer or tuple of integers. Default: 0.

  • has_aux (bool) – Flag indicating that func returns a tensor and other auxiliary objects: (output, aux). Default: False.

Returns

Function to compute gradients with respect to its inputs. By default, the output of the function is the gradient tensor(s) with respect to the first argument. If specified has_aux equals True, tuple of gradients and output auxiliary objects is returned. If argnums is a tuple of integers, a tuple of output gradients with respect to each argnums value is returned.

Return type

Callable

Example of using grad:

>>> from torch.func import grad
>>> x = torch.randn([])
>>> cos_x = grad(lambda x: torch.sin(x))(x)
>>> assert torch.allclose(cos_x, x.cos())
>>>
>>> # Second-order gradients
>>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
>>> assert torch.allclose(neg_sin_x, -x.sin())

When composed with vmap, grad can be used to compute per-sample-gradients:

>>> from torch.func import grad, vmap
>>> batch_size, feature_size = 3, 5
>>>
>>> def model(weights, feature_vec):
>>>     # Very simple linear model with activation
>>>     assert feature_vec.dim() == 1
>>>     return feature_vec.dot(weights).relu()
>>>
>>> def compute_loss(weights, example, target):
>>>     y = model(weights, example)
>>>     return ((y - target) ** 2).mean()  # MSELoss
>>>
>>> weights = torch.randn(feature_size, requires_grad=True)
>>> examples = torch.randn(batch_size, feature_size)
>>> targets = torch.randn(batch_size)
>>> inputs = (weights, examples, targets)
>>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)

Example of using grad with has_aux and argnums:

>>> from torch.func import grad
>>> def my_loss_func(y, y_pred):
>>>    loss_per_sample = (0.5 * y_pred - y) ** 2
>>>    loss = loss_per_sample.mean()
>>>    return loss, (y_pred, loss_per_sample)
>>>
>>> fn = grad(my_loss_func, argnums=(0, 1), has_aux=True)
>>> y_true = torch.rand(4)
>>> y_preds = torch.rand(4, requires_grad=True)
>>> out = fn(y_true, y_preds)
>>> # > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample))

Note

Using PyTorch torch.no_grad together with grad.

Case 1: Using torch.no_grad inside a function:

>>> def f(x):
>>>     with torch.no_grad():
>>>         c = x ** 2
>>>     return x - c

In this case, grad(f)(x) will respect the inner torch.no_grad.

Case 2: Using grad inside torch.no_grad context manager:

>>> with torch.no_grad():
>>>     grad(f)(x)

In this case, grad will respect the inner torch.no_grad, but not the outer one. This is because grad is a “function transform”: its result should not depend on the result of a context manager outside of f.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy