Shortcuts

Source code for torch.backends.opt_einsum

# mypy: allow-untyped-defs
import sys
import warnings
from contextlib import contextmanager
from functools import lru_cache as _lru_cache
from typing import Any

from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule


try:
    import opt_einsum as _opt_einsum  # type: ignore[import]
except ImportError:
    _opt_einsum = None


[docs]@_lru_cache def is_available() -> bool: r"""Return a bool indicating if opt_einsum is currently available. You must install opt-einsum in order for torch to automatically optimize einsum. To make opt-einsum available, you can install it along with torch: ``pip install torch[opt-einsum]`` or by itself: ``pip install opt-einsum``. If the package is installed, torch will import it automatically and use it accordingly. Use this function to check whether opt-einsum was installed and properly imported by torch. """ return _opt_einsum is not None
[docs]def get_opt_einsum() -> Any: r"""Return the opt_einsum package if opt_einsum is currently available, else None.""" return _opt_einsum
def _set_enabled(_enabled: bool) -> None: if not is_available() and _enabled: raise ValueError( f"opt_einsum is not available, so setting `enabled` to {_enabled} will not reap " "the benefits of calculating an optimal path for einsum. torch.einsum will " "fall back to contracting from left to right. To enable this optimal path " "calculation, please install opt-einsum." ) global enabled enabled = _enabled def _get_enabled() -> bool: return enabled def _set_strategy(_strategy: str) -> None: if not is_available(): raise ValueError( f"opt_einsum is not available, so setting `strategy` to {_strategy} will not be meaningful. " "torch.einsum will bypass path calculation and simply contract from left to right. " "Please install opt_einsum or unset `strategy`." ) if not enabled: raise ValueError( f"opt_einsum is not enabled, so setting a `strategy` to {_strategy} will not be meaningful. " "torch.einsum will bypass path calculation and simply contract from left to right. " "Please set `enabled` to `True` as well or unset `strategy`." ) if _strategy not in ["auto", "greedy", "optimal"]: raise ValueError( f"`strategy` must be one of the following: [auto, greedy, optimal] but is {_strategy}" ) global strategy strategy = _strategy def _get_strategy() -> str: return strategy def set_flags(_enabled=None, _strategy=None): orig_flags = (enabled, None if not is_available() else strategy) if _enabled is not None: _set_enabled(_enabled) if _strategy is not None: _set_strategy(_strategy) return orig_flags @contextmanager def flags(enabled=None, strategy=None): with __allow_nonbracketed_mutation(): orig_flags = set_flags(enabled, strategy) try: yield finally: # recover the previous values with __allow_nonbracketed_mutation(): set_flags(*orig_flags) # The magic here is to allow us to intercept code like this: # # torch.backends.opt_einsum.enabled = True class OptEinsumModule(PropModule): def __init__(self, m, name): super().__init__(m, name) global enabled enabled = ContextProp(_get_enabled, _set_enabled) global strategy strategy = None if is_available(): strategy = ContextProp(_get_strategy, _set_strategy) # This is the sys.modules replacement trick, see # https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273 sys.modules[__name__] = OptEinsumModule(sys.modules[__name__], __name__) enabled = True if is_available() else False strategy = "auto" if is_available() else None

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy