Shortcuts

Source code for torch.distributions.distribution

# mypy: allow-untyped-defs
import warnings
from typing import Optional
from typing_extensions import deprecated

import torch
from torch import Tensor
from torch.distributions import constraints
from torch.distributions.utils import lazy_property
from torch.types import _size


__all__ = ["Distribution"]


[docs]class Distribution: r""" Distribution is the abstract base class for probability distributions. """ has_rsample = False has_enumerate_support = False _validate_args = __debug__
[docs] @staticmethod def set_default_validate_args(value: bool) -> None: """ Sets whether validation is enabled or disabled. The default behavior mimics Python's ``assert`` statement: validation is on by default, but is disabled if Python is run in optimized mode (via ``python -O``). Validation may be expensive, so you may want to disable it once a model is working. Args: value (bool): Whether to enable validation. """ if value not in [True, False]: raise ValueError Distribution._validate_args = value
def __init__( self, batch_shape: torch.Size = torch.Size(), event_shape: torch.Size = torch.Size(), validate_args: Optional[bool] = None, ): self._batch_shape = batch_shape self._event_shape = event_shape if validate_args is not None: self._validate_args = validate_args if self._validate_args: try: arg_constraints = self.arg_constraints except NotImplementedError: arg_constraints = {} warnings.warn( f"{self.__class__} does not define `arg_constraints`. " + "Please set `arg_constraints = {}` or initialize the distribution " + "with `validate_args=False` to turn off validation." ) for param, constraint in arg_constraints.items(): if constraints.is_dependent(constraint): continue # skip constraints that cannot be checked if param not in self.__dict__ and isinstance( getattr(type(self), param), lazy_property ): continue # skip checking lazily-constructed args value = getattr(self, param) valid = constraint.check(value) if not torch._is_all_true(valid): raise ValueError( f"Expected parameter {param} " f"({type(value).__name__} of shape {tuple(value.shape)}) " f"of distribution {repr(self)} " f"to satisfy the constraint {repr(constraint)}, " f"but found invalid values:\n{value}" ) super().__init__()
[docs] def expand(self, batch_shape: _size, _instance=None): """ Returns a new distribution instance (or populates an existing instance provided by a derived class) with batch dimensions expanded to `batch_shape`. This method calls :class:`~torch.Tensor.expand` on the distribution's parameters. As such, this does not allocate new memory for the expanded distribution instance. Additionally, this does not repeat any args checking or parameter broadcasting in `__init__.py`, when an instance is first created. Args: batch_shape (torch.Size): the desired expanded size. _instance: new instance provided by subclasses that need to override `.expand`. Returns: New distribution instance with batch dimensions expanded to `batch_size`. """ raise NotImplementedError
@property def batch_shape(self) -> torch.Size: """ Returns the shape over which parameters are batched. """ return self._batch_shape @property def event_shape(self) -> torch.Size: """ Returns the shape of a single sample (without batching). """ return self._event_shape @property def arg_constraints(self) -> dict[str, constraints.Constraint]: """ Returns a dictionary from argument names to :class:`~torch.distributions.constraints.Constraint` objects that should be satisfied by each argument of this distribution. Args that are not tensors need not appear in this dict. """ raise NotImplementedError @property def support(self) -> Optional[constraints.Constraint]: """ Returns a :class:`~torch.distributions.constraints.Constraint` object representing this distribution's support. """ raise NotImplementedError @property def mean(self) -> Tensor: """ Returns the mean of the distribution. """ raise NotImplementedError @property def mode(self) -> Tensor: """ Returns the mode of the distribution. """ raise NotImplementedError(f"{self.__class__} does not implement mode") @property def variance(self) -> Tensor: """ Returns the variance of the distribution. """ raise NotImplementedError @property def stddev(self) -> Tensor: """ Returns the standard deviation of the distribution. """ return self.variance.sqrt()
[docs] def sample(self, sample_shape: _size = torch.Size()) -> Tensor: """ Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched. """ with torch.no_grad(): return self.rsample(sample_shape)
[docs] def rsample(self, sample_shape: _size = torch.Size()) -> Tensor: """ Generates a sample_shape shaped reparameterized sample or sample_shape shaped batch of reparameterized samples if the distribution parameters are batched. """ raise NotImplementedError
[docs] @deprecated( "`sample_n(n)` will be deprecated. Use `sample((n,))` instead.", category=FutureWarning, ) def sample_n(self, n: int) -> Tensor: """ Generates n samples or n batches of samples if the distribution parameters are batched. """ return self.sample(torch.Size((n,)))
[docs] def log_prob(self, value: Tensor) -> Tensor: """ Returns the log of the probability density/mass function evaluated at `value`. Args: value (Tensor): """ raise NotImplementedError
[docs] def cdf(self, value: Tensor) -> Tensor: """ Returns the cumulative density/mass function evaluated at `value`. Args: value (Tensor): """ raise NotImplementedError
[docs] def icdf(self, value: Tensor) -> Tensor: """ Returns the inverse cumulative density/mass function evaluated at `value`. Args: value (Tensor): """ raise NotImplementedError
[docs] def enumerate_support(self, expand: bool = True) -> Tensor: """ Returns tensor containing all values supported by a discrete distribution. The result will enumerate over dimension 0, so the shape of the result will be `(cardinality,) + batch_shape + event_shape` (where `event_shape = ()` for univariate distributions). Note that this enumerates over all batched tensors in lock-step `[[0, 0], [1, 1], ...]`. With `expand=False`, enumeration happens along dim 0, but with the remaining batch dimensions being singleton dimensions, `[[0], [1], ..`. To iterate over the full Cartesian product use `itertools.product(m.enumerate_support())`. Args: expand (bool): whether to expand the support over the batch dims to match the distribution's `batch_shape`. Returns: Tensor iterating over dimension 0. """ raise NotImplementedError
[docs] def entropy(self) -> Tensor: """ Returns entropy of distribution, batched over batch_shape. Returns: Tensor of shape batch_shape. """ raise NotImplementedError
[docs] def perplexity(self) -> Tensor: """ Returns perplexity of distribution, batched over batch_shape. Returns: Tensor of shape batch_shape. """ return torch.exp(self.entropy())
def _extended_shape(self, sample_shape: _size = torch.Size()) -> torch.Size: """ Returns the size of the sample returned by the distribution, given a `sample_shape`. Note, that the batch and event shapes of a distribution instance are fixed at the time of construction. If this is empty, the returned shape is upcast to (1,). Args: sample_shape (torch.Size): the size of the sample to be drawn. """ if not isinstance(sample_shape, torch.Size): sample_shape = torch.Size(sample_shape) return torch.Size(sample_shape + self._batch_shape + self._event_shape) def _validate_sample(self, value: Tensor) -> None: """ Argument validation for distribution methods such as `log_prob`, `cdf` and `icdf`. The rightmost dimensions of a value to be scored via these methods must agree with the distribution's batch and event shapes. Args: value (Tensor): the tensor whose log probability is to be computed by the `log_prob` method. Raises ValueError: when the rightmost dimensions of `value` do not match the distribution's batch and event shapes. """ if not isinstance(value, torch.Tensor): raise ValueError("The value argument to log_prob must be a Tensor") event_dim_start = len(value.size()) - len(self._event_shape) if value.size()[event_dim_start:] != self._event_shape: raise ValueError( f"The right-most size of value must match event_shape: {value.size()} vs {self._event_shape}." ) actual_shape = value.size() expected_shape = self._batch_shape + self._event_shape for i, j in zip(reversed(actual_shape), reversed(expected_shape)): if i != 1 and j != 1 and i != j: raise ValueError( f"Value is not broadcastable with batch_shape+event_shape: {actual_shape} vs {expected_shape}." ) try: support = self.support except NotImplementedError: warnings.warn( f"{self.__class__} does not define `support` to enable " + "sample validation. Please initialize the distribution with " + "`validate_args=False` to turn off validation." ) return assert support is not None valid = support.check(value) if not torch._is_all_true(valid): raise ValueError( "Expected value argument " f"({type(value).__name__} of shape {tuple(value.shape)}) " f"to be within the support ({repr(support)}) " f"of the distribution {repr(self)}, " f"but found invalid values:\n{value}" ) def _get_checked_instance(self, cls, _instance=None): if _instance is None and type(self).__init__ != cls.__init__: raise NotImplementedError( f"Subclass {self.__class__.__name__} of {cls.__name__} that defines a custom __init__ method " "must also define a custom .expand() method." ) return self.__new__(type(self)) if _instance is None else _instance def __repr__(self) -> str: param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__] args_string = ", ".join( [ f"{p}: {self.__dict__[p] if self.__dict__[p].numel() == 1 else self.__dict__[p].size()}" for p in param_names ] ) return self.__class__.__name__ + "(" + args_string + ")"

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy