Shortcuts

torch.tensor_split

torch.tensor_split(input, indices_or_sections, dim=0) List of Tensors

Splits a tensor into multiple sub-tensors, all of which are views of input, along dimension dim according to the indices or number of sections specified by indices_or_sections. This function is based on NumPy’s numpy.array_split().

Parameters
  • input (Tensor) – the tensor to split

  • indices_or_sections (Tensor, int or list or tuple of ints) –

    If indices_or_sections is an integer n or a zero dimensional long tensor with value n, input is split into n sections along dimension dim. If input is divisible by n along dimension dim, each section will be of equal size, input.size(dim) / n. If input is not divisible by n, the sizes of the first int(input.size(dim) % n) sections will have size int(input.size(dim) / n) + 1, and the rest will have size int(input.size(dim) / n).

    If indices_or_sections is a list or tuple of ints, or a one-dimensional long tensor, then input is split along dimension dim at each of the indices in the list, tuple or tensor. For instance, indices_or_sections=[2, 3] and dim=0 would result in the tensors input[:2], input[2:3], and input[3:].

    If indices_or_sections is a tensor, it must be a zero-dimensional or one-dimensional long tensor on the CPU.

  • dim (int, optional) – dimension along which to split the tensor. Default: 0

Example:

>>> x = torch.arange(8)
>>> torch.tensor_split(x, 3)
(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7]))

>>> x = torch.arange(7)
>>> torch.tensor_split(x, 3)
(tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6]))
>>> torch.tensor_split(x, (1, 6))
(tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6]))

>>> x = torch.arange(14).reshape(2, 7)
>>> x
tensor([[ 0,  1,  2,  3,  4,  5,  6],
        [ 7,  8,  9, 10, 11, 12, 13]])
>>> torch.tensor_split(x, 3, dim=1)
(tensor([[0, 1, 2],
        [7, 8, 9]]),
 tensor([[ 3,  4],
        [10, 11]]),
 tensor([[ 5,  6],
        [12, 13]]))
>>> torch.tensor_split(x, (1, 6), dim=1)
(tensor([[0],
        [7]]),
 tensor([[ 1,  2,  3,  4,  5],
        [ 8,  9, 10, 11, 12]]),
 tensor([[ 6],
        [13]]))

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy