Shortcuts

torch.take_along_dim

torch.take_along_dim(input, indices, dim=None, *, out=None) Tensor

Selects values from input at the 1-dimensional indices from indices along the given dim.

If dim is None, the input array is treated as if it has been flattened to 1d.

Functions that return indices along a dimension, like torch.argmax() and torch.argsort(), are designed to work with this function. See the examples below.

Note

This function is similar to NumPy’s take_along_axis. See also torch.gather().

Parameters
  • input (Tensor) – the input tensor.

  • indices (LongTensor) – the indices into input. Must have long dtype.

  • dim (int, optional) – dimension to select along. Default: 0

Keyword Arguments

out (Tensor, optional) – the output tensor.

Example:

>>> t = torch.tensor([[10, 30, 20], [60, 40, 50]])
>>> max_idx = torch.argmax(t)
>>> torch.take_along_dim(t, max_idx)
tensor([60])
>>> sorted_idx = torch.argsort(t, dim=1)
>>> torch.take_along_dim(t, sorted_idx, dim=1)
tensor([[10, 20, 30],
        [40, 50, 60]])

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy