_layouts.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676
  1. from dataclasses import dataclass, field
  2. from typing import List
  3. from triton.language.core import _unwrap_if_constexpr, _unwrap_shape, constexpr_type
  4. from triton.runtime.jit import constexpr_function
  5. import math
  6. class DistributedLayout:
  7. """
  8. Base class for distributed memory layouts in Gluon IR.
  9. """
  10. @property
  11. def type(self):
  12. return constexpr_type(self)
  13. @property
  14. def rank(self):
  15. raise NotImplementedError("DistributedLayout subclasses must define rank")
  16. @dataclass(frozen=True)
  17. class AutoLayout(DistributedLayout):
  18. def _to_ir(self, builder):
  19. return builder.get_auto_layout()
  20. def mangle(self):
  21. return "AL"
  22. @property
  23. def rank(self):
  24. raise ValueError("AutoLayout has no rank")
  25. @dataclass(frozen=True)
  26. class CoalescedLayout(DistributedLayout):
  27. def _to_ir(self, builder):
  28. return builder.get_coalesced_layout()
  29. def mangle(self):
  30. return "CL"
  31. @property
  32. def rank(self):
  33. raise ValueError("CoalescedLayout has no rank")
  34. @dataclass(frozen=True)
  35. class BlockedLayout(DistributedLayout):
  36. """
  37. Represents a blocked layout, partitioning a tensor across threads, warps, and CTAs.
  38. Args:
  39. size_per_thread (List[int]): Number of elements per thread per dimension.
  40. threads_per_warp (List[int]): Number of threads per warp per dimension.
  41. warps_per_cta (List[int]): Number of warps per CTA per dimension.
  42. order (List[int]): The ordering of dimensions for partitioning.
  43. cga_layout (Optional[List[List[int]]]): Bases describing how CTAs tile each dimension.
  44. """
  45. size_per_thread: List[int]
  46. threads_per_warp: List[int]
  47. warps_per_cta: List[int]
  48. order: List[int]
  49. cga_layout: List[List[int]] = field(default_factory=list)
  50. def __post_init__(self):
  51. super().__setattr__("size_per_thread", _unwrap_if_constexpr(self.size_per_thread))
  52. super().__setattr__("threads_per_warp", _unwrap_if_constexpr(self.threads_per_warp))
  53. super().__setattr__("warps_per_cta", _unwrap_if_constexpr(self.warps_per_cta))
  54. super().__setattr__("order", _unwrap_if_constexpr(self.order))
  55. rank = len(self.size_per_thread)
  56. object.__setattr__(self, "cga_layout", self.cga_layout)
  57. assert len(self.threads_per_warp) == rank
  58. assert len(self.warps_per_cta) == rank
  59. assert len(self.order) == rank
  60. def _to_ir(self, builder):
  61. return builder.get_blocked_layout(
  62. self.size_per_thread,
  63. self.threads_per_warp,
  64. self.warps_per_cta,
  65. self.order,
  66. self.cga_layout,
  67. )
  68. def mangle(self) -> str:
  69. def stringify(x):
  70. if x is None:
  71. return ""
  72. return "_".join(map(str, x))
  73. size_per_thread = stringify(self.size_per_thread)
  74. threads_per_warp = stringify(self.threads_per_warp)
  75. warps_per_cta = stringify(self.warps_per_cta)
  76. order = stringify(self.order)
  77. cga_layout = "_".join("~".join(map(str, vec)) for vec in self.cga_layout) if self.cga_layout else ""
  78. return f"B{size_per_thread}_{threads_per_warp}_{warps_per_cta}_{order}_{cga_layout}B"
  79. def __hash__(self):
  80. return hash((tuple(self.size_per_thread), tuple(self.threads_per_warp), tuple(self.warps_per_cta),
  81. tuple(self.order), tuple(tuple(vec) for vec in self.cga_layout)))
  82. @property
  83. def rank(self):
  84. return len(self.order)
  85. @dataclass(frozen=True)
  86. class SliceLayout(DistributedLayout):
  87. """
  88. Represents a layout corresponding to slicing a distributed tensor along one dimension.
  89. Args:
  90. dim (int): The dimension index to slice.
  91. parent (DistributedLayout): The parent layout before slicing.
  92. """
  93. dim: int
  94. parent: DistributedLayout
  95. def __post_init__(self):
  96. super().__setattr__("dim", _unwrap_if_constexpr(self.dim))
  97. super().__setattr__("parent", _unwrap_if_constexpr(self.parent))
  98. def _to_ir(self, builder):
  99. return builder.get_slice_layout(
  100. self.dim,
  101. self.parent._to_ir(builder),
  102. )
  103. def mangle(self) -> str:
  104. return f"SL{self.dim}_{self.parent.mangle()}SL"
  105. def __hash__(self):
  106. return hash((self.dim, self.parent))
  107. @property
  108. def rank(self):
  109. return self.parent.rank - 1
  110. @property
  111. def cga_layout(self):
  112. parent_cga_layout = self.parent.cga_layout
  113. if not parent_cga_layout:
  114. return []
  115. rank = self.parent.rank
  116. assert 0 <= self.dim < rank
  117. return [basis[:self.dim] + basis[self.dim + 1:] for basis in parent_cga_layout]
  118. @dataclass(frozen=True)
  119. class DistributedLinearLayout(DistributedLayout):
  120. """
  121. Represents a linear distributed layout with explicit bases at register, lane, warp, and block levels.
  122. See: https://arxiv.org/abs/2505.23819 for reference.
  123. Args:
  124. reg_bases (List[List[int]]): Bases for register-level distribution.
  125. lane_bases (List[List[int]]): Bases for lane-level distribution.
  126. warp_bases (List[List[int]]): Bases for warp-level distribution.
  127. block_bases (List[List[int]]): Bases for block-level distribution.
  128. shape (List[int]): The tensor global shape.
  129. """
  130. reg_bases: List[List[int]]
  131. lane_bases: List[List[int]]
  132. warp_bases: List[List[int]]
  133. block_bases: List[List[int]]
  134. shape: List[int]
  135. def __post_init__(self):
  136. super().__setattr__("reg_bases", _unwrap_shape(self.reg_bases))
  137. super().__setattr__("lane_bases", _unwrap_shape(self.lane_bases))
  138. super().__setattr__("warp_bases", _unwrap_shape(self.warp_bases))
  139. super().__setattr__("block_bases", _unwrap_shape(self.block_bases))
  140. super().__setattr__("shape", _unwrap_shape(self.shape))
  141. rank = len(self.shape)
  142. for basis in self.reg_bases:
  143. assert len(basis) == rank
  144. for basis in self.lane_bases:
  145. assert len(basis) == rank
  146. for basis in self.warp_bases:
  147. assert len(basis) == rank
  148. for basis in self.block_bases:
  149. assert len(basis) == rank
  150. def _to_ir(self, builder):
  151. return builder.get_distributed_linear_layout(self.reg_bases, self.lane_bases, self.warp_bases, self.block_bases,
  152. self.shape)
  153. def mangle(self):
  154. return f"DLL{self.reg_bases}_{self.lane_bases}_{self.warp_bases}_{self.block_bases}_{self.shape}DLL"
  155. def __hash__(self):
  156. return hash((
  157. tuple(map(tuple, self.reg_bases)),
  158. tuple(map(tuple, self.lane_bases)),
  159. tuple(map(tuple, self.warp_bases)),
  160. tuple(map(tuple, self.block_bases)),
  161. tuple(self.shape),
  162. ))
  163. @property
  164. def rank(self):
  165. return len(self.shape)
  166. @dataclass(frozen=True)
  167. class DotOperandLayout(DistributedLayout):
  168. """
  169. Represents a layout for a dot operand.
  170. Args:
  171. operand_index (int): 0 for LHS and 1 for RHS of the dot operation.
  172. parent (DistributedLayout): The parent layout, representing the MMA.
  173. k_width (int): Number of elements per 32-bits.
  174. """
  175. operand_index: int
  176. parent: DistributedLayout
  177. k_width: int
  178. def __post_init__(self):
  179. super().__setattr__("operand_index", _unwrap_if_constexpr(self.operand_index))
  180. super().__setattr__("parent", _unwrap_if_constexpr(self.parent))
  181. super().__setattr__("k_width", _unwrap_if_constexpr(self.k_width))
  182. def _to_ir(self, builder):
  183. return builder.get_dot_operand_layout(self.operand_index, self.parent._to_ir(builder), self.k_width)
  184. def mangle(self) -> str:
  185. return f"DO{self.operand_index}_{self.parent.mangle()}_{self.k_width}DO"
  186. def __hash__(self):
  187. return hash((self.operand_index, self.parent, self.k_width))
  188. @property
  189. def rank(self):
  190. return self.parent.rank
  191. @property
  192. def cga_layout(self):
  193. parent_cga_layout = _unwrap_if_constexpr(getattr(self.parent, "cga_layout", [])) or []
  194. if not parent_cga_layout:
  195. return []
  196. rank = self.parent.rank
  197. assert all(len(basis) == rank for basis in parent_cga_layout)
  198. k_dim = rank - 1 if self.operand_index == 0 else rank - 2
  199. assert 0 <= k_dim < rank
  200. derived = []
  201. for basis in parent_cga_layout:
  202. new_basis = list(basis)
  203. new_basis[k_dim] = 0
  204. derived.append(new_basis)
  205. return derived
  206. @dataclass(frozen=True, eq=True)
  207. class NVMMADistributedLayout(DistributedLayout):
  208. """
  209. Represents a layout for NVIDIA MMA (tensor core) operations.
  210. Args:
  211. version (List[int]): Version identifier for the MMA instruction.
  212. warps_per_cta (List[int]): Number of warps per CTA.
  213. instr_shape (List[int]): Instruction shape for MMA.
  214. cga_layout (Optional[List[List[int]]]): Bases describing CTA tiling.
  215. """
  216. version: List[int]
  217. warps_per_cta: List[int]
  218. instr_shape: List[int]
  219. cga_layout: List[List[int]] = field(default_factory=list)
  220. def __post_init__(self):
  221. super().__setattr__("version", _unwrap_if_constexpr(self.version))
  222. super().__setattr__("warps_per_cta", _unwrap_if_constexpr(self.warps_per_cta))
  223. super().__setattr__("instr_shape", _unwrap_if_constexpr(self.instr_shape))
  224. object.__setattr__(self, "cga_layout", self.cga_layout)
  225. def _to_ir(self, builder):
  226. return builder.get_mma_layout(
  227. self.version,
  228. self.warps_per_cta,
  229. self.cga_layout,
  230. self.instr_shape,
  231. )
  232. def mangle(self) -> str:
  233. cga_layout = "_".join("~".join(map(str, vec)) for vec in self.cga_layout) if self.cga_layout else ""
  234. return f"MMA_{self.version}_{self.warps_per_cta}_{self.instr_shape}_{cga_layout}_MMA"
  235. def __hash__(self):
  236. return hash((tuple(self.version), tuple(self.warps_per_cta), tuple(self.instr_shape),
  237. tuple(tuple(vec) for vec in self.cga_layout)))
  238. @property
  239. def rank(self):
  240. return len(self.warps_per_cta)
  241. class SharedLayout:
  242. """
  243. Base class for shared memory layouts in Gluon IR.
  244. """
  245. @property
  246. def type(self):
  247. return constexpr_type(self)
  248. @constexpr_function
  249. def _get_shape_per_cta(shape, cga_layout):
  250. if not cga_layout:
  251. return shape
  252. shape_per_cta = list(shape)
  253. rank = len(cga_layout[0])
  254. cga_shape = [1] * rank
  255. for basis in cga_layout:
  256. assert len(basis) == rank
  257. for i in range(rank):
  258. cga_shape[i] = max(cga_shape[i], basis[i])
  259. # The shape is the largest stride * 2
  260. for i in range(rank):
  261. cga_shape[i] *= 2
  262. for dim in range(rank):
  263. assert shape_per_cta[dim] % cga_shape[dim] == 0, f"Shape {shape} is not divisible by CGA layout {cga_layout}"
  264. shape_per_cta[dim] //= cga_shape[dim]
  265. return shape_per_cta
  266. @dataclass(frozen=True)
  267. class NVMMASharedLayout(SharedLayout):
  268. """
  269. Represents a layout for shared memory suitable for NVIDIA MMA operations.
  270. Args:
  271. swizzle_byte_width (int): Width in bytes for swizzling.
  272. element_bitwidth (int): Bitwidth of element type.
  273. rank (int): Rank of the tensor.
  274. transposed (bool): Whether the layout is transposed.
  275. fp4_padded (bool): Whether FP4 padding is used.
  276. cga_layout (Optional[List[List[int]]]): Bases describing CTA tiling.
  277. """
  278. swizzle_byte_width: int
  279. element_bitwidth: int
  280. rank: int = 2
  281. transposed: bool = False
  282. fp4_padded: bool = False
  283. cga_layout: List[List[int]] = field(default_factory=list)
  284. def __post_init__(self):
  285. super().__setattr__("swizzle_byte_width", _unwrap_if_constexpr(self.swizzle_byte_width))
  286. super().__setattr__("element_bitwidth", _unwrap_if_constexpr(self.element_bitwidth))
  287. super().__setattr__("transposed", _unwrap_if_constexpr(self.transposed))
  288. super().__setattr__("fp4_padded", _unwrap_if_constexpr(self.fp4_padded))
  289. # TODO: Make rank optional and check that (rank or cga_layout)
  290. cga_layout = self.cga_layout or []
  291. if cga_layout:
  292. assert len(cga_layout[0]) == self.rank
  293. super().__setattr__("rank", _unwrap_if_constexpr(self.rank))
  294. super().__setattr__("cga_layout", _unwrap_if_constexpr(cga_layout))
  295. assert self.element_bitwidth in [8, 16, 32, 64]
  296. assert self.swizzle_byte_width in [0, 32, 64, 128]
  297. def _to_ir(self, builder):
  298. return builder.get_nvmma_shared_layout(
  299. self.swizzle_byte_width,
  300. self.element_bitwidth,
  301. self.transposed,
  302. self.fp4_padded,
  303. self.cga_layout,
  304. self.rank,
  305. )
  306. @staticmethod
  307. @constexpr_function
  308. def get_default_for(block_shape, dtype, transposed=False, fp4_padded=False, cga_layout=None):
  309. """Returns an NVMMASharedLayout with default swizzling for a given shape.
  310. This picks the largest swizzle pattern compatible with the shape, which
  311. allows emitting the fewest TMA or MMA messages.
  312. """
  313. packing_factor = 2 if fp4_padded else 1
  314. shape_per_cta = block_shape if cga_layout is None else _get_shape_per_cta(block_shape, cga_layout)
  315. rank = len(block_shape)
  316. if transposed:
  317. shape_per_cta = shape_per_cta[1:] + shape_per_cta[:1]
  318. contig_dim_size = shape_per_cta[-1] * packing_factor
  319. contig_dim_bytes = contig_dim_size * dtype.primitive_bitwidth // 8
  320. if contig_dim_bytes >= 128 and contig_dim_bytes % 128 == 0:
  321. swizzle_byte_width = 128
  322. elif contig_dim_bytes >= 64 and contig_dim_bytes % 64 == 0:
  323. swizzle_byte_width = 64
  324. elif contig_dim_bytes >= 32 and contig_dim_bytes % 32 == 0:
  325. swizzle_byte_width = 32
  326. else:
  327. swizzle_byte_width = 0
  328. flatten_outer_dim = 1
  329. for size in shape_per_cta[:-1]:
  330. flatten_outer_dim *= size
  331. if len(block_shape) < 2 or flatten_outer_dim < 8:
  332. swizzle_byte_width = 0
  333. return NVMMASharedLayout(
  334. swizzle_byte_width=swizzle_byte_width,
  335. element_bitwidth=dtype.primitive_bitwidth,
  336. rank=rank,
  337. transposed=transposed,
  338. fp4_padded=fp4_padded,
  339. cga_layout=cga_layout,
  340. )
  341. def mangle(self) -> str:
  342. cga_layout = "_".join("~".join(map(str, vec)) for vec in self.cga_layout) if self.cga_layout else ""
  343. return f"NVMMA_{self.swizzle_byte_width}_{self.element_bitwidth}_{self.transposed}_{self.fp4_padded}_{cga_layout}_NVMMA"
  344. def __hash__(self):
  345. return hash((self.swizzle_byte_width, self.element_bitwidth, self.rank, self.transposed, self.fp4_padded,
  346. tuple(tuple(vec) for vec in self.cga_layout) if self.cga_layout else None))
  347. @dataclass(frozen=True, eq=True)
  348. class SwizzledSharedLayout(SharedLayout):
  349. """
  350. Represents a generic swizzled shared memory layout.
  351. Args:
  352. vec (int): Vector width for swizzling.
  353. per_phase (int): Elements per swizzle phase.
  354. max_phase (int): Maximum number of swizzle phases.
  355. order (List[int]): Dimension ordering for swizzling.
  356. cga_layout (Optional[List[List[int]]]): Bases describing CTA tiling.
  357. """
  358. vec: int
  359. per_phase: int
  360. max_phase: int
  361. order: List[int]
  362. cga_layout: List[List[int]] = field(default_factory=list)
  363. def __post_init__(self):
  364. super().__setattr__("vec", _unwrap_if_constexpr(self.vec))
  365. super().__setattr__("per_phase", _unwrap_if_constexpr(self.per_phase))
  366. super().__setattr__("max_phase", _unwrap_if_constexpr(self.max_phase))
  367. super().__setattr__("order", _unwrap_if_constexpr(self.order))
  368. object.__setattr__(self, "cga_layout", self.cga_layout)
  369. def _to_ir(self, builder):
  370. return builder.get_swizzled_shared_layout(
  371. self.vec,
  372. self.per_phase,
  373. self.max_phase,
  374. self.order,
  375. self.cga_layout,
  376. )
  377. def mangle(self) -> str:
  378. def stringify(x):
  379. if x is None:
  380. return ""
  381. return "_".join(map(str, x))
  382. cga_layout = "_".join("~".join(map(str, vec)) for vec in self.cga_layout) if self.cga_layout else ""
  383. return f"SSS_{self.vec}_{self.per_phase}_{self.max_phase}_{stringify(self.order)}_{cga_layout}_SSS"
  384. def __hash__(self):
  385. return hash(
  386. (self.vec, self.per_phase, self.max_phase, tuple(self.order), tuple(tuple(vec) for vec in self.cga_layout)))
  387. @dataclass(frozen=True, eq=True)
  388. class PaddedSharedLayout(SharedLayout):
  389. """
  390. Represents a layout for the access to shared memory. Compared to SwizzledSharedLayout,
  391. it combined padding and element reordering via linear transformation (e.g. row permutation)
  392. to avoid shared memory bank conflicts. After every interval tensor elements, the
  393. corresponding number of padding elements are inserted. If a position corresponds to
  394. multiple intervals, the padding amounts are summed.
  395. In the following example of a tensor,
  396. `eM` represents original elements in the and `pN` represents padded element.
  397. Before padding, the shared memory looks like:
  398. [e0, e1,
  399. e2, e3,
  400. e4, e5,
  401. e6, e7,
  402. ...]
  403. After padding with interval-padding list [[2, 1], [4, 2]] with an identity remapping,
  404. the shared memory will be
  405. [e0, e1, p0,
  406. e2, e3, p1, p2, p3,
  407. e4, e5, p4,
  408. e6, e7, p5, p6, p7,
  409. ...]
  410. Furthermore this encoding allows for a linear remapping from the 1-D shared
  411. memory offset to logical n-D tensor elements. The remapping is given in the form
  412. of linear bases mapping from offset to [dim0, dim1...dimN-1].
  413. See LinearLayout.h for more details how linear layouts are applied to remap
  414. elements.
  415. Some concrete examples using `xN` and `yN` to mean the logical n-D tensor elements
  416. and `pN` to mean padding:
  417. After padding for shape = [8] with interval-padding list [[2, 2]], offset_bases = [[2], [1]] and block_bases = []:
  418. [x0, x2, p0 p1, x1, x3]
  419. After padding for shape = [8, 4] with interval_padding_pairs = [[8, 1]], offset_bases = [[0, 1], [0, 2], /*gap, stride by 2 rows*/[2, 0], [4, 0], [1, 0]]] and block_bases = []:
  420. [
  421. x0y0, x0y1, x0y2, x0y3,
  422. x2y0, x2y1, x2y2, x2y3,
  423. p0,
  424. x4y0, x4y1, x4y2, x4y3,
  425. x6y0, x6y1, x6y2, x6y3,
  426. p1,
  427. x1y0, x1y1, x1y2, x1y3,
  428. x3y0, x3y1, x3y2, x3y3,
  429. p2,
  430. x5y0, x5y1, x5y2, x5y3,
  431. x7y0, x7y1, x7y2, x7y3,
  432. ]
  433. Args:
  434. interval_padding_pairs (List[int]): List of [interval, padding] pair and both interval and padding must be powers of 2.
  435. offset_bases (List[int]): Bases for shared memory offsets
  436. block_bases (List[List[int]]): Bases for block-level shared memory offsets.
  437. shape (List[int]): n-D logical shared memory shape
  438. """
  439. interval_padding_pairs: List[List[int]]
  440. offset_bases: List[List[int]]
  441. block_bases: List[List[int]]
  442. shape: List[int]
  443. def __post_init__(self):
  444. super().__setattr__("interval_padding_pairs", _unwrap_shape(self.interval_padding_pairs))
  445. super().__setattr__("offset_bases", _unwrap_shape(self.offset_bases))
  446. super().__setattr__("block_bases", _unwrap_shape(self.block_bases))
  447. super().__setattr__("shape", _unwrap_shape(self.shape))
  448. rank = len(self.shape)
  449. for basis in self.offset_bases:
  450. assert len(basis) == rank
  451. for basis in self.block_bases:
  452. assert len(basis) == rank
  453. self.verify()
  454. def _to_ir(self, builder):
  455. intervals, paddings = zip(*self.interval_padding_pairs)
  456. return builder.get_padded_shared_layout(intervals, paddings, self.offset_bases, self.block_bases, self.shape)
  457. def mangle(self) -> str:
  458. return f"PaddedShared_{self.interval_padding_pairs}_{self.offset_bases}_{self.block_bases}_{self.shape}_PaddedShared"
  459. def verify(self):
  460. pairs = self.interval_padding_pairs
  461. assert len(pairs) > 0, "PaddedSharedLayout interval_padding_pairs must have at least one interval-padding pair"
  462. assert all(len(pair) == 2 for pair in pairs)
  463. intervals, paddings = zip(*pairs)
  464. unique_intervals = list(set(intervals))
  465. assert len(unique_intervals) == len(intervals)
  466. is_power_of_2 = lambda n: n > 0 and n & (n - 1) == 0
  467. assert all(is_power_of_2(n) for n in intervals), "PaddedSharedLayout interval values must all be power of two"
  468. assert all(is_power_of_2(n) for n in paddings), "PaddedSharedLayout padding values must all be power of two"
  469. rank = len(self.shape)
  470. assert rank > 0, "PaddedSharedLayout order must not be empty"
  471. @staticmethod
  472. @constexpr_function
  473. def with_identity_for(interval_padding_pairs, shape, order):
  474. """Returns a PaddedSharedLayout with the given interval and padding pairs and an identity mapping as the linear component for the given shape and order.
  475. """
  476. assert len(shape) == len(order)
  477. is_power_of_2 = lambda n: n > 0 and n & (n - 1) == 0
  478. assert all(is_power_of_2(n) for n in shape)
  479. rank = len(shape)
  480. # Create a idendity mapping based on shape + order
  481. offset_bases = []
  482. for dim in order:
  483. for basis in range(int(math.log2(shape[dim]))):
  484. offset_bases.append([1 << basis if i == dim else 0 for i in range(rank)])
  485. return PaddedSharedLayout(interval_padding_pairs, offset_bases, [], shape)
  486. def __hash__(self):
  487. return hash((tuple(map(tuple, self.interval_padding_pairs)), tuple(map(tuple, self.offset_bases)),
  488. tuple(map(tuple, self.block_bases)), tuple(self.shape)))
  489. @dataclass(frozen=True)
  490. class SharedLinearLayout(SharedLayout):
  491. """Represents a shared memory layout defined via an explicit LinearLayout."""
  492. offset_bases: List[List[int]]
  493. block_bases: List[List[int]] = field(default_factory=list)
  494. alignment: int = 16
  495. def __post_init__(self):
  496. super().__setattr__("offset_bases", _unwrap_shape(self.offset_bases))
  497. super().__setattr__("block_bases", _unwrap_shape(self.block_bases))
  498. super().__setattr__("alignment", _unwrap_if_constexpr(self.alignment))
  499. assert len(self.offset_bases) != 0, "SharedLinearLayout offset_bases must not be empty"
  500. rank = len(self.offset_bases[0])
  501. assert rank > 0, "SharedLinearLayout offset_bases must not be empty"
  502. for basis in self.offset_bases:
  503. assert len(basis) == rank
  504. for basis in self.block_bases:
  505. assert len(basis) == rank
  506. assert self.alignment > 0 and (self.alignment & (self.alignment - 1)) == 0, \
  507. "SharedLinearLayout alignment must be a positive power of two"
  508. def _to_ir(self, builder):
  509. return builder.get_shared_linear_layout(self.offset_bases, self.block_bases, self.alignment)
  510. def mangle(self) -> str:
  511. return f"SharedLinear_{self.offset_bases}_{self.block_bases}_{self.alignment}_SharedLinear"
  512. def __hash__(self):
  513. return hash((
  514. tuple(map(tuple, self.offset_bases)),
  515. tuple(map(tuple, self.block_bases)),
  516. self.alignment,
  517. ))
  518. # Python impl of LinearEncodingAttr::basesPerDim
  519. def bases_per_dim(bases, rank, skip_broadcast=True):
  520. result = [1] * rank
  521. if not bases:
  522. return result
  523. non_zero_idx = None
  524. for basis in bases:
  525. # Find the first non-zero index in the current basis
  526. idx = next((i for i, v in enumerate(basis) if v != 0), None)
  527. if idx is not None:
  528. non_zero_idx = idx
  529. result[idx] *= 2
  530. elif not skip_broadcast:
  531. # If no non-zero found and we're not skipping broadcasts, use the last found non-zero index
  532. assert non_zero_idx is not None
  533. result[non_zero_idx] *= 2
  534. return result
  535. def warps_per_cta(layout, shape):
  536. if isinstance(layout, DistributedLinearLayout):
  537. return bases_per_dim(layout.warp_bases, len(shape))
  538. elif isinstance(layout, (SliceLayout, DotOperandLayout)):
  539. return warps_per_cta(layout.parent, shape)
  540. else:
  541. return layout.warps_per_cta