Shortcuts

Source code for torch.nn.modules.flatten

# mypy: allow-untyped-defs
from typing import Union

from torch import Tensor
from torch.types import _size

from .module import Module


__all__ = ["Flatten", "Unflatten"]


[docs]class Flatten(Module): r""" Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`, see :meth:`torch.flatten` for details. Shape: - Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,' where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any number of dimensions including none. - Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`. Args: start_dim: first dim to flatten (default = 1). end_dim: last dim to flatten (default = -1). Examples:: >>> input = torch.randn(32, 1, 5, 5) >>> # With default parameters >>> m = nn.Flatten() >>> output = m(input) >>> output.size() torch.Size([32, 25]) >>> # With non-default parameters >>> m = nn.Flatten(0, 2) >>> output = m(input) >>> output.size() torch.Size([160, 5]) """ __constants__ = ["start_dim", "end_dim"] start_dim: int end_dim: int def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None: super().__init__() self.start_dim = start_dim self.end_dim = end_dim def forward(self, input: Tensor) -> Tensor: return input.flatten(self.start_dim, self.end_dim) def extra_repr(self) -> str: return f"start_dim={self.start_dim}, end_dim={self.end_dim}"
[docs]class Unflatten(Module): r""" Unflattens a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`. * :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively. * :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape` (tuple of `(name, size)` tuples) for `NamedTensor` input. Shape: - Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at dimension :attr:`dim` and :math:`*` means any number of dimensions including none. - Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`. Args: dim (Union[int, str]): Dimension to be unflattened unflattened_size (Union[torch.Size, Tuple, List, NamedShape]): New shape of the unflattened dimension Examples: >>> input = torch.randn(2, 50) >>> # With tuple of ints >>> m = nn.Sequential( >>> nn.Linear(50, 50), >>> nn.Unflatten(1, (2, 5, 5)) >>> ) >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With torch.Size >>> m = nn.Sequential( >>> nn.Linear(50, 50), >>> nn.Unflatten(1, torch.Size([2, 5, 5])) >>> ) >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With namedshape (tuple of tuples) >>> input = torch.randn(2, 50, names=('N', 'features')) >>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5))) >>> output = unflatten(input) >>> output.size() torch.Size([2, 2, 5, 5]) """ NamedShape = tuple[tuple[str, int]] __constants__ = ["dim", "unflattened_size"] dim: Union[int, str] unflattened_size: Union[_size, NamedShape] def __init__( self, dim: Union[int, str], unflattened_size: Union[_size, NamedShape] ) -> None: super().__init__() if isinstance(dim, int): self._require_tuple_int(unflattened_size) elif isinstance(dim, str): self._require_tuple_tuple(unflattened_size) else: raise TypeError("invalid argument type for dim parameter") self.dim = dim self.unflattened_size = unflattened_size def _require_tuple_tuple(self, input): if isinstance(input, tuple): for idx, elem in enumerate(input): if not isinstance(elem, tuple): raise TypeError( "unflattened_size must be tuple of tuples, " + f"but found element of type {type(elem).__name__} at pos {idx}" ) return raise TypeError( "unflattened_size must be a tuple of tuples, " + f"but found type {type(input).__name__}" ) def _require_tuple_int(self, input): if isinstance(input, (tuple, list)): for idx, elem in enumerate(input): if not isinstance(elem, int): raise TypeError( "unflattened_size must be tuple of ints, " + f"but found element of type {type(elem).__name__} at pos {idx}" ) return raise TypeError( f"unflattened_size must be a tuple of ints, but found type {type(input).__name__}" ) def forward(self, input: Tensor) -> Tensor: return input.unflatten(self.dim, self.unflattened_size) def extra_repr(self) -> str: return f"dim={self.dim}, unflattened_size={self.unflattened_size}"

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy