layout.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. #################################################################################################
  2. #
  3. # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  4. # SPDX-License-Identifier: BSD-3-Clause
  5. #
  6. # Redistribution and use in source and binary forms, with or without
  7. # modification, are permitted provided that the following conditions are met:
  8. #
  9. # 1. Redistributions of source code must retain the above copyright notice, this
  10. # list of conditions and the following disclaimer.
  11. #
  12. # 2. Redistributions in binary form must reproduce the above copyright notice,
  13. # this list of conditions and the following disclaimer in the documentation
  14. # and/or other materials provided with the distribution.
  15. #
  16. # 3. Neither the name of the copyright holder nor the names of its
  17. # contributors may be used to endorse or promote products derived from
  18. # this software without specific prior written permission.
  19. #
  20. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  21. # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  22. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  23. # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
  24. # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  25. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  26. # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  27. # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  28. # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  29. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  30. #
  31. #################################################################################################
  32. """
  33. Definition of CuTe Layouts and functions to manipulate them which works with the order
  34. of lexicographic instead of co-lexicographic as implemented in the original layout.py
  35. """
  36. from itertools import chain
  37. from typing import TypeAlias
  38. from typing_extensions import Self, TypeIs
  39. from .int_tuple import (
  40. crd2idx,
  41. flatten,
  42. has_none,
  43. IntTuple,
  44. is_int,
  45. is_tuple,
  46. product,
  47. slice_,
  48. suffix_product,
  49. )
  50. # Type aliases
  51. CoordinateType: TypeAlias = (
  52. int | IntTuple | tuple[object, ...] | None
  53. ) # Input for slice_ and crd2idx functions
  54. class LayoutBase:
  55. pass
  56. def is_layout(x: object) -> TypeIs["Layout"]:
  57. return isinstance(x, LayoutBase)
  58. class Layout(LayoutBase):
  59. def __init__(self, _shape: IntTuple, _stride: IntTuple | None = None) -> None:
  60. self.shape = _shape
  61. if _stride is None:
  62. self.stride = suffix_product(self.shape)
  63. else:
  64. self.stride = _stride
  65. # operator ==
  66. def __eq__(self, other: object) -> bool:
  67. if not isinstance(other, Layout):
  68. return False
  69. return self.shape == other.shape and self.stride == other.stride
  70. # operator len(L) (len [rank] like tuples)
  71. def __len__(self) -> int:
  72. if is_tuple(self.shape):
  73. return len(self.shape)
  74. else:
  75. return 1
  76. # operator () (map coord to idx)
  77. def __call__(self, *args: CoordinateType) -> Self | int:
  78. """
  79. Map a logical coordinate to a linear index (Coord has no Underscore slice operators)
  80. OR
  81. Slice the layout and return the sublayout (Coord has an Underscore slice op)
  82. Follow the same behavior of `Layout::operator(Coord const&)` in cute C++
  83. """
  84. if has_none(args):
  85. if len(args) == 1:
  86. return Layout(slice_(args[0], self.shape), slice_(args[0], self.stride))
  87. else:
  88. return Layout(slice_(args, self.shape), slice_(args, self.stride))
  89. else:
  90. if len(args) == 1:
  91. return crd2idx(args[0], self.shape, self.stride) # type: ignore[arg-type]
  92. else:
  93. return crd2idx(args, self.shape, self.stride) # type: ignore[arg-type]
  94. # operator [] (get-i like tuples)
  95. def __getitem__(self, i: int) -> Self:
  96. if is_tuple(self.shape):
  97. return Layout(self.shape[i], self.stride[i]) # type: ignore[index]
  98. else:
  99. assert i == 0
  100. return Layout(self.shape, self.stride)
  101. # size(layout) Size of the domain
  102. def size(self) -> int:
  103. return product(self.shape)
  104. # cosize(layout) Size of the codomain
  105. def cosize(self) -> int:
  106. return self(self.size() - 1) + 1 # type: ignore[operator]
  107. # print and str
  108. def __str__(self) -> str:
  109. return f"{self.shape}:{self.stride}"
  110. # error msgs and representation
  111. def __repr__(self) -> str:
  112. return f"Layout({self.shape},{self.stride})"
  113. # Type aliases
  114. LayoutOrIntTuple: TypeAlias = Layout | IntTuple
  115. LayoutProfile: TypeAlias = tuple[object, ...] | Layout | None
  116. LayoutInput: TypeAlias = Layout | IntTuple | tuple[object, ...] | None
  117. # Make Layout from a list of layouts (each layout it's own mode in the result)
  118. def make_layout(*layouts: Layout | tuple[Layout, ...]) -> Layout:
  119. if len(layouts) == 1 and not is_layout(layouts[0]):
  120. layouts = layouts[0]
  121. shape, stride = zip(*((a.shape, a.stride) for a in layouts)) # type: ignore[union-attr]
  122. return Layout(shape, stride)
  123. # Size of the domain
  124. def size(layout: LayoutOrIntTuple) -> int:
  125. if is_layout(layout):
  126. return layout.size()
  127. return product(layout)
  128. # Size of the codomain
  129. def cosize(layout: Layout) -> int:
  130. return layout.cosize()
  131. # Layout coalesce -- flatten and combine as many modes as possible while preserving the int-to-int function
  132. def coalesce(layout: Layout, profile: LayoutProfile = None) -> Layout:
  133. if is_tuple(profile):
  134. assert len(layout) >= len(profile)
  135. return make_layout(
  136. # pyrefly: ignore [bad-argument-type]
  137. chain(
  138. (coalesce(layout[i], profile[i]) for i in range(len(profile))), # type: ignore[arg-type]
  139. (layout[i] for i in range(len(profile), len(layout))),
  140. )
  141. )
  142. result_shape = [1]
  143. result_stride = [0]
  144. # Since we now follow lexicographic order, we need to process from right to left.
  145. # And to make implementation more efficient, we append to the end of list and reverse it in the end.
  146. for shape, stride in zip(
  147. reversed(flatten(layout.shape)), reversed(flatten(layout.stride))
  148. ):
  149. # skip their shape-1s
  150. if shape == 1:
  151. continue
  152. # replace our shape-1 with anything
  153. elif result_shape[-1] == 1:
  154. result_shape[-1] = shape
  155. result_stride[-1] = stride
  156. # merge modes if the shape*stride match
  157. elif result_shape[-1] * result_stride[-1] == stride:
  158. result_shape[-1] = result_shape[-1] * shape
  159. # append a new mode
  160. else:
  161. result_shape.append(shape)
  162. result_stride.append(stride)
  163. if len(result_shape) == 1:
  164. return Layout(result_shape[0], result_stride[0])
  165. else:
  166. result_shape.reverse()
  167. result_stride.reverse()
  168. return Layout(tuple(result_shape), tuple(result_stride))
  169. # Layout filter -- replace all stride-0 modes with size-1 and then coalesce to remove them
  170. def filter(layout: Layout, profile: LayoutProfile = None) -> Layout:
  171. if is_tuple(profile):
  172. assert len(layout) >= len(profile)
  173. return make_layout(
  174. # pyrefly: ignore [bad-argument-type]
  175. chain(
  176. (filter(layout[i], profile[i]) for i in range(len(profile))), # type: ignore[arg-type]
  177. (layout[i] for i in range(len(profile), len(layout))),
  178. )
  179. )
  180. result_shape = []
  181. result_stride = []
  182. for shape, stride in zip(flatten(layout.shape), flatten(layout.stride)):
  183. # skip their shape-1s and stride-0s
  184. if not (shape == 1 or stride == 0):
  185. result_shape.append(shape)
  186. result_stride.append(stride)
  187. if len(result_shape) == 0:
  188. return Layout(1, 0)
  189. else:
  190. return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
  191. # Layout composition
  192. # Use tuples-of-layouts to perform this operation by-mode and None as no-op
  193. def composition(layoutA: Layout, layoutB: LayoutInput) -> Layout:
  194. if layoutB is None:
  195. return layoutA
  196. elif is_int(layoutB):
  197. return composition(layoutA, Layout(layoutB))
  198. elif is_tuple(layoutB):
  199. assert len(layoutA) >= len(layoutB)
  200. return make_layout(
  201. # pyrefly: ignore [bad-argument-type]
  202. chain(
  203. (composition(layoutA[i], layoutB[i]) for i in range(len(layoutB))), # type: ignore[arg-type]
  204. (layoutA[i] for i in range(len(layoutB), len(layoutA))),
  205. )
  206. )
  207. elif is_tuple(layoutB.shape):
  208. return make_layout(composition(layoutA, layoutB_i) for layoutB_i in layoutB) # type: ignore[arg-type, attr-defined]
  209. if layoutB.stride == 0:
  210. return Layout(layoutB.shape, 0)
  211. else:
  212. result_shape = []
  213. result_stride = []
  214. rest_shape = layoutB.shape
  215. rest_stride = layoutB.stride
  216. flat_A = coalesce(layoutA)
  217. # when left layout is multi-dimensional sublayout, aka, self = (a,b,...,c):(x,y,...,z), layout = s:d,
  218. # for integral s and d means that we want:
  219. # (1) “remove” the first d elements from left, starting from rightmost. (This will increase the stride.)
  220. # (2) “keep” the first s of those strided elements. (This does not affect the stride.)
  221. # For example, if self = (6,2):(2,1), layout = (3:2)
  222. # Step 1: remove the first 2 elements from self with stride increase, i.e., (6,2):(2,1) -> (6,1):(2,2)
  223. # Step 2: keep the first 3 of those strided elements, i.e., (6,1):(2,2) -> (3,1):(2,2)
  224. # Because we are going lexicographically, we go through left layout from right to left.
  225. for curr_shape, curr_stride in zip(
  226. reversed(flatten(flat_A.shape)[1:]), reversed(flatten(flat_A.stride)[1:])
  227. ):
  228. assert curr_shape % rest_stride == 0 or rest_stride % curr_shape == 0 # type: ignore[operator]
  229. new_shape = min(max(1, curr_shape // rest_stride), rest_shape) # type: ignore[operator]
  230. if new_shape != 1:
  231. result_shape.append(new_shape) # Append to end, will reverse later
  232. result_stride.append(rest_stride * curr_stride)
  233. rest_shape = rest_shape // new_shape # type: ignore[operator]
  234. rest_stride = -(
  235. -rest_stride // curr_shape # type: ignore[operator]
  236. ) # Python exclusive impl: "//" is always floor div so == ceil_div(abs(rest_stride), curr_shape) * signum(rest_stride)
  237. # When left has single-size sublayout or reach the last sublayout, aka, left = a:b, layout = s:d,
  238. # the result is rather trivial: left o layout = a:b o s:d = s:(b*d).
  239. # For example, if self = (6:2), layout = (3:2), the result is (3:(2*2)) = (3:4).
  240. if rest_shape != 1 or len(result_shape) == 0:
  241. result_shape.append(rest_shape) # Append to end, will reverse later
  242. result_stride.append(rest_stride * flatten(flat_A.stride)[0])
  243. # Reverse the lists because we build lists in reverse order (append to end), this way it is more efficient.
  244. result_shape.reverse()
  245. result_stride.reverse()
  246. if len(result_shape) == 1:
  247. return Layout(result_shape[0], result_stride[0]) # type: ignore[arg-type]
  248. else:
  249. return Layout(tuple(result_shape), tuple(result_stride)) # type: ignore[arg-type]
  250. # Layout complement
  251. def complement(layout: LayoutOrIntTuple, max_idx: int = 1) -> Layout:
  252. if is_int(layout):
  253. return complement(Layout(layout))
  254. result_shape = []
  255. result_stride = []
  256. current_idx = 1
  257. sorted_DS = sorted(zip(flatten(layout.stride), flatten(layout.shape))) # type: ignore[union-attr]
  258. for stride, shape in sorted_DS:
  259. if stride == 0 or shape == 1:
  260. continue
  261. in_bound = current_idx <= shape * stride
  262. # To support symbolic value which can't be evaluated now
  263. assert (type(in_bound) is not bool) or in_bound
  264. result_shape.append(stride // current_idx)
  265. result_stride.append(current_idx)
  266. current_idx = shape * stride
  267. result_shape.append((max_idx + current_idx - 1) // current_idx) # ceil_div
  268. result_stride.append(current_idx)
  269. # This is different from original pycute implementation, because we want to follow the lexicographic order here
  270. # where the right-most dimension is the innermost dimension (smallest stride).
  271. result_shape.reverse()
  272. result_stride.reverse()
  273. return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
  274. # Layout right inverse
  275. def right_inverse(layout: LayoutOrIntTuple | None) -> Layout | None:
  276. if layout is None:
  277. return None
  278. elif is_int(layout):
  279. return Layout(layout)
  280. result_shape = []
  281. result_stride = []
  282. current_idx = 1
  283. flat_shape = flatten(layout.shape) # type: ignore[union-attr]
  284. flat_stride = flatten(layout.stride) # type: ignore[union-attr]
  285. sorted_DSA = sorted(zip(flat_stride, flat_shape, suffix_product(flat_shape))) # type: ignore[arg-type]
  286. for stride, shape, rstride in sorted_DSA:
  287. if shape == 1:
  288. continue
  289. if current_idx != stride:
  290. break
  291. result_shape.append(shape)
  292. result_stride.append(rstride)
  293. current_idx = shape * stride
  294. result_shape.reverse()
  295. result_stride.reverse()
  296. return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
  297. # Layout left inverse
  298. def left_inverse(layout: LayoutOrIntTuple | None) -> Layout | None:
  299. if layout is None:
  300. return None
  301. elif is_int(layout):
  302. return Layout(layout)
  303. return right_inverse(make_layout(complement(layout), layout)) # type: ignore[arg-type]
  304. # Split a layout by the composition of B and the "rest"
  305. # Use tuples-of-layouts to perform this operation by-mode and None as no-op
  306. def logical_divide(layoutA: Layout, layoutB: LayoutInput) -> Layout:
  307. if layoutB is None:
  308. return layoutA
  309. elif is_int(layoutB):
  310. return logical_divide(layoutA, Layout(layoutB))
  311. elif is_tuple(layoutB):
  312. assert len(layoutA) >= len(layoutB)
  313. return make_layout(
  314. # pyrefly: ignore [bad-argument-type]
  315. chain(
  316. (
  317. logical_divide(layoutA[i], layoutB[i]) # type: ignore[arg-type]
  318. for i in range(len(layoutB))
  319. ),
  320. (layoutA[i] for i in range(len(layoutB), len(layoutA))),
  321. )
  322. )
  323. return composition(
  324. layoutA,
  325. make_layout(layoutB, complement(layoutB, size(layoutA))),
  326. )
  327. # Reproduce a layoutA over a layoutB
  328. # Use tuples-of-layouts to perform this operation by-mode and None as no-op
  329. def logical_product(layoutA: Layout, layoutB: LayoutInput) -> Layout:
  330. if layoutB is None:
  331. return layoutA
  332. elif is_int(layoutB):
  333. return logical_divide(layoutA, Layout(layoutB))
  334. elif is_tuple(layoutB):
  335. assert len(layoutA) >= len(layoutB)
  336. return make_layout(
  337. # pyrefly: ignore [bad-argument-type]
  338. chain(
  339. (
  340. logical_product(layoutA[i], layoutB[i]) # type: ignore[arg-type]
  341. for i in range(len(layoutB))
  342. ),
  343. (layoutA[i] for i in range(len(layoutB), len(layoutA))),
  344. )
  345. )
  346. return make_layout(
  347. layoutA,
  348. composition(complement(layoutA, size(layoutA) * cosize(layoutB)), layoutB),
  349. )
  350. # Gather the modes from a hierarchical logical_divide or logical_product
  351. def hier_unzip(
  352. splitter: object,
  353. layoutA: Layout,
  354. layoutB: LayoutInput,
  355. ) -> Layout:
  356. if layoutB is None:
  357. return make_layout(Layout(1, 0), layoutA)
  358. elif is_tuple(layoutB):
  359. assert len(layoutA) >= len(layoutB)
  360. # A layout with shape ((A,a),(B,b),(C,c))
  361. split = make_layout(
  362. hier_unzip(splitter, layoutA[i], layoutB[i]) # type: ignore[arg-type]
  363. for i in range(len(layoutB))
  364. )
  365. # Gather to shape ((A,B,C,...),(a,b,c,...,y,z))
  366. return make_layout(
  367. make_layout(split[i][0] for i in range(len(layoutB))), # type: ignore[arg-type]
  368. make_layout(
  369. chain( # type: ignore[arg-type]
  370. (split[i][1] for i in range(len(layoutB))),
  371. (layoutA[i] for i in range(len(layoutB), len(layoutA))),
  372. )
  373. ),
  374. )
  375. # splitter must return a rank-2 layout
  376. return splitter(layoutA, layoutB) # type: ignore[operator]
  377. # Apply logical divide hierarchically and gather the split modes into two modes
  378. def zipped_divide(layoutA: Layout, layoutB: LayoutInput) -> Layout:
  379. return hier_unzip(logical_divide, layoutA, layoutB)
  380. # Perform logical divide hierarchically and gather tiles (B-layouts) into a new mode
  381. def tiled_divide(layoutA: Layout, layoutB: LayoutInput) -> Layout:
  382. result = zipped_divide(layoutA, layoutB)
  383. return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))]) # type: ignore[arg-type]
  384. # Apply logical product hierarchically and gather the split modes into two modes
  385. def zipped_product(layoutA: Layout, layoutB: LayoutInput) -> Layout:
  386. return hier_unzip(logical_product, layoutA, layoutB)
  387. # Perform logical product hierarchically and gather tiles (B-layouts) into a new mode
  388. def tiled_product(layoutA: Layout, layoutB: LayoutInput) -> Layout:
  389. result = zipped_product(layoutA, layoutB)
  390. return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))]) # type: ignore[arg-type]
  391. def slice_and_offset(crd: tuple[object, ...], layout: Layout) -> tuple[Layout, int]:
  392. return (
  393. Layout(slice_(crd, layout.shape), slice_(crd, layout.stride)),
  394. crd2idx(crd, layout.shape, layout.stride), # type: ignore[arg-type]
  395. )