_utils.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. from __future__ import annotations
  2. from functools import reduce
  3. from typing import Any, Callable, TYPE_CHECKING, Union, List, Dict
  4. if TYPE_CHECKING:
  5. from .language import core
  6. IterableType = Union[list[Any], tuple[Any, ...], core.tuple, core.tuple_type]
  7. ObjPath = tuple[int, ...]
  8. TRITON_MAX_TENSOR_NUMEL = 1048576
  9. def get_iterable_path(iterable: IterableType, path: ObjPath) -> Any:
  10. return reduce(lambda a, idx: a[idx], path, iterable) # type: ignore[index]
  11. def set_iterable_path(iterable: IterableType, path: tuple[int, ...], val: Any):
  12. from .language import core
  13. assert len(path) != 0
  14. prev = iterable if len(path) == 1 else get_iterable_path(iterable, path[:-1])
  15. assert isinstance(prev, core.tuple)
  16. prev._setitem(path[-1], val)
  17. def find_paths_if(iterable: Union[IterableType, Any], pred: Callable[[ObjPath, Any], bool]) -> list[ObjPath]:
  18. from .language import core
  19. is_iterable: Callable[[Any], bool] = lambda x: isinstance(x, (list, tuple, core.tuple, core.tuple_type))
  20. # We need to use dict so that ordering is maintained, while set doesn't guarantee order
  21. ret: dict[ObjPath, None] = {}
  22. def _impl(path: tuple[int, ...], current: Any):
  23. if is_iterable(current):
  24. for idx, item in enumerate(current):
  25. _impl((*path, idx), item)
  26. elif pred(path, current):
  27. ret[path] = None
  28. _impl((), iterable)
  29. return list(ret.keys())
  30. def is_power_of_two(x):
  31. return (x & (x - 1)) == 0
  32. def validate_block_shape(shape: List[int]):
  33. numel = 1
  34. for i, d in enumerate(shape):
  35. if not isinstance(d, int):
  36. raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d)}]")
  37. if not is_power_of_two(d):
  38. raise ValueError(f"Shape element {i} must be a power of 2")
  39. numel *= d
  40. if numel > TRITON_MAX_TENSOR_NUMEL:
  41. raise ValueError(f"numel ({numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})")
  42. return numel
  43. type_canonicalisation_dict = {
  44. # we canonicalise all bools to be unsigned:
  45. "bool": "u1",
  46. "int1": "u1",
  47. "uint1": "u1",
  48. "i1": "u1",
  49. # floating-point dtypes:
  50. "float8e4nv": "fp8e4nv",
  51. "float8e5": "fp8e5",
  52. "float8e4b15": "fp8e4b15",
  53. "float8_e4m3fn": "fp8e4nv",
  54. "float8e4b8": "fp8e4b8",
  55. "float8_e4m3fnuz": "fp8e4b8",
  56. "float8_e5m2": "fp8e5",
  57. "float8e5b16": "fp8e5b16",
  58. "float8_e5m2fnuz": "fp8e5b16",
  59. "half": "fp16",
  60. "float16": "fp16",
  61. "bfloat16": "bf16",
  62. "float": "fp32",
  63. "float32": "fp32",
  64. "double": "fp64",
  65. "float64": "fp64",
  66. # signed integers:
  67. "int8": "i8",
  68. "int16": "i16",
  69. "int": "i32",
  70. "int32": "i32",
  71. "int64": "i64",
  72. # unsigned integers:
  73. "uint8": "u8",
  74. "uint16": "u16",
  75. "uint32": "u32",
  76. "uint64": "u64",
  77. "void": "void",
  78. }
  79. for v in list(type_canonicalisation_dict.values()):
  80. type_canonicalisation_dict[v] = v
  81. def canonicalize_dtype(dtype):
  82. dtype_str = str(dtype).split(".")[-1]
  83. return type_canonicalisation_dict[dtype_str]
  84. def canonicalize_ptr_dtype(dtype, is_const):
  85. return f"{'*k' if is_const else '*'}{canonicalize_dtype(dtype)}"
  86. BITWIDTH_DICT: Dict[str, int] = {
  87. **{f"u{n}": n
  88. for n in (1, 8, 16, 32, 64)},
  89. **{f"i{n}": n
  90. for n in (1, 8, 16, 32, 64)},
  91. **{f"fp{n}": n
  92. for n in (16, 32, 64)},
  93. **{f"fp8{suffix}": 8
  94. for suffix in ("e4nv", "e4b15", "e4b8", "e5", "e5b16")},
  95. "bf16": 16,
  96. "void": 0,
  97. }
  98. for k, v in type_canonicalisation_dict.items():
  99. BITWIDTH_DICT[k] = BITWIDTH_DICT[v]
  100. def get_primitive_bitwidth(dtype: str) -> int:
  101. return BITWIDTH_DICT[dtype]
  102. def is_namedtuple(val):
  103. return isinstance(val, type) and issubclass(val, tuple) and hasattr(val, "_fields")