Shortcuts

Source code for torch.backends.cudnn

# mypy: allow-untyped-defs
import os
import sys
import warnings
from contextlib import contextmanager
from typing import Optional

import torch
from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule


try:
    from torch._C import _cudnn
except ImportError:
    _cudnn = None  # type: ignore[assignment]

# Write:
#
#   torch.backends.cudnn.enabled = False
#
# to globally disable CuDNN/MIOpen

__cudnn_version: Optional[int] = None

if _cudnn is not None:

    def _init():
        global __cudnn_version
        if __cudnn_version is None:
            __cudnn_version = _cudnn.getVersionInt()
            runtime_version = _cudnn.getRuntimeVersion()
            compile_version = _cudnn.getCompileVersion()
            runtime_major, runtime_minor, _ = runtime_version
            compile_major, compile_minor, _ = compile_version
            # Different major versions are always incompatible
            # Starting with cuDNN 7, minor versions are backwards-compatible
            # Not sure about MIOpen (ROCm), so always do a strict check
            if runtime_major != compile_major:
                cudnn_compatible = False
            elif runtime_major < 7 or not _cudnn.is_cuda:
                cudnn_compatible = runtime_minor == compile_minor
            else:
                cudnn_compatible = runtime_minor >= compile_minor
            if not cudnn_compatible:
                if os.environ.get("PYTORCH_SKIP_CUDNN_COMPATIBILITY_CHECK", "0") == "1":
                    return True
                base_error_msg = (
                    f"cuDNN version incompatibility: "
                    f"PyTorch was compiled  against {compile_version} "
                    f"but found runtime version {runtime_version}. "
                    f"PyTorch already comes bundled with cuDNN. "
                    f"One option to resolving this error is to ensure PyTorch "
                    f"can find the bundled cuDNN. "
                )

                if "LD_LIBRARY_PATH" in os.environ:
                    ld_library_path = os.environ.get("LD_LIBRARY_PATH", "")
                    if any(
                        substring in ld_library_path for substring in ["cuda", "cudnn"]
                    ):
                        raise RuntimeError(
                            f"{base_error_msg}"
                            f"Looks like your LD_LIBRARY_PATH contains incompatible version of cudnn. "
                            f"Please either remove it from the path or install cudnn {compile_version}"
                        )
                    else:
                        raise RuntimeError(
                            f"{base_error_msg}"
                            f"one possibility is that there is a "
                            f"conflicting cuDNN in LD_LIBRARY_PATH."
                        )
                else:
                    raise RuntimeError(base_error_msg)

        return True

else:

    def _init():
        return False


[docs]def version(): """Return the version of cuDNN.""" if not _init(): return None return __cudnn_version
CUDNN_TENSOR_DTYPES = { torch.half, torch.float, torch.double, }
[docs]def is_available(): r"""Return a bool indicating if CUDNN is currently available.""" return torch._C._has_cudnn
def is_acceptable(tensor): if not torch._C._get_cudnn_enabled(): return False if tensor.device.type != "cuda" or tensor.dtype not in CUDNN_TENSOR_DTYPES: return False if not is_available(): warnings.warn( "PyTorch was compiled without cuDNN/MIOpen support. To use cuDNN/MIOpen, rebuild " "PyTorch making sure the library is visible to the build system." ) return False if not _init(): warnings.warn( "cuDNN/MIOpen library not found. Check your {libpath}".format( libpath={"darwin": "DYLD_LIBRARY_PATH", "win32": "PATH"}.get( sys.platform, "LD_LIBRARY_PATH" ) ) ) return False return True def set_flags( _enabled=None, _benchmark=None, _benchmark_limit=None, _deterministic=None, _allow_tf32=None, ): orig_flags = ( torch._C._get_cudnn_enabled(), torch._C._get_cudnn_benchmark(), None if not is_available() else torch._C._cuda_get_cudnn_benchmark_limit(), torch._C._get_cudnn_deterministic(), torch._C._get_cudnn_allow_tf32(), ) if _enabled is not None: torch._C._set_cudnn_enabled(_enabled) if _benchmark is not None: torch._C._set_cudnn_benchmark(_benchmark) if _benchmark_limit is not None and is_available(): torch._C._cuda_set_cudnn_benchmark_limit(_benchmark_limit) if _deterministic is not None: torch._C._set_cudnn_deterministic(_deterministic) if _allow_tf32 is not None: torch._C._set_cudnn_allow_tf32(_allow_tf32) return orig_flags @contextmanager def flags( enabled=False, benchmark=False, benchmark_limit=10, deterministic=False, allow_tf32=True, ): with __allow_nonbracketed_mutation(): orig_flags = set_flags( enabled, benchmark, benchmark_limit, deterministic, allow_tf32 ) try: yield finally: # recover the previous values with __allow_nonbracketed_mutation(): set_flags(*orig_flags) # The magic here is to allow us to intercept code like this: # # torch.backends.<cudnn|mkldnn>.enabled = True class CudnnModule(PropModule): def __init__(self, m, name): super().__init__(m, name) enabled = ContextProp(torch._C._get_cudnn_enabled, torch._C._set_cudnn_enabled) deterministic = ContextProp( torch._C._get_cudnn_deterministic, torch._C._set_cudnn_deterministic ) benchmark = ContextProp( torch._C._get_cudnn_benchmark, torch._C._set_cudnn_benchmark ) benchmark_limit = None if is_available(): benchmark_limit = ContextProp( torch._C._cuda_get_cudnn_benchmark_limit, torch._C._cuda_set_cudnn_benchmark_limit, ) allow_tf32 = ContextProp( torch._C._get_cudnn_allow_tf32, torch._C._set_cudnn_allow_tf32 ) # This is the sys.modules replacement trick, see # https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273 sys.modules[__name__] = CudnnModule(sys.modules[__name__], __name__) # Add type annotation for the replaced module enabled: bool deterministic: bool benchmark: bool allow_tf32: bool benchmark_limit: int

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy