_triton.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. import functools
  2. import hashlib
  3. from typing import Any
  4. @functools.cache
  5. def has_triton_package() -> bool:
  6. try:
  7. import triton # noqa: F401
  8. return True
  9. except ImportError:
  10. return False
  11. @functools.cache
  12. def get_triton_version(fallback: tuple[int, int] = (0, 0)) -> tuple[int, int]:
  13. try:
  14. import triton
  15. major, minor = tuple(int(v) for v in triton.__version__.split(".")[:2])
  16. return (major, minor)
  17. except ImportError:
  18. return fallback
  19. @functools.cache
  20. def _device_supports_tma() -> bool:
  21. import torch
  22. return (
  23. torch.cuda.is_available()
  24. and torch.cuda.get_device_capability() >= (9, 0)
  25. and not torch.version.hip
  26. )
  27. @functools.cache
  28. def has_triton_experimental_host_tma() -> bool:
  29. if has_triton_package():
  30. if _device_supports_tma():
  31. try:
  32. from triton.tools.experimental_descriptor import ( # noqa: F401
  33. create_1d_tma_descriptor,
  34. create_2d_tma_descriptor,
  35. )
  36. try:
  37. from triton.tools.experimental_descriptor import enable_in_pytorch
  38. return enable_in_pytorch()
  39. except ImportError:
  40. return True
  41. except ImportError:
  42. pass
  43. return False
  44. @functools.cache
  45. def has_triton_tensor_descriptor_host_tma() -> bool:
  46. if has_triton_package():
  47. if _device_supports_tma():
  48. try:
  49. from triton.tools.tensor_descriptor import ( # noqa: F401
  50. TensorDescriptor,
  51. )
  52. return True
  53. except ImportError:
  54. pass
  55. return False
  56. @functools.cache
  57. def has_triton_tma() -> bool:
  58. return has_triton_tensor_descriptor_host_tma() or has_triton_experimental_host_tma()
  59. @functools.cache
  60. def has_triton_tma_device() -> bool:
  61. if has_triton_package():
  62. import torch
  63. if (
  64. torch.cuda.is_available()
  65. and torch.cuda.get_device_capability() >= (9, 0)
  66. and not torch.version.hip
  67. ) or torch.xpu.is_available():
  68. # old API
  69. try:
  70. from triton.language.extra.cuda import ( # noqa: F401
  71. experimental_device_tensormap_create1d,
  72. experimental_device_tensormap_create2d,
  73. )
  74. return True
  75. except ImportError:
  76. pass
  77. # new API
  78. try:
  79. from triton.language import make_tensor_descriptor # noqa: F401
  80. return True
  81. except ImportError:
  82. pass
  83. return False
  84. @functools.cache
  85. def has_datacenter_blackwell_tma_device() -> bool:
  86. import torch
  87. if (
  88. torch.cuda.is_available()
  89. and torch.cuda.get_device_capability() >= (10, 0)
  90. and torch.cuda.get_device_capability() < (11, 0)
  91. and not torch.version.hip
  92. ):
  93. return has_triton_tma_device() and has_triton_tensor_descriptor_host_tma()
  94. return False
  95. @functools.lru_cache(None)
  96. def has_triton_stable_tma_api() -> bool:
  97. if has_triton_package():
  98. import torch
  99. if (
  100. torch.cuda.is_available()
  101. and torch.cuda.get_device_capability() >= (9, 0)
  102. and not torch.version.hip
  103. ) or torch.xpu.is_available():
  104. try:
  105. from triton.language import make_tensor_descriptor # noqa: F401
  106. return True
  107. except ImportError:
  108. pass
  109. return False
  110. @functools.cache
  111. def has_triton() -> bool:
  112. if not has_triton_package():
  113. return False
  114. from torch._inductor.config import triton_disable_device_detection
  115. if triton_disable_device_detection:
  116. return False
  117. from torch._dynamo.device_interface import get_interface_for_device
  118. def cuda_extra_check(device_interface: Any) -> bool:
  119. return device_interface.Worker.get_device_properties().major >= 7
  120. def cpu_extra_check(device_interface: Any) -> bool:
  121. import triton.backends
  122. return "cpu" in triton.backends.backends
  123. def _return_true(device_interface: Any) -> bool:
  124. return True
  125. triton_supported_devices = {
  126. "cuda": cuda_extra_check,
  127. "xpu": _return_true,
  128. "cpu": cpu_extra_check,
  129. "mtia": _return_true,
  130. }
  131. def is_device_compatible_with_triton() -> bool:
  132. for device, extra_check in triton_supported_devices.items():
  133. device_interface = get_interface_for_device(device)
  134. if device_interface.is_available() and extra_check(device_interface):
  135. return True
  136. return False
  137. return is_device_compatible_with_triton()
  138. @functools.cache
  139. def triton_backend() -> Any:
  140. from triton.compiler.compiler import make_backend
  141. from triton.runtime.driver import driver
  142. target = driver.active.get_current_target()
  143. return make_backend(target)
  144. @functools.cache
  145. def triton_hash_with_backend() -> str:
  146. from torch._inductor.runtime.triton_compat import triton_key
  147. backend = triton_backend()
  148. key = f"{triton_key()}-{backend.hash()}"
  149. # Hash is upper case so that it can't contain any Python keywords.
  150. return hashlib.sha256(key.encode("utf-8")).hexdigest().upper()