_internal_testing.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. import os
  2. import re
  3. import numpy as np
  4. import torch
  5. import triton
  6. import triton.language as tl
  7. from triton import knobs
  8. from typing import Optional, Set, Union
  9. import pytest
  10. from numpy.random import RandomState
  11. from triton.runtime.jit import TensorWrapper, reinterpret, type_canonicalisation_dict
  12. int_dtypes = ['int8', 'int16', 'int32', 'int64']
  13. uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
  14. integral_dtypes = int_dtypes + uint_dtypes
  15. float_dtypes = ['float16', 'float32', 'float64']
  16. float_dtypes_with_bfloat16 = float_dtypes + ['bfloat16']
  17. dtypes = integral_dtypes + float_dtypes
  18. dtypes_with_bfloat16 = dtypes + ['bfloat16']
  19. torch_float8_dtypes = ['float8_e4m3fn', 'float8_e5m2']
  20. torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16']
  21. tma_dtypes = sorted(set(dtypes_with_bfloat16) - {"int64", "uint64", "float64"})
  22. def is_interpreter():
  23. return os.environ.get('TRITON_INTERPRET', '0') == '1'
  24. def get_current_target():
  25. if is_interpreter():
  26. return None
  27. return triton.runtime.driver.active.get_current_target()
  28. def is_cuda():
  29. target = get_current_target()
  30. return False if target is None else target.backend == "cuda"
  31. def is_ampere_or_newer():
  32. return is_cuda() and torch.cuda.get_device_capability()[0] >= 8
  33. def is_blackwell():
  34. return is_cuda() and torch.cuda.get_device_capability()[0] == 10
  35. def is_hopper_or_newer():
  36. return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
  37. def is_hopper():
  38. return is_cuda() and torch.cuda.get_device_capability()[0] == 9
  39. def is_sm12x():
  40. return is_cuda() and torch.cuda.get_device_capability()[0] == 12
  41. def is_hip():
  42. target = get_current_target()
  43. return False if target is None else target.backend == "hip"
  44. def is_hip_cdna2():
  45. target = get_current_target()
  46. return target is not None and target.backend == 'hip' and target.arch == 'gfx90a'
  47. def is_hip_cdna3():
  48. target = get_current_target()
  49. return target is not None and target.backend == 'hip' and target.arch == 'gfx942'
  50. def is_hip_cdna4():
  51. target = get_current_target()
  52. return target is not None and target.backend == 'hip' and target.arch == 'gfx950'
  53. def is_hip_gfx11():
  54. target = get_current_target()
  55. return target is not None and target.backend == 'hip' and 'gfx11' in target.arch
  56. def is_hip_gfx12():
  57. target = get_current_target()
  58. return target is not None and target.backend == 'hip' and 'gfx12' in target.arch
  59. def is_hip_gfx1250():
  60. target = get_current_target()
  61. return target is not None and target.backend == 'hip' and 'gfx1250' in target.arch
  62. def is_hip_cdna():
  63. return is_hip_cdna2() or is_hip_cdna3() or is_hip_cdna4()
  64. def get_hip_lds_size():
  65. return 163840 if is_hip_cdna4() else 65536
  66. def is_xpu():
  67. target = get_current_target()
  68. return False if target is None else target.backend == "xpu"
  69. def get_arch():
  70. target = get_current_target()
  71. return "" if target is None else str(target.arch)
  72. def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None):
  73. """
  74. Override `rs` if you're calling this function twice and don't want the same
  75. result for both calls.
  76. """
  77. if isinstance(shape, int):
  78. shape = (shape, )
  79. if rs is None:
  80. rs = RandomState(seed=17)
  81. if dtype_str in int_dtypes + uint_dtypes:
  82. iinfo = np.iinfo(getattr(np, dtype_str))
  83. low = iinfo.min if low is None else max(low, iinfo.min)
  84. high = iinfo.max if high is None else min(high, iinfo.max)
  85. dtype = getattr(np, dtype_str)
  86. x = rs.randint(low, high, shape, dtype=dtype)
  87. x[x == 0] = 1 # Workaround. Never return zero so tests of division don't error out.
  88. return x
  89. elif dtype_str and 'float8' in dtype_str:
  90. x = rs.randint(20, 40, shape, dtype=np.int8)
  91. return x
  92. elif dtype_str in float_dtypes:
  93. return rs.normal(0, 1, shape).astype(dtype_str)
  94. elif dtype_str == 'bfloat16':
  95. return (rs.normal(0, 1, shape).astype('float32').view('uint32') & np.uint32(0xffff0000)).view('float32')
  96. elif dtype_str in ['bool', 'int1', 'bool_']:
  97. return rs.normal(0, 1, shape) > 0.0
  98. else:
  99. raise RuntimeError(f'Unknown dtype {dtype_str}')
  100. def to_triton(x: np.ndarray, device, dst_type=None) -> Union[TensorWrapper, torch.Tensor]:
  101. '''
  102. Note: We need dst_type because the type of x can be different from dst_type.
  103. For example: x is of type `float32`, dst_type is `bfloat16`.
  104. If dst_type is None, we infer dst_type from x.
  105. '''
  106. t = x.dtype.name
  107. if t in uint_dtypes:
  108. signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16"
  109. x_signed = x.astype(getattr(np, signed_type_name))
  110. return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t))
  111. else:
  112. if dst_type and 'float8' in dst_type:
  113. return reinterpret(torch.tensor(x, device=device), getattr(tl, dst_type))
  114. if t == 'float32' and dst_type == 'bfloat16':
  115. return torch.tensor(x, device=device).bfloat16()
  116. return torch.tensor(x, device=device)
  117. def str_to_triton_dtype(x: str) -> tl.dtype:
  118. return tl.str_to_ty(type_canonicalisation_dict[x], None)
  119. def torch_dtype_name(dtype) -> str:
  120. if isinstance(dtype, triton.language.dtype):
  121. return dtype.name
  122. elif isinstance(dtype, torch.dtype):
  123. # 'torch.int64' -> 'int64'
  124. m = re.match(r'^torch\.(\w+)$', str(dtype))
  125. return m.group(1)
  126. else:
  127. raise TypeError(f'not a triton or torch dtype: {type(dtype)}')
  128. def to_numpy(x):
  129. if isinstance(x, TensorWrapper):
  130. return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype)))
  131. elif isinstance(x, torch.Tensor):
  132. if x.dtype is torch.bfloat16:
  133. return x.cpu().float().numpy()
  134. return x.cpu().numpy()
  135. else:
  136. raise ValueError(f"Not a triton-compatible tensor: {x}")
  137. def supports_tma(byval_only=False):
  138. if is_interpreter():
  139. return True
  140. if not is_cuda():
  141. return False
  142. cuda_version = knobs.nvidia.ptxas.version
  143. min_cuda_version = (12, 0) if byval_only else (12, 3)
  144. cuda_version_tuple = tuple(map(int, cuda_version.split(".")))
  145. assert len(cuda_version_tuple) == 2, cuda_version_tuple
  146. return torch.cuda.get_device_capability()[0] >= 9 and cuda_version_tuple >= min_cuda_version
  147. def supports_ws():
  148. if is_interpreter():
  149. return True
  150. if not is_cuda():
  151. return False
  152. return torch.cuda.get_device_capability()[0] >= 9
  153. def tma_skip_msg(byval_only=False):
  154. if byval_only:
  155. return "Requires __grid_constant__ TMA support (NVIDIA Hopper or higher, CUDA 12.0 or higher)"
  156. else:
  157. return "Requires advanced TMA support (NVIDIA Hopper or higher, CUDA 12.3 or higher)"
  158. requires_tma = pytest.mark.skipif(not supports_tma(), reason=tma_skip_msg())
  159. def default_alloc_fn(size: int, align: int, _):
  160. return torch.empty(size, dtype=torch.int8, device="cuda")
  161. def unwrap_tensor(t: Union[torch.Tensor, triton.runtime.jit.TensorWrapper]) -> torch.Tensor:
  162. if isinstance(t, triton.runtime.jit.TensorWrapper):
  163. return t.base
  164. return t
  165. def _fresh_knobs_impl(skipped_attr: Optional[Set[str]] = None):
  166. from triton import knobs
  167. if skipped_attr is None:
  168. skipped_attr = set()
  169. monkeypatch = pytest.MonkeyPatch()
  170. knobs_map = {
  171. name: knobset
  172. for name, knobset in knobs.__dict__.items()
  173. if isinstance(knobset, knobs.base_knobs) and knobset != knobs.base_knobs and name not in skipped_attr
  174. }
  175. # We store which variables we need to unset below in finally because
  176. # monkeypatch doesn't appear to reset variables that were never set
  177. # before the monkeypatch.delenv call below.
  178. env_to_unset = []
  179. prev_propagate_env = knobs.propagate_env
  180. def fresh_function():
  181. nonlocal env_to_unset
  182. for name, knobset in knobs_map.items():
  183. setattr(knobs, name, knobset.copy().reset())
  184. for knob in knobset.knob_descriptors.values():
  185. if knob.key in os.environ:
  186. monkeypatch.delenv(knob.key, raising=False)
  187. else:
  188. env_to_unset.append(knob.key)
  189. knobs.propagate_env = True
  190. return knobs
  191. def reset_function():
  192. for name, knobset in knobs_map.items():
  193. setattr(knobs, name, knobset)
  194. # `undo` should be placed before `del os.environ`
  195. # Otherwise, it may restore environment variables that monkeypatch deleted
  196. monkeypatch.undo()
  197. for k in env_to_unset:
  198. if k in os.environ:
  199. del os.environ[k]
  200. knobs.propagate_env = prev_propagate_env
  201. return fresh_function, reset_function