| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204 |
- import functools
- import hashlib
- from typing import Any
- @functools.cache
- def has_triton_package() -> bool:
- try:
- import triton # noqa: F401
- return True
- except ImportError:
- return False
- @functools.cache
- def get_triton_version(fallback: tuple[int, int] = (0, 0)) -> tuple[int, int]:
- try:
- import triton
- major, minor = tuple(int(v) for v in triton.__version__.split(".")[:2])
- return (major, minor)
- except ImportError:
- return fallback
- @functools.cache
- def _device_supports_tma() -> bool:
- import torch
- return (
- torch.cuda.is_available()
- and torch.cuda.get_device_capability() >= (9, 0)
- and not torch.version.hip
- )
- @functools.cache
- def has_triton_experimental_host_tma() -> bool:
- if has_triton_package():
- if _device_supports_tma():
- try:
- from triton.tools.experimental_descriptor import ( # noqa: F401
- create_1d_tma_descriptor,
- create_2d_tma_descriptor,
- )
- try:
- from triton.tools.experimental_descriptor import enable_in_pytorch
- return enable_in_pytorch()
- except ImportError:
- return True
- except ImportError:
- pass
- return False
- @functools.cache
- def has_triton_tensor_descriptor_host_tma() -> bool:
- if has_triton_package():
- if _device_supports_tma():
- try:
- from triton.tools.tensor_descriptor import ( # noqa: F401
- TensorDescriptor,
- )
- return True
- except ImportError:
- pass
- return False
- @functools.cache
- def has_triton_tma() -> bool:
- return has_triton_tensor_descriptor_host_tma() or has_triton_experimental_host_tma()
- @functools.cache
- def has_triton_tma_device() -> bool:
- if has_triton_package():
- import torch
- if (
- torch.cuda.is_available()
- and torch.cuda.get_device_capability() >= (9, 0)
- and not torch.version.hip
- ) or torch.xpu.is_available():
- # old API
- try:
- from triton.language.extra.cuda import ( # noqa: F401
- experimental_device_tensormap_create1d,
- experimental_device_tensormap_create2d,
- )
- return True
- except ImportError:
- pass
- # new API
- try:
- from triton.language import make_tensor_descriptor # noqa: F401
- return True
- except ImportError:
- pass
- return False
- @functools.cache
- def has_datacenter_blackwell_tma_device() -> bool:
- import torch
- if (
- torch.cuda.is_available()
- and torch.cuda.get_device_capability() >= (10, 0)
- and torch.cuda.get_device_capability() < (11, 0)
- and not torch.version.hip
- ):
- return has_triton_tma_device() and has_triton_tensor_descriptor_host_tma()
- return False
- @functools.lru_cache(None)
- def has_triton_stable_tma_api() -> bool:
- if has_triton_package():
- import torch
- if (
- torch.cuda.is_available()
- and torch.cuda.get_device_capability() >= (9, 0)
- and not torch.version.hip
- ) or torch.xpu.is_available():
- try:
- from triton.language import make_tensor_descriptor # noqa: F401
- return True
- except ImportError:
- pass
- return False
- @functools.cache
- def has_triton() -> bool:
- if not has_triton_package():
- return False
- from torch._inductor.config import triton_disable_device_detection
- if triton_disable_device_detection:
- return False
- from torch._dynamo.device_interface import get_interface_for_device
- def cuda_extra_check(device_interface: Any) -> bool:
- return device_interface.Worker.get_device_properties().major >= 7
- def cpu_extra_check(device_interface: Any) -> bool:
- import triton.backends
- return "cpu" in triton.backends.backends
- def _return_true(device_interface: Any) -> bool:
- return True
- triton_supported_devices = {
- "cuda": cuda_extra_check,
- "xpu": _return_true,
- "cpu": cpu_extra_check,
- "mtia": _return_true,
- }
- def is_device_compatible_with_triton() -> bool:
- for device, extra_check in triton_supported_devices.items():
- device_interface = get_interface_for_device(device)
- if device_interface.is_available() and extra_check(device_interface):
- return True
- return False
- return is_device_compatible_with_triton()
- @functools.cache
- def triton_backend() -> Any:
- from triton.compiler.compiler import make_backend
- from triton.runtime.driver import driver
- target = driver.active.get_current_target()
- return make_backend(target)
- @functools.cache
- def triton_hash_with_backend() -> str:
- from torch._inductor.runtime.triton_compat import triton_key
- backend = triton_backend()
- key = f"{triton_key()}-{backend.hash()}"
- # Hash is upper case so that it can't contain any Python keywords.
- return hashlib.sha256(key.encode("utf-8")).hexdigest().upper()
|