Shortcuts

Source code for torch.utils.data.dataset

# mypy: allow-untyped-defs
import bisect
import itertools
import math
import warnings
from collections.abc import Sequence

# UP006 wants 'Iterable' to be imported from collections.abc but it needs to
# stay from typing for now due to BC concerns. In particular several internal
# targets fail to typecheck with:
#     TypeError: Cannot create a consistent method resolution order (MRO) for
#     bases Iterable, Generic
from typing import cast, Generic, Iterable, Optional, TypeVar, Union  # noqa: UP035
from typing_extensions import deprecated

# No 'default_generator' in torch/__init__.pyi
from torch import default_generator, Generator, randperm, Tensor


__all__ = [
    "Dataset",
    "IterableDataset",
    "TensorDataset",
    "StackDataset",
    "ConcatDataset",
    "ChainDataset",
    "Subset",
    "random_split",
]


_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)
_T_dict = dict[str, _T_co]
_T_tuple = tuple[_T_co, ...]
_T_stack = TypeVar("_T_stack", _T_tuple, _T_dict)


[docs]class Dataset(Generic[_T_co]): r"""An abstract class representing a :class:`Dataset`. All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:`__len__`, which is expected to return the size of the dataset by many :class:`~torch.utils.data.Sampler` implementations and the default options of :class:`~torch.utils.data.DataLoader`. Subclasses could also optionally implement :meth:`__getitems__`, for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples. .. note:: :class:`~torch.utils.data.DataLoader` by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided. """ def __getitem__(self, index) -> _T_co: raise NotImplementedError("Subclasses of Dataset should implement __getitem__.") # def __getitems__(self, indices: List) -> List[_T_co]: # Not implemented to prevent false-positives in fetcher check in # torch.utils.data._utils.fetch._MapDatasetFetcher def __add__(self, other: "Dataset[_T_co]") -> "ConcatDataset[_T_co]": return ConcatDataset([self, other])
# No `def __len__(self)` default? # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] # in pytorch/torch/utils/data/sampler.py
[docs]class IterableDataset(Dataset[_T_co], Iterable[_T_co]): r"""An iterable Dataset. All datasets that represent an iterable of data samples should subclass it. Such form of datasets is particularly useful when data come from a stream. All subclasses should overwrite :meth:`__iter__`, which would return an iterator of samples in this dataset. When a subclass is used with :class:`~torch.utils.data.DataLoader`, each item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader` iterator. When :attr:`num_workers > 0`, each worker process will have a different copy of the dataset object, so it is often desired to configure each copy independently to avoid having duplicate data returned from the workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker process, returns information about the worker. It can be used in either the dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's :attr:`worker_init_fn` option to modify each copy's behavior. Example 1: splitting workload across all workers in :meth:`__iter__`:: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) >>> # xdoctest: +SKIP("Fails on MacOS12") >>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... worker_info = torch.utils.data.get_worker_info() ... if worker_info is None: # single-process data loading, return the full iterator ... iter_start = self.start ... iter_end = self.end ... else: # in a worker process ... # split workload ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... iter_start = self.start + worker_id * per_worker ... iter_end = min(iter_start + per_worker, self.end) ... return iter(range(iter_start, iter_end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [tensor([3]), tensor([4]), tensor([5]), tensor([6])] >>> # xdoctest: +REQUIRES(POSIX) >>> # Multi-process loading with two worker processes >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> # xdoctest: +IGNORE_WANT("non deterministic") >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [tensor([3]), tensor([5]), tensor([4]), tensor([6])] >>> # With even more workers >>> # xdoctest: +IGNORE_WANT("non deterministic") >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12))) [tensor([3]), tensor([5]), tensor([4]), tensor([6])] Example 2: splitting workload across all workers using :attr:`worker_init_fn`:: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) >>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... return iter(range(self.start, self.end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6] >>> >>> # Directly doing multi-process loading yields duplicate data >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [3, 3, 4, 4, 5, 5, 6, 6] >>> # Define a `worker_init_fn` that configures each dataset copy differently >>> def worker_init_fn(worker_id): ... worker_info = torch.utils.data.get_worker_info() ... dataset = worker_info.dataset # the dataset copy in this worker process ... overall_start = dataset.start ... overall_end = dataset.end ... # configure the dataset to only process the split workload ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... dataset.start = overall_start + worker_id * per_worker ... dataset.end = min(dataset.start + per_worker, overall_end) ... >>> # Mult-process loading with the custom `worker_init_fn` >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) [3, 5, 4, 6] >>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn))) [3, 4, 5, 6] """ def __add__(self, other: Dataset[_T_co]): return ChainDataset([self, other])
# No `def __len__(self)` default? Subclasses raise `TypeError` when needed. # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
[docs]class TensorDataset(Dataset[tuple[Tensor, ...]]): r"""Dataset wrapping tensors. Each sample will be retrieved by indexing tensors along the first dimension. Args: *tensors (Tensor): tensors that have the same size of the first dimension. """ tensors: tuple[Tensor, ...] def __init__(self, *tensors: Tensor) -> None: assert all( tensors[0].size(0) == tensor.size(0) for tensor in tensors ), "Size mismatch between tensors" self.tensors = tensors def __getitem__(self, index): return tuple(tensor[index] for tensor in self.tensors) def __len__(self): return self.tensors[0].size(0)
[docs]class StackDataset(Dataset[_T_stack]): r"""Dataset as a stacking of multiple datasets. This class is useful to assemble different parts of complex input data, given as datasets. Example: >>> # xdoctest: +SKIP >>> images = ImageDataset() >>> texts = TextDataset() >>> tuple_stack = StackDataset(images, texts) >>> tuple_stack[0] == (images[0], texts[0]) >>> dict_stack = StackDataset(image=images, text=texts) >>> dict_stack[0] == {'image': images[0], 'text': texts[0]} Args: *args (Dataset): Datasets for stacking returned as tuple. **kwargs (Dataset): Datasets for stacking returned as dict. """ datasets: Union[tuple, dict] def __init__(self, *args: Dataset[_T_co], **kwargs: Dataset[_T_co]) -> None: if args: if kwargs: raise ValueError( "Supported either ``tuple``- (via ``args``) or" "``dict``- (via ``kwargs``) like input/output, but both types are given." ) self._length = len(args[0]) # type: ignore[arg-type] if any(self._length != len(dataset) for dataset in args): # type: ignore[arg-type] raise ValueError("Size mismatch between datasets") self.datasets = args elif kwargs: tmp = list(kwargs.values()) self._length = len(tmp[0]) # type: ignore[arg-type] if any(self._length != len(dataset) for dataset in tmp): # type: ignore[arg-type] raise ValueError("Size mismatch between datasets") self.datasets = kwargs else: raise ValueError("At least one dataset should be passed") def __getitem__(self, index): if isinstance(self.datasets, dict): return {k: dataset[index] for k, dataset in self.datasets.items()} return tuple(dataset[index] for dataset in self.datasets) def __getitems__(self, indices: list): # add batched sampling support when parent datasets supports it. if isinstance(self.datasets, dict): dict_batch: list[_T_dict] = [{} for _ in indices] for k, dataset in self.datasets.items(): if callable(getattr(dataset, "__getitems__", None)): items = dataset.__getitems__(indices) # type: ignore[attr-defined] if len(items) != len(indices): raise ValueError( "Nested dataset's output size mismatch." f" Expected {len(indices)}, got {len(items)}" ) for data, d_sample in zip(items, dict_batch): d_sample[k] = data else: for idx, d_sample in zip(indices, dict_batch): d_sample[k] = dataset[idx] return dict_batch # tuple data list_batch: list[list] = [[] for _ in indices] for dataset in self.datasets: if callable(getattr(dataset, "__getitems__", None)): items = dataset.__getitems__(indices) # type: ignore[attr-defined] if len(items) != len(indices): raise ValueError( "Nested dataset's output size mismatch." f" Expected {len(indices)}, got {len(items)}" ) for data, t_sample in zip(items, list_batch): t_sample.append(data) else: for idx, t_sample in zip(indices, list_batch): t_sample.append(dataset[idx]) tuple_batch: list[_T_tuple] = [tuple(sample) for sample in list_batch] return tuple_batch def __len__(self): return self._length
[docs]class ConcatDataset(Dataset[_T_co]): r"""Dataset as a concatenation of multiple datasets. This class is useful to assemble different existing datasets. Args: datasets (sequence): List of datasets to be concatenated """ datasets: list[Dataset[_T_co]] cumulative_sizes: list[int] @staticmethod def cumsum(sequence): r, s = [], 0 for e in sequence: l = len(e) r.append(l + s) s += l return r def __init__(self, datasets: Iterable[Dataset]) -> None: super().__init__() self.datasets = list(datasets) assert len(self.datasets) > 0, "datasets should not be an empty iterable" # type: ignore[arg-type] for d in self.datasets: assert not isinstance( d, IterableDataset ), "ConcatDataset does not support IterableDataset" self.cumulative_sizes = self.cumsum(self.datasets) def __len__(self): return self.cumulative_sizes[-1] def __getitem__(self, idx): if idx < 0: if -idx > len(self): raise ValueError( "absolute value of index should not exceed dataset length" ) idx = len(self) + idx dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) if dataset_idx == 0: sample_idx = idx else: sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] return self.datasets[dataset_idx][sample_idx] @property @deprecated( "`cummulative_sizes` attribute is renamed to `cumulative_sizes`", category=FutureWarning, ) def cummulative_sizes(self): return self.cumulative_sizes
[docs]class ChainDataset(IterableDataset): r"""Dataset for chaining multiple :class:`IterableDataset` s. This class is useful to assemble different existing dataset streams. The chaining operation is done on-the-fly, so concatenating large-scale datasets with this class will be efficient. Args: datasets (iterable of IterableDataset): datasets to be chained together """ def __init__(self, datasets: Iterable[Dataset]) -> None: super().__init__() self.datasets = datasets def __iter__(self): for d in self.datasets: assert isinstance( d, IterableDataset ), "ChainDataset only supports IterableDataset" yield from d def __len__(self): total = 0 for d in self.datasets: assert isinstance( d, IterableDataset ), "ChainDataset only supports IterableDataset" total += len(d) # type: ignore[arg-type] return total
[docs]class Subset(Dataset[_T_co]): r""" Subset of a dataset at specified indices. Args: dataset (Dataset): The whole Dataset indices (sequence): Indices in the whole set selected for subset """ dataset: Dataset[_T_co] indices: Sequence[int] def __init__(self, dataset: Dataset[_T_co], indices: Sequence[int]) -> None: self.dataset = dataset self.indices = indices def __getitem__(self, idx): if isinstance(idx, list): return self.dataset[[self.indices[i] for i in idx]] return self.dataset[self.indices[idx]] def __getitems__(self, indices: list[int]) -> list[_T_co]: # add batched sampling support when parent dataset supports it. # see torch.utils.data._utils.fetch._MapDatasetFetcher if callable(getattr(self.dataset, "__getitems__", None)): return self.dataset.__getitems__([self.indices[idx] for idx in indices]) # type: ignore[attr-defined] else: return [self.dataset[self.indices[idx]] for idx in indices] def __len__(self): return len(self.indices)
[docs]def random_split( dataset: Dataset[_T], lengths: Sequence[Union[int, float]], generator: Optional[Generator] = default_generator, ) -> list[Subset[_T]]: r""" Randomly split a dataset into non-overlapping new datasets of given lengths. If a list of fractions that sum up to 1 is given, the lengths will be computed automatically as floor(frac * len(dataset)) for each fraction provided. After computing the lengths, if there are any remainders, 1 count will be distributed in round-robin fashion to the lengths until there are no remainders left. Optionally fix the generator for reproducible results, e.g.: Example: >>> # xdoctest: +SKIP >>> generator1 = torch.Generator().manual_seed(42) >>> generator2 = torch.Generator().manual_seed(42) >>> random_split(range(10), [3, 7], generator=generator1) >>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2) Args: dataset (Dataset): Dataset to be split lengths (sequence): lengths or fractions of splits to be produced generator (Generator): Generator used for the random permutation. """ if math.isclose(sum(lengths), 1) and sum(lengths) <= 1: subset_lengths: list[int] = [] for i, frac in enumerate(lengths): if frac < 0 or frac > 1: raise ValueError(f"Fraction at index {i} is not between 0 and 1") n_items_in_split = int( math.floor(len(dataset) * frac) # type: ignore[arg-type] ) subset_lengths.append(n_items_in_split) remainder = len(dataset) - sum(subset_lengths) # type: ignore[arg-type] # add 1 to all the lengths in round-robin fashion until the remainder is 0 for i in range(remainder): idx_to_add_at = i % len(subset_lengths) subset_lengths[idx_to_add_at] += 1 lengths = subset_lengths for i, length in enumerate(lengths): if length == 0: warnings.warn( f"Length of split at index {i} is 0. " f"This might result in an empty dataset." ) # Cannot verify that dataset is Sized if sum(lengths) != len(dataset): # type: ignore[arg-type] raise ValueError( "Sum of input lengths does not equal the length of the input dataset!" ) indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[arg-type, call-overload] lengths = cast(Sequence[int], lengths) return [ Subset(dataset, indices[offset - length : offset]) for offset, length in zip(itertools.accumulate(lengths), lengths) ]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy