_semantic.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573
  1. from typing import Sequence, List, TypeVar, Tuple, Callable
  2. import math
  3. from triton.language.semantic import TritonSemantic
  4. from . import _core as ttgl
  5. from ._layouts import AutoLayout, DistributedLayout, DistributedLinearLayout, SliceLayout, SharedLayout, CoalescedLayout
  6. from triton._C.libtriton.gluon_ir import GluonOpBuilder, compute_tmem_reg_layout
  7. from triton.compiler.code_generator import flatten_values_to_ir, unflatten_ir_values
  8. TensorTy = TypeVar("TensorTy")
  9. def _check(cond: bool, msg_fn: Callable[[], str], category=ValueError):
  10. if not cond:
  11. raise category(msg_fn())
  12. def _is_int_list(value):
  13. return isinstance(value, Sequence) and all(isinstance(i, int) for i in value)
  14. def _compute_tmem_reg_layout(element_ty, shape, layout, num_warps, instr_variant, cga_layout=None):
  15. _check(isinstance(instr_variant, str), lambda: "instr_variant must be a string")
  16. _check(instr_variant in ("32x32b", "16x64b", "16x128b", "16x256b", "16x32bx2", "32x32b_splitn"),
  17. lambda: f"unknown instr_variant: {instr_variant}")
  18. _check(isinstance(num_warps, int), lambda: f"num_warps must be an int but got {type(num_warps)!r}")
  19. _check(num_warps >= 4 and (num_warps & (num_warps - 1)) == 0, lambda: "num_warps must be a power of two and >= 4")
  20. shape = list(shape)
  21. _check(all(isinstance(dim, int) for dim in shape), lambda: f"shape entries must be ints but got {shape}")
  22. rank = len(shape)
  23. _check(rank == 2, lambda: "expected a 2D tensor")
  24. if cga_layout is None:
  25. cga_layout = []
  26. splitn = instr_variant == "32x32b_splitn"
  27. atom_variant = "32x32b" if splitn else instr_variant
  28. if cga_layout:
  29. for basis in cga_layout:
  30. _check(len(basis) == rank, lambda: "cga_layout basis rank mismatch")
  31. layout_obj = compute_tmem_reg_layout(
  32. element_ty,
  33. shape,
  34. layout,
  35. num_warps,
  36. atom_variant,
  37. cga_layout,
  38. )
  39. _check(layout_obj is not None,
  40. lambda: f"TMEM layout '{atom_variant}' unsupported for shape {shape} and num_warps {num_warps}")
  41. if splitn:
  42. N = shape[1]
  43. if not layout_obj.reg_bases:
  44. # We cannot use this layout in a load or a store ATM due to a PTX bug!
  45. # You can work around this by loading to 32x32b and follow by a convert_layout to this layout.
  46. _check(layout_obj.lane_bases[-1] == [0, N // 2],
  47. lambda: f"splitn with 1 register requires the last lane basis to be [0, N / 2]. Got {layout_obj}")
  48. layout_obj.reg_bases.append([0, N // 2])
  49. layout_obj.lane_bases[-1] = [0, 0]
  50. elif layout_obj.reg_bases[-1] != [0, N // 2]:
  51. bitwidth = element_ty.primitive_bitwidth
  52. num_reg = 2**len(layout_obj.reg_bases)
  53. _check(
  54. num_reg > 32 // bitwidth, lambda: "To be able to `tmem.load` into `tl.split` you need to have more "
  55. f"than {32 // bitwidth} {bitwidth}-bit registers, as you need to use "
  56. "the instruction 32x32b.x1 twice. You can always load into "
  57. "instr_variant=\"32x32b\" and then convert_layout to this layout otherwise.")
  58. reg_bases = layout_obj.reg_bases
  59. for bases_str in ("lane_bases", "warp_bases"):
  60. bases = getattr(layout_obj, bases_str)
  61. for i, basis in enumerate(bases):
  62. if basis == [0, N // 2]:
  63. reg_bases[-1], bases[i] = bases[i], reg_bases[-1]
  64. return layout_obj
  65. assert False, f"splitn requires at least one basis of [0, N / 2]. Got {layout}"
  66. return layout_obj
  67. _compute_tmem_reg_layout.__triton_builtin__ = True
  68. class GluonCallerContext:
  69. def __init__(self, num_warps: int):
  70. self.num_warps = num_warps
  71. def mangle(self):
  72. return f"_NW{self.num_warps}"
  73. def initialize_callee(self, fn, builder):
  74. fn.set_attr("ttg.num-warps", builder.get_int32_attr(self.num_warps))
  75. class GluonSemantic(TritonSemantic[TensorTy]):
  76. tensor = ttgl.tensor
  77. lang = ttgl
  78. builder: GluonOpBuilder
  79. def __init__(self, builder: GluonOpBuilder):
  80. self.builder = builder
  81. def _wrap_handle_infer_layout(self, handle, scalar_ty, shape):
  82. if shape == []:
  83. ty = scalar_ty
  84. else:
  85. ty = ttgl.distributed_type(scalar_ty, shape, self.builder.get_gluon_layout_from_tensor(handle))
  86. return self.tensor(handle, ty)
  87. def _wrap_tensor_infer_layout(self, tensor):
  88. return self._wrap_handle_infer_layout(tensor.handle, tensor.type.scalar, tensor.shape)
  89. def _broadcast_shapes(self, lhs_shape: List[int], rhs_shape: List[int]):
  90. if len(lhs_shape) != len(rhs_shape):
  91. raise ValueError(f"Cannot broadcast, rank mismatch: {lhs_shape}, {rhs_shape}")
  92. ret_shape = []
  93. for i, left in enumerate(lhs_shape):
  94. right = rhs_shape[i]
  95. if left == 1:
  96. ret_shape.append(right)
  97. elif (right == 1) or (right == left):
  98. ret_shape.append(left)
  99. else:
  100. raise ValueError("Cannot make_shape_compatible: incompatible dimensions "
  101. "at index " + str(i) + ": " + str(left) + " and " + str(right))
  102. return ret_shape
  103. def expand_dims(self, input: TensorTy, axis: int) -> TensorTy:
  104. dst_shape = [ttgl._unwrap_if_constexpr(x) for x in input.shape]
  105. dst_shape.insert(axis, 1)
  106. if axis < 0:
  107. axis += len(input.shape)
  108. _check(isinstance(input.type, ttgl.distributed_type),
  109. lambda: f"expected expand_dims input to be a distributed_type but got: {input.type!r}")
  110. layout = input.type.layout
  111. _check(isinstance(layout, (SliceLayout, AutoLayout, CoalescedLayout)),
  112. lambda: f"expected expand_dims input to have a SliceLayout, but got: {layout}")
  113. _check(
  114. isinstance(layout, (AutoLayout, CoalescedLayout)) or layout.dim == axis,
  115. lambda: f"expected expand_dims input layout to be sliced in axis {axis} but got {layout.dim}")
  116. handle = self.builder.create_expand_dims(input.handle, axis)
  117. return self._wrap_handle_infer_layout(handle, input.type.scalar, dst_shape)
  118. def join(self, a: TensorTy, b: TensorTy) -> TensorTy:
  119. a, b = self.broadcast_impl_value(a, b)
  120. _check(a.shape != [], lambda: "Cannot join scalars in gluon")
  121. value = super().join(a, b)
  122. return self._wrap_tensor_infer_layout(value)
  123. def split(self, a: TensorTy) -> Tuple[TensorTy, TensorTy]:
  124. lhs, rhs = super().split(a)
  125. return self._wrap_tensor_infer_layout(lhs), self._wrap_tensor_infer_layout(rhs)
  126. def permute(self, input: TensorTy, dims: Tuple[int]) -> TensorTy:
  127. value = super().permute(input, dims)
  128. return self._wrap_tensor_infer_layout(value)
  129. def broadcast_impl_shape(self, input: TensorTy, shape: Tuple[int]) -> TensorTy:
  130. _check(isinstance(input.type, ttgl.distributed_type),
  131. lambda: f"expected expand_dims input to be a distributed_type but got: {input.type!r}")
  132. src_shape = input.type.get_block_shapes()
  133. _check(len(src_shape) == len(shape), lambda: f"Cannot broadcast, rank mismatch: {src_shape}, {shape}")
  134. if shape == src_shape:
  135. return input
  136. for i, item in enumerate(src_shape):
  137. if shape[i] != item and item != 1:
  138. raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})"
  139. f" must match the existing size ({item}) at non-singleton dimension"
  140. f" {i}: {src_shape}, {shape}")
  141. ret_ty = ttgl.distributed_type(input.type.scalar, shape, input.type.layout)
  142. handle = self.builder.create_broadcast(input.handle, ret_ty.to_ir(self.builder))
  143. return self.tensor(handle, ret_ty)
  144. def broadcast_impl_value(self, lhs: TensorTy, rhs: TensorTy) -> TensorTy:
  145. lhs_ty = lhs.type
  146. rhs_ty = rhs.type
  147. if not lhs_ty.is_block() or not rhs_ty.is_block():
  148. return super().broadcast_impl_value(lhs, rhs)
  149. _check(isinstance(lhs_ty, ttgl.distributed_type),
  150. lambda: f"expected broadcast left input to be a distributed_type but got: {lhs_ty!r}")
  151. _check(isinstance(rhs_ty, ttgl.distributed_type),
  152. lambda: f"expected broadcast right input to be a distributed_type but got: {rhs_ty!r}")
  153. lhs_shape = lhs_ty.get_block_shapes()
  154. rhs_shape = rhs_ty.get_block_shapes()
  155. ret_shape = self._broadcast_shapes(lhs_shape, rhs_shape)
  156. is_lhs_auto = isinstance(lhs_ty.layout, AutoLayout)
  157. is_rhs_auto = isinstance(rhs_ty.layout, AutoLayout)
  158. if is_lhs_auto and not is_rhs_auto:
  159. lhs = self.set_auto_layout(lhs, rhs_ty.layout)
  160. elif is_rhs_auto and not is_lhs_auto:
  161. rhs = self.set_auto_layout(rhs, lhs_ty.layout)
  162. elif lhs_ty.layout != rhs_ty.layout:
  163. raise ValueError(f"Layout mismatch in broadcast: {lhs_ty.layout} vs {rhs_ty.layout}")
  164. lhs = self.broadcast_impl_shape(lhs, ret_shape)
  165. rhs = self.broadcast_impl_shape(rhs, ret_shape)
  166. return lhs, rhs
  167. def arange(self, start, end, layout):
  168. shape = [end - start]
  169. if layout is None:
  170. layout = AutoLayout()
  171. ret_ty = ttgl.distributed_type(ttgl.int32, shape, layout)
  172. return super().arange(start, end, ret_ty=ret_ty)
  173. def reshape(self, input: TensorTy, dst_shape: List[int], can_reorder: bool):
  174. _check(not can_reorder, lambda: "can_reorder is not supported in gluon")
  175. value = super().reshape(input, dst_shape, can_reorder)
  176. return self._wrap_tensor_infer_layout(value)
  177. def splat(self, value, shape, layout):
  178. if len(shape) == 0:
  179. return value
  180. ret_ty = ttgl.distributed_type(value.dtype, shape, layout)
  181. handle = self.builder.create_splat(ret_ty.to_ir(self.builder), value.handle)
  182. return ttgl.tensor(handle, ret_ty)
  183. def full(self, shape, value, dtype, layout):
  184. scalar = self.make_scalar(value, dtype)
  185. if layout is None:
  186. layout = AutoLayout()
  187. return self.splat(scalar, shape, layout)
  188. def convert_layout(self, value, layout, assert_trivial=False):
  189. ty = value.type
  190. _check(isinstance(ty, ttgl.distributed_type),
  191. lambda: f"expected convert_layout input to be a distributed_type but got: {ty!r}")
  192. _check(isinstance(layout, ttgl.DistributedLayout),
  193. lambda: f"expected 'layout' to be a DistributedLayout but got {layout}")
  194. ret_ty = ttgl.distributed_type(ty.element_ty, ty.shape, layout)
  195. ret_ty_ir = ret_ty.to_ir(self.builder)
  196. if assert_trivial and not self.builder.is_convert_layout_trivial(ret_ty_ir, value.handle):
  197. raise TypeError(f"layout conversion from {ty.layout} to {layout} is not trivial.\n"
  198. f"The linear layouts are:\n{self.to_linear_layout(ty.layout, ty.shape)}\n"
  199. f"{self.to_linear_layout(layout, ty.shape)}")
  200. handle = self.builder.create_convert_layout(ret_ty_ir, value.handle)
  201. return ttgl.tensor(handle, ret_ty)
  202. def allocate_shared(self, element_ty, shape, layout, value):
  203. _check(isinstance(element_ty, ttgl.dtype), lambda: f"expected 'element_ty' to be a dtype but got {element_ty}")
  204. _check(_is_int_list(shape), lambda: f"all elements of 'shape' must be integers but got {shape}")
  205. _check(isinstance(layout, ttgl.SharedLayout),
  206. lambda: f"expected 'layout' to be a SharedLayout but got {layout}")
  207. ty = ttgl.shared_memory_descriptor_type(element_ty, shape, layout, shape)
  208. if value is not None:
  209. handle = self.builder.create_local_alloc(ty.to_ir(self.builder), value.handle)
  210. else:
  211. handle = self.builder.create_local_alloc(ty.to_ir(self.builder))
  212. return ttgl.shared_memory_descriptor(handle, element_ty, shape, layout, shape)
  213. def shared_load(self, mem_desc, layout):
  214. _check(isinstance(layout, ttgl.DistributedLayout),
  215. lambda: f"expected 'layout' to be a DistributedLayout but got {layout}")
  216. ret_ty = ttgl.distributed_type(mem_desc.dtype, mem_desc.shape, layout)
  217. handle = self.builder.create_local_load(ret_ty.to_ir(self.builder), mem_desc.handle)
  218. return ttgl.tensor(handle, ret_ty)
  219. def shared_store(self, mem_desc, value):
  220. _check(isinstance(value, ttgl.tensor), lambda: f"expected 'value' to be a tensor, but got a {type(value)}")
  221. _check(value.shape == mem_desc.shape,
  222. lambda: f"source shape {value.shape} and destination shape {mem_desc.shape} must match")
  223. _check(value.dtype == mem_desc.dtype,
  224. lambda: f"source dtype {value.dtype} and destination dtype {mem_desc.dtype} must match")
  225. self.builder.create_local_store(mem_desc.handle, value.handle)
  226. def bank_conflicts(self, distr_ty, shared_ty):
  227. if not isinstance(distr_ty, ttgl.distributed_type):
  228. raise TypeError(
  229. f"bank_conflicts expects the register layout to be a distributed_type, got {type(distr_ty)}")
  230. if not isinstance(shared_ty, ttgl.shared_memory_descriptor_type):
  231. raise TypeError(
  232. f"bank_conflicts expects the shared layout to be a shared_memory_descriptor_type, got {type(shared_ty)}"
  233. )
  234. if distr_ty.shape != shared_ty.shape:
  235. raise ValueError(f"register shape {distr_ty.shape} and shared shape {shared_ty.shape} must match")
  236. if shared_ty.element_ty != distr_ty.element_ty:
  237. raise ValueError(
  238. f"mismatched dtypes between register ({distr_ty.element_ty}) and shared ({shared_ty.element_ty}) layouts"
  239. )
  240. if shared_ty.shape != shared_ty.alloc_shape[-len(shared_ty.shape):]:
  241. raise ValueError(
  242. f"bank_conflicts NYI for subslices. Got shape {shared_ty.shape} and alloc_shape {shared_ty.alloc_shape}"
  243. )
  244. reg_attr = distr_ty.layout._to_ir(self.builder)
  245. shared_attr = shared_ty.layout._to_ir(self.builder)
  246. return self.builder.get_shared_bank_conflicts(reg_attr, shared_attr, list(distr_ty.shape),
  247. distr_ty.element_ty.primitive_bitwidth)
  248. def to_linear_layout(self, layout, shape):
  249. _check(isinstance(layout, (DistributedLayout, SharedLayout)),
  250. lambda: f"Expected a DistributedLayout or SharedLayout, got {type(layout)}")
  251. if not isinstance(shape, list):
  252. shape = list(shape)
  253. layout = ttgl._unwrap_if_constexpr(layout)
  254. if isinstance(layout, (AutoLayout, DistributedLinearLayout)):
  255. return ttgl.constexpr(layout)
  256. return ttgl.constexpr(self.builder.to_linear_layout(layout._to_ir(self.builder), shape))
  257. def shared_dealloc(self, mem_desc):
  258. self.builder.create_local_dealloc(mem_desc.handle)
  259. def set_auto_layout(self, value, layout):
  260. src_ty = value.type
  261. _check(isinstance(layout, DistributedLayout),
  262. lambda: f"set_auto_layout must set to a distributed layout but got {layout}")
  263. _check(isinstance(src_ty.layout, AutoLayout),
  264. lambda: f"set_auto_layout input must have auto layout but got {value.type.layout}")
  265. handle = self.builder.create_set_auto_layout(layout._to_ir(self.builder), value.handle)
  266. res_ty = ttgl.distributed_type(src_ty.element_ty, src_ty.shape, layout)
  267. return self.tensor(handle, res_ty)
  268. def memdesc_slice(self, mem_desc, start, length, dim):
  269. _check(isinstance(start, int), lambda: f"expected 'start' to be an int but got {start}")
  270. _check(isinstance(length, int), lambda: f"expected 'length' to be an int but got {length}")
  271. _check(isinstance(dim, int), lambda: f"expected 'dim' to be an int but got {dim}")
  272. offsets = [0] * mem_desc.rank
  273. offsets[dim] = start
  274. shape = list(mem_desc.shape)
  275. shape[dim] = length
  276. layout = mem_desc.layout
  277. ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape)
  278. builder = self.builder
  279. handle = builder.create_memdesc_subslice(ty.to_ir(builder), mem_desc.handle, offsets)
  280. return ttgl.shared_memory_descriptor(handle, **ty.__dict__)
  281. def memdesc_index(self, mem_desc, index):
  282. index = self.to_tensor(index)
  283. _check(index.type == ttgl.int32, lambda: f"expected 'index' to be int32 but got {index.type}")
  284. shape = mem_desc.shape[1:]
  285. index = self.to_tensor(index).handle
  286. layout = mem_desc.layout
  287. ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, shape)
  288. builder = self.builder
  289. handle = builder.create_memdesc_index(ty.to_ir(builder), mem_desc.handle, index)
  290. return ttgl.shared_memory_descriptor(handle, **ty.__dict__)
  291. def memdesc_trans(self, mem_desc, order):
  292. _check(_is_int_list(order), lambda: f"all elements of 'order' must be integers but got {order}")
  293. _check(
  294. len(order) == len(mem_desc.shape),
  295. lambda: f"source rank ({mem_desc.rank}) and order length ({len(order)}) must match")
  296. shape = [mem_desc.shape[i] for i in order]
  297. alloc_shape = mem_desc.type.alloc_shape
  298. new_alloc_shape = alloc_shape[:len(alloc_shape) - mem_desc.rank]
  299. new_alloc_shape += [alloc_shape[len(alloc_shape) - mem_desc.rank:][i] for i in order]
  300. handle = self.builder.create_memdesc_trans(mem_desc.handle, order)
  301. layout = self.builder.get_gluon_layout_from_memdesc(handle)
  302. return ttgl.shared_memory_descriptor(handle, element_ty=mem_desc.dtype, shape=shape,
  303. alloc_shape=new_alloc_shape, layout=layout)
  304. def memdesc_reshape(self, mem_desc, shape):
  305. _check(_is_int_list(shape), lambda: f"all elements of 'shape' must be integers but got {shape}")
  306. _check(
  307. math.prod(shape) == math.prod(mem_desc.shape),
  308. lambda: (f"memdesc_reshape total elements mismatch: "
  309. f"{mem_desc.shape} -> {shape}"),
  310. )
  311. handle = self.builder.create_memdesc_reshape(mem_desc.handle, shape)
  312. layout = self.builder.get_gluon_layout_from_memdesc(handle)
  313. alloc_shape = mem_desc.type.alloc_shape
  314. prefix_len = len(alloc_shape) - mem_desc.rank
  315. new_alloc_shape = alloc_shape[:prefix_len] + list(shape)
  316. return ttgl.shared_memory_descriptor(
  317. handle,
  318. element_ty=mem_desc.dtype,
  319. shape=shape,
  320. alloc_shape=new_alloc_shape,
  321. layout=layout,
  322. )
  323. def memdesc_reinterpret(self, mem_desc, dtype, shape, layout):
  324. _check(isinstance(dtype, ttgl.dtype), lambda: f"expected 'dtype' to be a dtype but got {dtype}")
  325. _check(_is_int_list(shape), lambda: f"all elements of 'shape' must be integers but got {shape}")
  326. _check(isinstance(layout, ttgl.SharedLayout),
  327. lambda: f"expected 'layout' to be a SharedLayout but got {layout}")
  328. ty = ttgl.shared_memory_descriptor_type(dtype, shape, layout, shape)
  329. handle = self.builder.create_memdesc_reinterpret(ty.to_ir(self.builder), mem_desc.handle)
  330. return ttgl.shared_memory_descriptor(handle, **ty.__dict__)
  331. def wrap_tensor(self, x, scalar_ty, ret_shape, layout):
  332. if ret_shape:
  333. res_ty = ttgl.distributed_type(scalar_ty, ret_shape, layout)
  334. else:
  335. res_ty = scalar_ty
  336. return self.tensor(x, res_ty)
  337. @staticmethod
  338. def _check_same_layout(xs):
  339. for x in xs:
  340. _check(isinstance(x.type, ttgl.distributed_type), lambda: f"expected distributed_type but got: {x.type!r}")
  341. layouts = [x.type.layout for x in xs]
  342. l0 = layouts[0]
  343. _check(all(l == l0 for l in layouts[1:]),
  344. lambda: f"Expected inputs to have matching layouts, but got: {layouts}")
  345. def associative_scan(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn,
  346. reverse: bool) -> Tuple[TensorTy, ...]:
  347. shape = inputs[0].type.shape
  348. rank = len(shape)
  349. assert -rank <= axis < rank, f"scan axis {axis} must be < inputs rank ({rank})"
  350. if axis < 0:
  351. axis += rank
  352. for t in inputs:
  353. assert t.type.shape == shape, "all scan inputs must have the same shape"
  354. scan_op = self.builder.create_scan([t.handle for t in inputs], axis, reverse)
  355. region_builder_fn(scan_op)
  356. assert scan_op.verify()
  357. return tuple(
  358. self._wrap_handle_infer_layout(scan_op.get_result(i), inputs[i].type.scalar, shape)
  359. for i in range(len(inputs)))
  360. def reduction(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn) -> Tuple[TensorTy, ...]:
  361. if axis is None:
  362. inputs = tuple(self.reshape(t, [t.numel.value], can_reorder=False) for t in inputs)
  363. axis = 0
  364. # get result shape
  365. shape = inputs[0].type.shape
  366. rank = len(shape)
  367. _check(0 <= axis < rank, lambda: f"expected reduction axis to be in the range [0, {rank}) but got {axis}")
  368. self._check_same_layout(inputs)
  369. ret_shape = [s for i, s in enumerate(shape) if i != axis]
  370. assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape"
  371. reduce_op = self.builder.create_reduce([t.handle for t in inputs], axis)
  372. region_builder_fn(reduce_op)
  373. assert reduce_op.verify()
  374. return tuple(
  375. self._wrap_handle_infer_layout(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape)
  376. for i in range(len(inputs)))
  377. def histogram(self, input: TensorTy, num_bins: int, mask: TensorTy, layout) -> TensorTy:
  378. _check(len(input.shape) == 1, lambda: "histogram only supports 1D input")
  379. _check(input.dtype.is_int(), lambda: "histogram only supports integer input")
  380. _check(layout is not None, lambda: "histogram requires a destination layout")
  381. if mask is not None:
  382. mask, input = self.broadcast_impl_value(mask, input)
  383. _check(mask.type.scalar.is_bool(), lambda: "Mask must have boolean scalar type")
  384. mask = mask.handle
  385. layout_attr = layout._to_ir(self.builder)
  386. handle = self.builder.create_histogram(input.handle, num_bins, mask, layout_attr)
  387. return self.wrap_tensor(handle, ttgl.int32, [num_bins], layout)
  388. def cat(self, lhs: TensorTy, rhs: TensorTy, can_reorder: bool, layout) -> TensorTy:
  389. _check(layout is not None, lambda: "cat requires a destination layout")
  390. _check(can_reorder, lambda: "current implementation of `cat` always may reorder elements")
  391. _check(len(lhs.shape) == 1, lambda: "cat requires a rank-1 input")
  392. ret_type = ttgl.distributed_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]], layout)
  393. return self.tensor(self.builder.create_cat(lhs.handle, rhs.handle, ret_type.to_ir(self.builder)), ret_type)
  394. def gather(self, src: TensorTy, index: TensorTy, axis: int) -> TensorTy:
  395. _check(isinstance(src.type, ttgl.distributed_type), lambda: f"expected distributed_type but got: {src.type!r}")
  396. _check(isinstance(index.type, ttgl.distributed_type),
  397. lambda: f"expected distributed_type but got: {index.type!r}")
  398. _check(index.type.scalar.is_int(), lambda: f"expected integer scalar type but got: {index.type.scalar!r}")
  399. rank = len(src.type.shape)
  400. _check(len(index.type.shape) == rank, lambda: "source and index tensors must have the same rank")
  401. _check(-rank <= axis < rank, lambda: f"gather axis {axis} must be < source rank ({rank})")
  402. if axis < 0:
  403. axis += rank
  404. for d in range(rank):
  405. if d == axis:
  406. continue
  407. _check(
  408. index.type.shape[d] == src.type.shape[d],
  409. lambda: f"index dim {axis} must match the corresponding source dim",
  410. )
  411. gather = self.builder.create_gather(src.handle, index.handle, axis)
  412. return self.wrap_tensor(gather, src.type.scalar, index.type.shape, index.type.layout)
  413. def fp4_to_fp(self, src: TensorTy, elem_type, axis) -> TensorTy:
  414. result = self.builder.create_fp4_to_fp(src.handle, elem_type.to_ir(self.builder), axis)
  415. shape = list(src.type.shape)
  416. shape[axis] *= 2
  417. return self._wrap_handle_infer_layout(result, elem_type, shape)
  418. def warp_specialize(self, functions_and_args, worker_num_warps: Sequence[int], worker_num_regs: Sequence[int],
  419. generator):
  420. for _, args in functions_and_args:
  421. _check(isinstance(args, (tuple, ttgl.tuple)),
  422. lambda: f"function arguments must be a tuple of arguments, but got {type(args)}")
  423. assert len(functions_and_args) >= 1, "expected at least one function for the default partition"
  424. default_partition, default_args = functions_and_args[0]
  425. num_partitions = len(functions_and_args) - 1
  426. workers = functions_and_args[1:]
  427. assert num_partitions == len(
  428. worker_num_warps
  429. ), f"warp specialize got {num_partitions} partitions but {len(worker_num_warps)} warp counts"
  430. assert num_partitions == len(
  431. worker_num_regs
  432. ), f"warp specialize got {num_partitions} partitions but {len(worker_num_regs)} register counts"
  433. builder = self.builder
  434. insert_pt = builder.get_insertion_point()
  435. # Emit the default partition to get the result types.
  436. default_block = builder.new_block()
  437. builder.set_insertion_point_to_start(default_block)
  438. default_results = generator.call_JitFunction(default_partition, default_args, kwargs={})
  439. mlir_results = []
  440. if default_results is not None:
  441. mlir_results = flatten_values_to_ir(default_results)
  442. builder.create_warp_yield(mlir_results)
  443. result_types = [r.get_type() for r in mlir_results]
  444. # Create the warp specialize op.
  445. worker_args = [flatten_values_to_ir(args) for _, args in workers]
  446. mlir_args = sum(worker_args, [])
  447. builder.restore_insertion_point(insert_pt)
  448. ws_op = builder.create_warp_specialize(result_types, mlir_args, worker_num_warps)
  449. ws_op.get_default_region().push_back(default_block)
  450. ws_op.set_requested_registers(worker_num_regs)
  451. # Emit the partition regions.
  452. builder.create_block_with_parent(ws_op.get_partition_op_holder(), [])
  453. partitions_op = builder.create_warp_specialize_partitions(num_partitions)
  454. arg_types = [arg.get_type() for arg in mlir_args]
  455. arg_it = 0
  456. for i, (func, args) in enumerate(workers):
  457. caller_context = GluonCallerContext(num_warps=worker_num_warps[i])
  458. block = builder.create_block_with_parent(partitions_op.get_region(i), arg_types)
  459. mlir_args = worker_args[i]
  460. block_args = [block.get_argument(arg_it + j) for j in range(len(mlir_args))]
  461. block_args = unflatten_ir_values(block_args, [arg.type for arg in args])
  462. generator.call_JitFunction(func, block_args, kwargs={}, caller_context=caller_context)
  463. builder.create_warp_return()
  464. arg_it += len(mlir_args)
  465. builder.set_insertion_point_after(ws_op.get_operation())
  466. mlir_results = [ws_op.get_result(i) for i in range(len(result_types))]
  467. if default_results is None:
  468. return
  469. return tuple(unflatten_ir_values(mlir_results, [r.type for r in default_results]))
  470. def num_ctas(self):
  471. return ttgl.constexpr(self.builder.options.num_ctas)
  472. def num_warps(self, generator):
  473. if generator.caller_context is not None:
  474. assert isinstance(generator.caller_context, GluonCallerContext)
  475. return ttgl.constexpr(generator.caller_context.num_warps)
  476. return ttgl.constexpr(self.builder.options.num_warps)