_triton_ops.py 89 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626
  1. # mypy: allow-untyped-decorators
  2. # mypy: allow-untyped-defs
  3. import math
  4. import os
  5. import weakref
  6. from functools import lru_cache
  7. import torch
  8. from torch._dynamo.utils import warn_once
  9. from torch.utils._triton import has_triton
  10. from ._triton_ops_meta import get_meta
  11. TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE = int(
  12. os.getenv("TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE", 2)
  13. )
  14. def check(cond, msg):
  15. if not cond:
  16. raise ValueError(msg)
  17. def check_bsr_layout(f_name, t):
  18. check(
  19. t.layout == torch.sparse_bsr,
  20. f"{f_name}(): only BSR sparse format is supported for the sparse argument.",
  21. )
  22. def check_device(f_name, t, device):
  23. check(
  24. t.device == device and t.device.type == "cuda",
  25. f"{f_name}(): all inputs are expected to be on the same GPU device.",
  26. )
  27. def check_mm_compatible_shapes(f_name, lhs, rhs):
  28. check(
  29. lhs.dim() >= 2 and rhs.dim() >= 2,
  30. f"{f_name}(): all inputs involved in the matrix product are expected to be at least 2D, "
  31. f"but got lhs.dim() == {lhs.dim()} and rhs.dim() == {rhs.dim()}.",
  32. )
  33. _m, kl = lhs.shape[-2:]
  34. kr, _n = rhs.shape[-2:]
  35. check(
  36. kl == kr,
  37. f"{f_name}(): arguments' sizes involved in the matrix product are not compatible for matrix multiplication, "
  38. f"got lhs.shape[-1] == {kl} which is not equal to rhs.shape[-2] == {kr}.",
  39. )
  40. def check_dtype(f_name, t, dtype, *additional_dtypes):
  41. check(
  42. t.dtype == dtype
  43. and t.dtype
  44. in ((torch.half, torch.bfloat16, torch.float) + tuple(*additional_dtypes)),
  45. f"{f_name}(): all inputs are expected to be of the same dtype "
  46. f"and one of (half, bfloat16, float32) or {additional_dtypes}, "
  47. f"but got dtype == {t.dtype}.",
  48. )
  49. def check_blocksize(f_name, blocksize):
  50. if len(blocksize) != 2:
  51. raise AssertionError(f"blocksize must have length 2, got {len(blocksize)}")
  52. def is_power_of_two(v):
  53. return not (v & (v - 1))
  54. def is_compatible_blocksize(b):
  55. res = True
  56. for blocksize in b:
  57. # Triton loads only blocks which are at least 16 and powers of 2.
  58. res = (blocksize >= 16 and is_power_of_two(blocksize)) and res
  59. return res
  60. check(
  61. is_compatible_blocksize(blocksize),
  62. f"{f_name}(): sparse inputs' blocksize ({blocksize[0]}, {blocksize[1]}) "
  63. "should be at least 16 and a power of 2 in each dimension.",
  64. )
  65. def make_triton_contiguous(t):
  66. """Return input as a triton-contiguous tensor.
  67. A triton-contiguous tensor is defined as a tensor that has strides
  68. with minimal value smaller than or equal to 1.
  69. While triton kernels support triton-non-contiguous tensors (all
  70. strides being greater than 1) arguments, a considerable slow-down
  71. occurs because tensor data is copied element-wise rather than
  72. chunk-wise. Zero strides is assumed to not have this defect.
  73. """
  74. if min(t.stride()) > 1:
  75. # TODO: investigate if contiguity along other axes than the
  76. # last one can be beneficial for performance
  77. return t.contiguous()
  78. else:
  79. return t
  80. def broadcast_batch_dims(f_name, *tensors):
  81. try:
  82. return torch.broadcast_shapes(*(t.shape[:-2] for t in tensors))
  83. except Exception:
  84. check(False, f"{f_name}(): inputs' batch dimensions are not broadcastable!")
  85. def slicer(dim, slice_range, *tensors):
  86. for t in tensors:
  87. slices = [slice(None)] * t.dim()
  88. slices[dim] = slice_range
  89. yield t[slices]
  90. def multidim_slicer(dims, slices, *tensors):
  91. for t in tensors:
  92. s = [slice(None)] * t.dim()
  93. for d, d_slice in zip(dims, slices, strict=False):
  94. if d is not None:
  95. s[d] = d_slice
  96. yield t[tuple(s)]
  97. def ptr_stride_extractor(*tensors):
  98. for t in tensors:
  99. yield t
  100. yield from t.stride()
  101. def grid_partitioner(full_grid, grid_blocks, tensor_dims_map):
  102. if len(full_grid) < 0 or len(full_grid) > 3:
  103. raise AssertionError(f"full_grid length must be 0-3, got {len(full_grid)}")
  104. if len(grid_blocks) < 0 or len(grid_blocks) > 3:
  105. raise AssertionError(f"grid_blocks length must be 0-3, got {len(grid_blocks)}")
  106. import itertools
  107. def generate_grid_points():
  108. for fg, mg in zip(full_grid, grid_blocks, strict=False):
  109. yield range(0, fg, mg)
  110. def generate_sliced_tensors(slices):
  111. for t, t_dims in tensor_dims_map.items():
  112. yield next(multidim_slicer(t_dims, slices, t))
  113. for grid_point in itertools.product(*generate_grid_points()):
  114. grid = [
  115. min(fg - gp, mg)
  116. for fg, gp, mg in zip(full_grid, grid_point, grid_blocks, strict=False)
  117. ]
  118. slices = [slice(gp, gp + g) for gp, g in zip(grid_point, grid, strict=False)]
  119. # grid_points are iterated in a "contiguous" order, i.e.
  120. # left dimensions traversed slower than right dimensions.
  121. # This order is reversed for CUDA grids.
  122. yield grid[::-1], *generate_sliced_tensors(slices)
  123. def launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks=None):
  124. # cuda_max_grid = (2 ** 31 - 1, 2 ** 16 - 1, 2 ** 16 - 1)
  125. cuda_max_grid = (2147483647, 65535, 65535)[::-1]
  126. if grid_blocks is None:
  127. grid_blocks = cuda_max_grid
  128. else:
  129. def valid_grid_dim(g, mg):
  130. if g is None:
  131. return mg
  132. else:
  133. # grid must be at least 1 and no greater than mg
  134. return max(1, min(g, mg))
  135. grid_blocks = tuple(
  136. valid_grid_dim(g, mg)
  137. for g, mg in zip(grid_blocks, cuda_max_grid, strict=False)
  138. ) # type: ignore[assignment]
  139. for grid, *sliced_tensors in grid_partitioner(
  140. full_grid, grid_blocks, tensor_dims_map
  141. ):
  142. kernel(grid, *sliced_tensors)
  143. def prepare_inputs(bsr, *dense_tensors):
  144. # Introduce fake batch dimension if not present for convenience.
  145. crow_indices = bsr.crow_indices().unsqueeze(0)
  146. col_indices = bsr.col_indices().unsqueeze(0)
  147. values = make_triton_contiguous(bsr.values().unsqueeze(0))
  148. tensors = [make_triton_contiguous(t.unsqueeze(0)) for t in dense_tensors]
  149. # Compute broadcasted batch dimension
  150. batch_dims_broadcasted = torch.broadcast_shapes(
  151. values.shape[:-3], *(t.shape[:-2] for t in tensors)
  152. )
  153. # Broadcast batch dimensions and squash.
  154. # The result can be either a view or a copy.
  155. def batch_broadcast_and_squash(t, batch_dims, invariant_dims):
  156. return t.broadcast_to(batch_dims + invariant_dims).flatten(
  157. 0, len(batch_dims) - 1
  158. )
  159. crow_indices = batch_broadcast_and_squash(
  160. crow_indices, batch_dims_broadcasted, (-1,)
  161. )
  162. col_indices = batch_broadcast_and_squash(col_indices, batch_dims_broadcasted, (-1,))
  163. values = batch_broadcast_and_squash(
  164. values, batch_dims_broadcasted, values.shape[-3:]
  165. )
  166. tensors = [
  167. batch_broadcast_and_squash(t, batch_dims_broadcasted, t.shape[-2:])
  168. for t in tensors
  169. ]
  170. return crow_indices, col_indices, values, *tensors
  171. def broadcast_batch_dims_bsr(f_name, bsr, *tensors):
  172. batch_shape = broadcast_batch_dims(f_name, bsr, *tensors)
  173. crow_indices = bsr.crow_indices().broadcast_to(batch_shape + (-1,))
  174. col_indices = bsr.col_indices().broadcast_to(batch_shape + (-1,))
  175. values = bsr.values().broadcast_to(batch_shape + bsr.values().shape[-3:])
  176. size = batch_shape + bsr.shape[-2:]
  177. return torch.sparse_compressed_tensor(
  178. crow_indices, col_indices, values, size=size, layout=bsr.layout
  179. )
  180. # NOTE: this function will ALWAYS create a view
  181. def tile_to_blocksize(t, blocksize):
  182. *rest, m, n = t.shape
  183. new_shape = rest + [
  184. m // blocksize[0],
  185. blocksize[0],
  186. n // blocksize[1],
  187. blocksize[1],
  188. ]
  189. # using .view instead of .reshape to ensure that the result is
  190. # indeed a view:
  191. return t.view(new_shape).transpose(-3, -2)
  192. def as1Dbatch(tensor):
  193. """Return tensor as 3D tensor by either prepending new dimensions to
  194. the tensor shape (when ``tensor.ndim < 3``), or by collapsing
  195. starting dimensions into the first dimension (when ``tensor.ndim >
  196. 3``).
  197. """
  198. while tensor.ndim < 3:
  199. tensor = tensor.unsqueeze(0)
  200. if tensor.ndim > 3:
  201. tensor = tensor.flatten(0, tensor.ndim - 3)
  202. if tensor.ndim != 3:
  203. raise AssertionError(
  204. f"tensor should have 3 dimensions after reshape, got {tensor.shape}"
  205. )
  206. return tensor
  207. def scatter_mm(blocks, others, indices_data, *, accumulators=None):
  208. """Scattered matrix multiplication of tensors.
  209. A scattered matrix multiplication is defined as a series of matrix
  210. multiplications applied to input tensors according to the input
  211. and output mappings specified by indices data.
  212. The following indices data formats are supported for defining a
  213. scattered matrix multiplication operation (:attr:`indices_data[0]`
  214. holds the name of the indices data format as specified below):
  215. - ``"scatter_mm"`` - matrix multiplications scattered in batches
  216. of tensors.
  217. If :attr:`blocks` is a :math:`(* \times M \times K) tensor,
  218. :attr:`others` is a :math:`(* \times K \times N)` tensor,
  219. :attr:`accumulators` is a :math:`(* \times M \times N)` tensor,
  220. and :attr:`indices = indices_data['indices']` is a :math:`(*
  221. \times 3)` tensor, then the operation is equivalent to the
  222. following code::
  223. c_offsets, pq = indices_data[1:]
  224. for r in range(len(c_offsets) - 1):
  225. for g in range(c_offsets[r], c_offsets[r + 1]):
  226. p, q = pq[g]
  227. accumulators[r] += blocks[p] @ others[q]
  228. - ``"bsr_strided_mm"`` - matrix multiplications scattered in
  229. batches of tensors and a tensor.
  230. If :attr:`blocks` is a :math:`(Ms \times Ks) tensor,
  231. :attr:`others` is a :math:`(* \times K \times N)` tensor,
  232. :attr:`accumulators` is a :math:`(* \times M \times N)` tensor, then
  233. the operation is equivalent to the following code::
  234. c_indices, r_offsets, p_offsets, q_offsets, meta = indices_data[1:]
  235. for b in range(nbatches):
  236. for i, r in enumerate(r_offsets):
  237. r0, r1 = divmod(r, N)
  238. acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns]
  239. for g in range(c_indices[i], c_indices[i + 1]):
  240. p = p_offsets[g]
  241. q0, q1 = divmod(q_offsets[g], N)
  242. acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns]
  243. where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are
  244. integer multiples of ``Ms`` and ``Ks``, respectively.
  245. - ``"bsr_strided_mm_compressed"`` - matrix multiplications
  246. scattered in batches of tensors and a tensor. A memory and
  247. processor efficient version of ``"bsr_strided_mm"`` format. If
  248. :attr:`blocks` is a :math:`(Ms \times Ks) tensor, :attr:`others`
  249. is a :math:`(* \times K \times N)` tensor, :attr:`accumulators`
  250. is a :math:`(* \times M \times N)` tensor, then the operation is
  251. equivalent to the following code::
  252. c_indices, r_offsets, q_offsets, meta = indices_data[1:]
  253. for b in range(nbatches):
  254. for r in r_offsets:
  255. m = (r // N) // Ms
  256. n = (r % N) // Ns
  257. r0, r1 = divmod(r, N)
  258. c0, c1 = c_indices[m], c_indices[m + 1]
  259. acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns]
  260. for i, p in enumerate(range(c0, c1)):
  261. q = q_offsets[n * c1 + (SPLIT_N - n) * c0 + i]
  262. q0, q1 = divmod(q, N)
  263. acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns]
  264. where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are
  265. integer multiples of ``Ms`` and ``Ks``, respectively.
  266. Notice that the order of ``r_offsets`` items can be arbitrary;
  267. this property enables defining swizzle operators via
  268. rearrangements of ``r_offsets`` items..
  269. Auxiliary functions are provided for pre-computing
  270. :attr:`indices_data`. For example,
  271. :func:`bsr_scatter_mm_indices_data` is used to define indices data
  272. for matrix multiplication of BSR and strided tensors.
  273. Parameters
  274. ----------
  275. blocks (Tensor): a 3-D tensor of first matrices to be multiplied
  276. others (Tensor): a tensor of second matrices to be multiplied. If
  277. ``indices_data[0]=="scatter_mm"``, the tensor is a 1-D batch
  278. tensor of second input matrices to be multiplied. Otherwise, the
  279. second input matrices are slices of the :attr:`others` tensor.
  280. indices_data (tuple): a format data that defines the inputs and
  281. outputs of scattered matrix multiplications.
  282. Keyword arguments
  283. -----------------
  284. accumulators (Tensor, optional): a tensor of matrix product
  285. accumulators. If ``indices_data[0]=="scatter_mm"``, the tensor
  286. is a 1-D batch tensor of output matrices. Otherwise, output
  287. matrices are slices of the :attr:`accumulators` tensor.
  288. """
  289. indices_format = indices_data[0]
  290. if blocks.ndim != 3:
  291. raise AssertionError(f"blocks must be 3D, got {blocks.ndim}D")
  292. _P, Ms, Ks = blocks.shape
  293. if indices_format == "scatter_mm":
  294. c_offsets, pq = indices_data[1:]
  295. if others.ndim != 3:
  296. raise AssertionError(f"others must be 3D, got {others.ndim}D")
  297. _Q, Ks_, Ns = others.shape
  298. if Ks != Ks_:
  299. raise AssertionError(f"blocks K ({Ks}) != others K ({Ks_})")
  300. if accumulators is None:
  301. R = c_offsets.shape[0] - 1
  302. accumulators = torch.zeros(
  303. (R, Ms, Ns), dtype=blocks.dtype, device=blocks.device
  304. )
  305. else:
  306. R, Ms_, Ns_ = accumulators.shape
  307. if Ms_ != Ms:
  308. raise AssertionError(f"accumulators Ms ({Ms_}) != blocks Ms ({Ms})")
  309. if Ns_ != Ns:
  310. raise AssertionError(f"accumulators Ns ({Ns_}) != others Ns ({Ns})")
  311. if Ms % 16 or Ks % 16 or Ns % 16 or _scatter_mm2 is None:
  312. for r in range(c_offsets.shape[0] - 1):
  313. g0 = c_offsets[r]
  314. g1 = c_offsets[r + 1]
  315. for g in range(g0, g1):
  316. p, q = pq[g]
  317. accumulators[r] += blocks[p] @ others[q]
  318. else:
  319. _scatter_mm2(blocks, others, c_offsets, pq, accumulators)
  320. return accumulators
  321. elif indices_format == "bsr_strided_mm":
  322. others_shape = others.shape
  323. others = as1Dbatch(others)
  324. B, K, N = others.shape
  325. if K % Ks != 0:
  326. raise AssertionError(f"K ({K}) must be divisible by Ks ({Ks})")
  327. c_indices, r_offsets, p_offsets, q_offsets, meta = indices_data[1:]
  328. SPLIT_N = meta["SPLIT_N"]
  329. if accumulators is None:
  330. M = Ms + (r_offsets.max().item() + 1) // N
  331. accumulators = torch.zeros(
  332. (*others_shape[:-2], M, N), dtype=blocks.dtype, device=blocks.device
  333. )
  334. else:
  335. M, N_ = accumulators.shape[-2:]
  336. if N_ != N:
  337. raise AssertionError(f"accumulators N ({N_}) != others N ({N})")
  338. accumulators_shape = accumulators.shape
  339. accumulators = as1Dbatch(accumulators)
  340. Ns = N // SPLIT_N
  341. if Ms % 16 or Ks % 16 or Ns % 16 or _scatter_mm6 is None:
  342. accumulators.zero_()
  343. for b in range(B):
  344. for r in range(r_offsets.shape[0]):
  345. r_ = r_offsets[r].item()
  346. g0 = c_indices[r].item()
  347. g1 = c_indices[r + 1].item()
  348. r0, r1 = divmod(r_, N)
  349. acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns]
  350. for g in range(g0, g1):
  351. p, q = p_offsets[g], q_offsets[g]
  352. q0, q1 = divmod(q.item(), N)
  353. acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns]
  354. else:
  355. _scatter_mm6(
  356. blocks,
  357. others,
  358. c_indices,
  359. r_offsets,
  360. p_offsets,
  361. q_offsets,
  362. meta,
  363. accumulators,
  364. )
  365. return accumulators.view(accumulators_shape)
  366. elif indices_format == "bsr_strided_mm_compressed":
  367. others_shape = others.shape
  368. others = as1Dbatch(others)
  369. B, K, N = others.shape
  370. if K % Ks != 0:
  371. raise AssertionError(f"K ({K}) must be divisible by Ks ({Ks})")
  372. c_indices, r_offsets, q_offsets, meta = indices_data[1:]
  373. SPLIT_N = meta["SPLIT_N"]
  374. if accumulators is None:
  375. M = Ms + (r_offsets.max().item() + 1) // N
  376. accumulators = torch.zeros(
  377. (*others_shape[:-2], M, N), dtype=blocks.dtype, device=blocks.device
  378. )
  379. else:
  380. M, N_ = accumulators.shape[-2:]
  381. if N_ != N:
  382. raise AssertionError(f"accumulators N ({N_}) != others N ({N})")
  383. accumulators_shape = accumulators.shape
  384. accumulators = as1Dbatch(accumulators)
  385. Ns = N // SPLIT_N
  386. if Ms % 16 or Ks % 16 or Ns % 16 or _scatter_mm6 is None:
  387. for b in range(B):
  388. for j in range(len(r_offsets)):
  389. r0, r1 = divmod(r_offsets[j].item(), N)
  390. m = r0 // Ms
  391. n = r1 // Ns
  392. c0 = c_indices[m].item()
  393. c1 = c_indices[m + 1].item()
  394. acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns]
  395. for i, p in enumerate(range(c0, c1)):
  396. q = q_offsets[n * c1 + (SPLIT_N - n) * c0 + i].item()
  397. q0, q1 = divmod(q, N)
  398. acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns]
  399. else:
  400. p_offsets = torch.empty(
  401. (0,), dtype=q_offsets.dtype, device=q_offsets.device
  402. )
  403. _scatter_mm6(
  404. blocks,
  405. others,
  406. c_indices,
  407. r_offsets,
  408. p_offsets,
  409. q_offsets,
  410. meta,
  411. accumulators,
  412. )
  413. return accumulators.view(accumulators_shape)
  414. else:
  415. raise NotImplementedError(indices_format)
  416. def scatter_mm_meta(
  417. M,
  418. K,
  419. N,
  420. Ms,
  421. Ks,
  422. GROUP_SIZE=None,
  423. TILE_M=None,
  424. TILE_N=None,
  425. SPLIT_N=None,
  426. num_warps=None,
  427. num_stages=None,
  428. **extra,
  429. ):
  430. if {TILE_M, TILE_N, SPLIT_N, num_warps, num_stages, GROUP_SIZE} == {None}:
  431. device_name = torch.cuda.get_device_name()
  432. meta = get_meta(
  433. "scatter_mm",
  434. (M, K, N, Ms, Ks),
  435. device_name,
  436. version=(0, torch.float16, 0.5),
  437. )
  438. if meta is not None:
  439. meta.update(**extra)
  440. return meta
  441. # The following parameters are optimized for the performance
  442. # equilibrium points of bsr-dense and dense-dense matrix
  443. # multiplications when using GPU card NVIDIA GeForce RTX 2060
  444. # SUPER. For points far from the performance equilibrium
  445. # points as well as for other GPU cards, the optimal
  446. # parameters are likely different from what specified below.
  447. if (M, K, N) == (256,) * 3:
  448. if (Ms, Ks) == (16, 16):
  449. SPLIT_N = 1
  450. TILE_M = 16
  451. TILE_N = 16
  452. GROUP_SIZE = 4
  453. num_stages = 1
  454. num_warps = 4 # noqa: E225,E231,E702
  455. elif (Ms, Ks) == (32, 32):
  456. SPLIT_N = 2
  457. TILE_M = 32
  458. TILE_N = 16
  459. GROUP_SIZE = 4
  460. num_stages = 1
  461. num_warps = 4 # noqa: E225,E231,E702
  462. elif (Ms, Ks) == (64, 64):
  463. SPLIT_N = 1
  464. TILE_M = 32
  465. TILE_N = 32
  466. GROUP_SIZE = 4
  467. num_stages = 1
  468. num_warps = 4 # noqa: E225,E231,E702
  469. elif (Ms, Ks) == (128, 128):
  470. SPLIT_N = 1
  471. TILE_M = 32
  472. TILE_N = 32
  473. GROUP_SIZE = 2
  474. num_stages = 1
  475. num_warps = 4 # noqa: E225,E231,E702
  476. elif (M, K, N) == (512,) * 3:
  477. if (Ms, Ks) == (16, 16):
  478. SPLIT_N = 8
  479. TILE_M = 16
  480. TILE_N = 64
  481. GROUP_SIZE = 2
  482. num_stages = 1
  483. num_warps = 2 # noqa: E225,E231,E702
  484. elif (Ms, Ks) == (32, 32):
  485. SPLIT_N = 8
  486. TILE_M = 32
  487. TILE_N = 64
  488. GROUP_SIZE = 4
  489. num_stages = 1
  490. num_warps = 2 # noqa: E225,E231,E702
  491. elif (Ms, Ks) == (64, 64):
  492. SPLIT_N = 4
  493. TILE_M = 32
  494. TILE_N = 128
  495. GROUP_SIZE = 4
  496. num_stages = 1
  497. num_warps = 4 # noqa: E225,E231,E702
  498. elif (Ms, Ks) == (128, 128):
  499. SPLIT_N = 8
  500. TILE_M = 64
  501. TILE_N = 64
  502. GROUP_SIZE = 4
  503. num_stages = 1
  504. num_warps = 4 # noqa: E225,E231,E702
  505. elif (M, K, N) == (1024,) * 3:
  506. if (Ms, Ks) == (16, 16):
  507. SPLIT_N = 4
  508. TILE_M = 16
  509. TILE_N = 128
  510. GROUP_SIZE = 2
  511. num_stages = 1
  512. num_warps = 1 # noqa: E225,E231,E702
  513. elif (Ms, Ks) == (32, 32):
  514. SPLIT_N = 8
  515. TILE_M = 32
  516. TILE_N = 64
  517. GROUP_SIZE = 2
  518. num_stages = 1
  519. num_warps = 1 # noqa: E225,E231,E702
  520. elif (Ms, Ks) == (64, 64):
  521. SPLIT_N = 16
  522. TILE_M = 64
  523. TILE_N = 64
  524. GROUP_SIZE = 4
  525. num_stages = 1
  526. num_warps = 2 # noqa: E225,E231,E702
  527. elif (Ms, Ks) == (128, 128):
  528. SPLIT_N = 16
  529. TILE_M = 64
  530. TILE_N = 64
  531. GROUP_SIZE = 4
  532. num_stages = 1
  533. num_warps = 4 # noqa: E225,E231,E702
  534. elif (Ms, Ks) == (256, 256):
  535. SPLIT_N = 16
  536. TILE_M = 64
  537. TILE_N = 64
  538. GROUP_SIZE = 2
  539. num_stages = 1
  540. num_warps = 4 # noqa: E225,E231,E702
  541. elif (M, K, N) == (2048,) * 3:
  542. if (Ms, Ks) == (16, 16):
  543. SPLIT_N = 4
  544. TILE_M = 16
  545. TILE_N = 128
  546. GROUP_SIZE = 8
  547. num_stages = 1
  548. num_warps = 1 # noqa: E225,E231,E702
  549. elif (Ms, Ks) == (32, 32):
  550. SPLIT_N = 4
  551. TILE_M = 32
  552. TILE_N = 64
  553. GROUP_SIZE = 4
  554. num_stages = 1
  555. num_warps = 1 # noqa: E225,E231,E702
  556. elif (Ms, Ks) == (64, 64):
  557. SPLIT_N = 4
  558. TILE_M = 64
  559. TILE_N = 128
  560. GROUP_SIZE = 4
  561. num_stages = 1
  562. num_warps = 4 # noqa: E225,E231,E702
  563. elif (Ms, Ks) == (128, 128):
  564. SPLIT_N = 8
  565. TILE_M = 64
  566. TILE_N = 64
  567. GROUP_SIZE = 4
  568. num_stages = 1
  569. num_warps = 4 # noqa: E225,E231,E702
  570. elif (Ms, Ks) == (256, 256):
  571. SPLIT_N = 4
  572. TILE_M = 64
  573. TILE_N = 64
  574. GROUP_SIZE = 2
  575. num_stages = 1
  576. num_warps = 4 # noqa: E225,E231,E702
  577. elif (M, K, N) == (4096,) * 3:
  578. if (Ms, Ks) == (16, 16):
  579. SPLIT_N = 2
  580. TILE_M = 16
  581. TILE_N = 256
  582. GROUP_SIZE = 2
  583. num_stages = 1
  584. num_warps = 2 # noqa: E225,E231,E702
  585. elif (Ms, Ks) == (32, 32):
  586. SPLIT_N = 2
  587. TILE_M = 32
  588. TILE_N = 64
  589. GROUP_SIZE = 2
  590. num_stages = 1
  591. num_warps = 1 # noqa: E225,E231,E702
  592. elif (Ms, Ks) == (64, 64):
  593. SPLIT_N = 2
  594. TILE_M = 64
  595. TILE_N = 128
  596. GROUP_SIZE = 2
  597. num_stages = 1
  598. num_warps = 4 # noqa: E225,E231,E702
  599. if SPLIT_N is None:
  600. # Assume NVIDIA GeForce RTX 2060 SUPER:
  601. # With the probality of 92% (99.9% when N > 512), the
  602. # performance will not be worse more than 2% from the
  603. # performance when using an optimal value. Otherwise, when N
  604. # <= 512, using the following heuristics may give upto 15%
  605. # lower performance.
  606. SPLIT_N = {
  607. 16: 1,
  608. 32: 2,
  609. 64: 4,
  610. 128: 8,
  611. 256: 16,
  612. 512: 8,
  613. 1024: 16,
  614. 4096: 32,
  615. 8192: 64,
  616. }.get(N, 16)
  617. if Ms >= 512 and N >= 2048:
  618. SPLIT_N = 1
  619. Ns = N // SPLIT_N
  620. if TILE_M is None:
  621. TILE_M = min(64 if Ns < 512 else 32, Ms)
  622. if TILE_N is None:
  623. TILE_N = min(64 if Ns < 512 else 32, Ns)
  624. num_stages = num_stages or 1
  625. if num_warps is None:
  626. if min(M, N) > 1024:
  627. num_warps = {16: 1, 32: 1, 64: 2}.get(Ms, 4)
  628. elif min(M, N) == 1024:
  629. num_warps = {16: 1, 32: 1, 64: 2}.get(Ms, 4)
  630. elif min(M, N) == 256:
  631. num_warps = {16: 1, 32: 4}.get(Ms, 4)
  632. else:
  633. num_warps = {16: 1, 32: 2}.get(Ms, 4)
  634. GROUP_SIZE = GROUP_SIZE or 4
  635. if TILE_M > Ms:
  636. raise AssertionError(f"TILE_M ({TILE_M}) must be <= Ms ({Ms})")
  637. if TILE_N > Ns:
  638. raise AssertionError(f"TILE_N ({TILE_N}) must be <= Ns ({Ns})")
  639. if Ms > M:
  640. raise AssertionError(f"Ms ({Ms}) must be <= M ({M})")
  641. if Ns > N:
  642. raise AssertionError(f"Ns ({Ns}) must be <= N ({N})")
  643. if Ks > K:
  644. raise AssertionError(f"Ks ({Ks}) must be <= K ({K})")
  645. return dict(
  646. TILE_M=TILE_M,
  647. TILE_N=TILE_N,
  648. GROUP_SIZE=GROUP_SIZE,
  649. num_stages=num_stages,
  650. num_warps=num_warps,
  651. SPLIT_N=SPLIT_N,
  652. **extra,
  653. )
  654. def bsr_dense_addmm_meta(
  655. M,
  656. K,
  657. N,
  658. Ms,
  659. Ks,
  660. beta,
  661. alpha,
  662. SPLIT_N=None,
  663. GROUP_SIZE_ROW=None,
  664. num_warps=None,
  665. num_stages=None,
  666. sparsity=None,
  667. dtype=None,
  668. out_dtype=None,
  669. _version=0,
  670. **extra,
  671. ):
  672. # Specifying _version is useful for situations when one wants to
  673. # discard existing triton kernel tuning results, say, in testing
  674. # bsr_dense_addmm_meta functionality.
  675. if dtype is None:
  676. dtype = torch.float16
  677. if out_dtype is None:
  678. out_dtype = dtype
  679. if sparsity is None:
  680. sparsity = 0.5
  681. if {SPLIT_N, num_warps, num_stages, GROUP_SIZE_ROW} == {None}:
  682. device_name = torch.cuda.get_device_name()
  683. key = (M, K, N, Ms, Ks, beta == 0, beta == 1, alpha == 1)
  684. if dtype is out_dtype:
  685. version_dtype = dtype
  686. else:
  687. version_dtype = dtype, out_dtype
  688. meta = get_meta(
  689. "bsr_dense_addmm",
  690. key,
  691. device_name,
  692. version=(_version, version_dtype, sparsity),
  693. )
  694. if meta is None and sparsity != 0.5:
  695. meta = get_meta(
  696. "bsr_dense_addmm",
  697. key,
  698. device_name,
  699. version=(_version, version_dtype, 0.5),
  700. )
  701. if meta is None and dtype is not out_dtype:
  702. meta = get_meta(
  703. "bsr_dense_addmm", key, device_name, version=(_version, dtype, 0.5)
  704. )
  705. if meta is None:
  706. # find approximate meta such that N % SPLIT_N == 0.
  707. matching_meta = get_meta(
  708. "bsr_dense_addmm",
  709. (*key[:2], "*", *key[3:]),
  710. device_name,
  711. version=(_version, version_dtype, 0.5),
  712. )
  713. if matching_meta is None and dtype is not out_dtype:
  714. matching_meta = get_meta(
  715. "bsr_dense_addmm",
  716. (*key[:2], "*", *key[3:]),
  717. device_name,
  718. version=(_version, dtype, 0.5),
  719. )
  720. for mkey in sorted(matching_meta or {}):
  721. meta_ = matching_meta[mkey]
  722. n = mkey[2]
  723. split_n = meta_["SPLIT_N"]
  724. c = n // split_n
  725. if N % c == 0 and n <= N:
  726. meta = dict(meta_)
  727. meta["SPLIT_N"] = N // c
  728. if meta is not None:
  729. meta.update(**extra)
  730. return meta
  731. else:
  732. # see [Computing optimal kernel parameters] in
  733. # _triton_ops_meta.py for ways to avoid this warning
  734. # message
  735. warn_once(
  736. "bsr_dense_addmm uses non-optimal triton kernel parameters"
  737. f" for {M=} {K=} {N=} {Ms=}, {Ks=} {beta=} {alpha=} {dtype=} {out_dtype=}"
  738. )
  739. SPLIT_N = SPLIT_N or max(N // Ms, 1)
  740. GROUP_SIZE_ROW = GROUP_SIZE_ROW or 4
  741. num_stages = num_stages or 1
  742. num_warps = num_warps or 4
  743. return dict(
  744. SPLIT_N=SPLIT_N,
  745. GROUP_SIZE_ROW=GROUP_SIZE_ROW,
  746. num_stages=num_stages,
  747. num_warps=num_warps,
  748. **extra,
  749. )
  750. class TensorAsKey:
  751. """A light-weight wrapper of a tensor that enables storing tensors as
  752. keys with efficient memory reference based comparison as an
  753. approximation to data equality based keys.
  754. Motivation: the hash value of a torch tensor is tensor instance
  755. based that does not use data equality and makes the usage of
  756. tensors as keys less useful. For instance, the result of
  757. ``len({a.crow_indices(), a.crow_indices()})`` is `2`, although,
  758. the tensor results from `crow_indices` method call are equal, in
  759. fact, these share the same data storage.
  760. On the other hand, for efficient caching of tensors we want to
  761. avoid calling torch.equal that compares tensors item-wise.
  762. TensorAsKey offers a compromise in that it guarantees key equality
  763. of tensors that references data in the same storage in the same
  764. manner and without accessing underlying data. However, this
  765. approach does not always guarantee correctness. For instance, for
  766. a complex tensor ``x``, we have ``TensorAsKey(x) ==
  767. TensorAsKey(x.conj())`` while ``torch.equal(x, x.conj())`` would
  768. return False.
  769. """
  770. def __init__(self, obj):
  771. def get_tensor_key(obj):
  772. # Warning: TensorAsKey does not track negative nor
  773. # conjugate bits of its input object because in the use
  774. # case of wrapping compressed/plain indices of compressed
  775. # sparse tensors (that are always integer tensors with
  776. # non-negative items) these bits are never set. However,
  777. # when extending the use of TensorAsKey to float or
  778. # complex tensors, the values of these bits (see is_neg
  779. # and is_conj methods) must be included in the key as
  780. # well.
  781. if obj.dtype.is_floating_point or obj.dtype.is_complex:
  782. raise AssertionError(
  783. f"TensorAsKey does not support floating point or complex dtype: {obj.dtype}"
  784. )
  785. return (
  786. obj.data_ptr(),
  787. obj.storage_offset(),
  788. obj.shape,
  789. obj.stride(),
  790. obj.dtype,
  791. )
  792. self._obj_ref = weakref.ref(obj)
  793. if obj.layout is torch.strided:
  794. self.key = get_tensor_key(obj)
  795. elif obj.layout in {torch.sparse_csr, torch.sparse_bsr}:
  796. self.key = (
  797. get_tensor_key(obj.crow_indices()),
  798. get_tensor_key(obj.col_indices()),
  799. )
  800. elif obj.layout in {torch.sparse_csc, torch.sparse_bsc}:
  801. self.key = (
  802. get_tensor_key(obj.ccol_indices()),
  803. get_tensor_key(obj.row_indices()),
  804. )
  805. else:
  806. raise NotImplementedError(obj.layout)
  807. self._hash = hash(self.key)
  808. def __hash__(self):
  809. return self._hash
  810. def __eq__(self, other):
  811. if not isinstance(other, TensorAsKey):
  812. return False
  813. if self.obj is None or other.obj is None:
  814. # dead objects always compare unequal unless these are
  815. # same objects
  816. return self is other
  817. return self.key == other.key
  818. @property
  819. def obj(self):
  820. """Return object if alive, otherwise None."""
  821. return self._obj_ref()
  822. @lru_cache(maxsize=TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE)
  823. def _bsr_scatter_mm_indices_data(
  824. indices_format, M, K, N, Ms, Ks, nbatches, SPLIT_N, compressed_sparse_tensor_as_key
  825. ):
  826. bsr = compressed_sparse_tensor_as_key.obj
  827. if bsr is None:
  828. raise AssertionError("compressed_sparse_tensor_as_key.obj is None")
  829. crow_indices, col_indices = bsr.crow_indices(), bsr.col_indices()
  830. device = crow_indices.device
  831. indices_dtype = torch.int32
  832. if indices_format == "bsr_strided_mm_compressed":
  833. Ns = N // SPLIT_N
  834. q_offsets_lst = []
  835. b = torch.arange(SPLIT_N, dtype=indices_dtype, device=device) * Ns
  836. for m in range(M // Ms):
  837. r0 = crow_indices[m].item()
  838. r1 = crow_indices[m + 1].item()
  839. if r1 == r0:
  840. continue
  841. q_offsets_lst.append(
  842. (col_indices[r0:r1] * (Ks * N)).repeat(SPLIT_N)
  843. + b.repeat_interleave(r1 - r0)
  844. )
  845. q_offsets = torch.cat(q_offsets_lst)
  846. crow_indices_diff = crow_indices.diff()
  847. non_zero_row_indices = crow_indices_diff.nonzero()
  848. a = non_zero_row_indices * (Ms * N)
  849. r_offsets = (a + b).view(-1)
  850. c_indices = crow_indices
  851. # swizzle operation: mm elements with longer sums are computed first:
  852. nnz_per_row = crow_indices_diff[non_zero_row_indices].repeat_interleave(SPLIT_N)
  853. nnz_per_row, indices = nnz_per_row.sort(descending=True, stable=True)
  854. r_offsets = r_offsets[indices]
  855. return (indices_format, c_indices, r_offsets, q_offsets)
  856. elif indices_format == "bsr_strided_mm":
  857. Ns = N // SPLIT_N
  858. p_offsets_lst = []
  859. q_offsets_lst = []
  860. b = torch.arange(SPLIT_N, dtype=indices_dtype, device=device) * Ns
  861. for m in range(M // Ms):
  862. r0 = crow_indices[m].item()
  863. r1 = crow_indices[m + 1].item()
  864. if r1 == r0:
  865. continue
  866. p_offsets_lst.append(
  867. torch.arange(r0, r1, dtype=indices_dtype, device=device).repeat(SPLIT_N)
  868. )
  869. q_offsets_lst.append(
  870. (col_indices[r0:r1] * (Ks * N)).repeat(SPLIT_N)
  871. + b.repeat_interleave(r1 - r0)
  872. )
  873. q_offsets = torch.cat(q_offsets_lst)
  874. crow_indices_diff = crow_indices.diff()
  875. non_zero_row_indices = crow_indices_diff.nonzero()
  876. a = non_zero_row_indices * (Ms * N)
  877. r_offsets = (a + b).view(-1)
  878. c_indices = torch.cat(
  879. (
  880. crow_indices[:1],
  881. torch.cumsum(
  882. crow_indices_diff[non_zero_row_indices].repeat_interleave(SPLIT_N),
  883. 0,
  884. ),
  885. )
  886. )
  887. p_offsets = torch.cat(p_offsets_lst)
  888. return (indices_format, c_indices, r_offsets, p_offsets, q_offsets)
  889. elif indices_format == "scatter_mm":
  890. Ns = Ms
  891. c_indices = [0]
  892. pq_offsets = []
  893. # todo: eliminate inner for-loops for efficiency
  894. for b in range(nbatches):
  895. for m in range(M // Ms):
  896. r0 = crow_indices[m].item()
  897. r1 = crow_indices[m + 1].item()
  898. for n in range(N // Ns):
  899. c_indices.append(c_indices[-1] + r1 - r0)
  900. for t in range(r1 - r0):
  901. p = r0 + t
  902. q = (col_indices[p].item() + b * (K // Ks)) * (N // Ns) + n
  903. pq_offsets.append([p, q])
  904. return (
  905. indices_format,
  906. torch.tensor(c_indices, dtype=indices_dtype, device=device),
  907. torch.tensor(pq_offsets, dtype=indices_dtype, device=device),
  908. )
  909. else:
  910. raise ValueError(
  911. f"Invalid {indices_format=}. Expected bsr_strided_mm_compressed|bsr_strided_mm|scatter_mm"
  912. )
  913. def bsr_scatter_mm_indices_data(
  914. bsr, other, indices_format="bsr_strided_mm_compressed", **meta_input
  915. ):
  916. """Computes indices data for :func:`scatter_mm` used in BSR and
  917. strided tensor matrix multiplication.
  918. """
  919. if bsr.dense_dim() != 0:
  920. raise AssertionError(f"bsr.dense_dim() must be 0, got {bsr.dense_dim()}")
  921. if bsr.ndim != 2:
  922. raise AssertionError(f"bsr must be 2D (no batch dims), got {bsr.ndim}D")
  923. blocksize = bsr.values().shape[-2:]
  924. M, K = bsr.shape
  925. Ms, Ks = blocksize
  926. K_, N = other.shape[-2:]
  927. if K_ != K:
  928. raise AssertionError(f"other K ({K_}) != bsr K ({K})")
  929. nbatches = other.shape[:-2].numel()
  930. meta = scatter_mm_meta(M, K, N, Ms, Ks, **meta_input)
  931. if "allow_tf32" not in meta_input:
  932. meta.update(allow_tf32=bsr.dtype in {torch.float16, torch.bfloat16})
  933. SPLIT_N = meta["SPLIT_N"]
  934. indices_data = _bsr_scatter_mm_indices_data(
  935. indices_format, M, K, N, Ms, Ks, nbatches, SPLIT_N, TensorAsKey(bsr)
  936. )
  937. if indices_format == "bsr_strided_mm_compressed":
  938. meta.update(is_compressed=True)
  939. return indices_data + (meta,)
  940. elif indices_format == "bsr_strided_mm":
  941. meta.update(is_compressed=False)
  942. return indices_data + (meta,)
  943. else:
  944. return indices_data
  945. def bsr_scatter_mm(bsr, other, indices_data=None, out=None):
  946. """BSR @ strided -> strided"""
  947. if bsr.ndim != 2:
  948. raise AssertionError(f"bsr must be 2D, got {bsr.ndim}D")
  949. if other.ndim < 2:
  950. raise AssertionError(
  951. f"other must have at least 2 dimensions, got {other.ndim}D"
  952. )
  953. Ms, Ks, Ns = bsr.shape[-2], bsr.shape[-1], other.shape[-1]
  954. blocksize = bsr.values().shape[-2:]
  955. if indices_data is None:
  956. indices_data = bsr_scatter_mm_indices_data(
  957. bsr, other, indices_format="bsr_strided_mm_compressed"
  958. )
  959. indices_format = indices_data[0]
  960. if out is None:
  961. out = torch.empty(
  962. (*other.shape[:-2], Ms, Ns), dtype=bsr.dtype, device=bsr.device
  963. )
  964. out_shape = out.shape
  965. out = as1Dbatch(out)
  966. if bsr._nnz() == 0:
  967. out.zero_()
  968. elif indices_format in {"bsr_strided_mm_compressed", "bsr_strided_mm"}:
  969. out.zero_()
  970. scatter_mm(bsr.values(), other, indices_data, accumulators=out)
  971. elif indices_format == "scatter_mm":
  972. nbatches = other.shape[:-2].numel()
  973. accumulators = torch.zeros(
  974. (
  975. nbatches * Ms // blocksize[0] * Ns // blocksize[0],
  976. blocksize[0],
  977. blocksize[0],
  978. ),
  979. dtype=bsr.dtype,
  980. device=bsr.device,
  981. )
  982. others = (
  983. as1Dbatch(other)
  984. .transpose(-2, -1)
  985. .view(
  986. nbatches,
  987. Ns // blocksize[0],
  988. blocksize[0],
  989. Ks // blocksize[1],
  990. blocksize[1],
  991. )
  992. .movedim(
  993. (3, 1, 4, 2), (1, 2, 3, 4)
  994. ) # equivalent to .transpose(-3, -2).transpose(-2, -1).transpose(-4, -3)
  995. .flatten(0, 2)
  996. )
  997. scatter_mm(bsr.values(), others, indices_data, accumulators=accumulators)
  998. out.copy_(
  999. accumulators.unflatten(
  1000. 0, (nbatches, Ms // blocksize[0], Ns // blocksize[0])
  1001. )
  1002. .movedim(
  1003. (1, 2, 3, 4), (3, 1, 4, 2)
  1004. ) # equivalent to .transpose(-4, -3).transpose(-2, -1).transpose(-3, -2)
  1005. .reshape(nbatches, Ns, Ms)
  1006. .transpose(-2, -1)
  1007. )
  1008. else:
  1009. raise NotImplementedError(indices_format)
  1010. return out.view(out_shape)
  1011. def _int_bsr_dense_addmm(
  1012. input: torch.Tensor,
  1013. bsr: torch.Tensor,
  1014. dense: torch.Tensor,
  1015. *,
  1016. beta=1,
  1017. alpha=1,
  1018. left_alpha: torch.Tensor | None = None,
  1019. right_alpha: torch.Tensor | None = None,
  1020. out: torch.Tensor | None = None,
  1021. skip_checks: bool = False,
  1022. max_grid: tuple[int | None, int | None, int | None] | None = None,
  1023. meta: dict | None = None,
  1024. ):
  1025. if out is None and dense.dtype is torch.int8:
  1026. f_name = "_int_bsr_dense_addmm"
  1027. crow_indices = bsr.crow_indices()
  1028. batch_ndim = crow_indices.dim() - 1
  1029. M = bsr.shape[batch_ndim]
  1030. N = dense.shape[-1]
  1031. original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense)
  1032. out = torch.empty(
  1033. original_batch_dims_broadcasted + (M, N),
  1034. dtype=torch.int32,
  1035. device=dense.device,
  1036. )
  1037. return bsr_dense_addmm(
  1038. input,
  1039. bsr,
  1040. dense,
  1041. beta=beta,
  1042. alpha=alpha,
  1043. left_alpha=left_alpha,
  1044. right_alpha=right_alpha,
  1045. out=out,
  1046. skip_checks=skip_checks,
  1047. max_grid=max_grid,
  1048. meta=meta,
  1049. )
  1050. def bsr_dense_addmm(
  1051. input: torch.Tensor,
  1052. bsr: torch.Tensor,
  1053. dense: torch.Tensor,
  1054. *,
  1055. beta=1,
  1056. alpha=1,
  1057. left_alpha: torch.Tensor | None = None,
  1058. right_alpha: torch.Tensor | None = None,
  1059. out: torch.Tensor | None = None,
  1060. skip_checks: bool = False,
  1061. max_grid: tuple[int | None, int | None, int | None] | None = None,
  1062. meta: dict | None = None,
  1063. ):
  1064. """Compute
  1065. out = beta * input + left_alpha.reshape(-1, 1) * (alpha * (bsr @ dense)) * right_alpha.reshape(1, -1)
  1066. where left_alpha, right_alpha are (* + 1)-D tensors when
  1067. specified, otherwise, these are treated as tensors filled with
  1068. ones.
  1069. """
  1070. f_name = "bsr_dense_addmm"
  1071. values = bsr.values()
  1072. crow_indices = bsr.crow_indices()
  1073. col_indices = bsr.col_indices()
  1074. batch_ndim = crow_indices.dim() - 1
  1075. M, K = bsr.shape[batch_ndim : batch_ndim + 2]
  1076. blocksize = values.shape[batch_ndim + 1 : batch_ndim + 3]
  1077. N = dense.shape[-1]
  1078. # todo: implement checks
  1079. original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense)
  1080. if out is None:
  1081. out = dense.new_empty(original_batch_dims_broadcasted + (M, N))
  1082. if bsr._nnz() == 0 or alpha == 0 or N == 0 or M == 0 or K == 0:
  1083. if beta == 0:
  1084. out.zero_()
  1085. else:
  1086. out.copy_(input)
  1087. if beta != 1:
  1088. out.mul_(beta)
  1089. return out
  1090. left_alpha_is_one = False
  1091. right_alpha_is_one = False
  1092. if left_alpha is None:
  1093. left_alpha_is_one = True
  1094. left_alpha = dense.new_empty(()).expand(
  1095. *original_batch_dims_broadcasted, M, N
  1096. ) # not referenced
  1097. else:
  1098. left_alpha = left_alpha.view(*original_batch_dims_broadcasted, M, 1).expand(
  1099. *original_batch_dims_broadcasted, M, N
  1100. )
  1101. if right_alpha is None:
  1102. right_alpha_is_one = True
  1103. right_alpha = dense.new_empty(()).expand(
  1104. *original_batch_dims_broadcasted, M, N
  1105. ) # not referenced
  1106. else:
  1107. right_alpha = right_alpha.view(*original_batch_dims_broadcasted, 1, N).expand(
  1108. *original_batch_dims_broadcasted, M, N
  1109. )
  1110. if left_alpha.stride()[-1] != 0:
  1111. raise AssertionError(
  1112. f"left_alpha.stride()[-1] must be 0, got {left_alpha.stride()[-1]}"
  1113. )
  1114. if right_alpha.stride()[-2] != 0:
  1115. raise AssertionError(
  1116. f"right_alpha.stride()[-2] must be 0, got {right_alpha.stride()[-2]}"
  1117. )
  1118. if meta is None:
  1119. sparsity = round(1 - bsr._nnz() * blocksize[0] * blocksize[1] / (M * K), 2)
  1120. meta = bsr_dense_addmm_meta(
  1121. M,
  1122. K,
  1123. N,
  1124. blocksize[0],
  1125. blocksize[1],
  1126. beta,
  1127. alpha,
  1128. sparsity=sparsity,
  1129. dtype=dense.dtype,
  1130. out_dtype=out.dtype,
  1131. )
  1132. out_backup = out
  1133. (
  1134. crow_indices,
  1135. col_indices,
  1136. values,
  1137. input,
  1138. dense,
  1139. left_alpha,
  1140. right_alpha,
  1141. out,
  1142. ) = prepare_inputs(bsr, input, dense, left_alpha, right_alpha, out)
  1143. BM, BK = blocksize
  1144. SPLIT_N = meta.get("SPLIT_N", N // BM)
  1145. BN = N // SPLIT_N
  1146. out_untiled = out
  1147. out = tile_to_blocksize(out, (BM, BN))
  1148. dense = tile_to_blocksize(dense, (BK, BN))
  1149. input = tile_to_blocksize(input, (BM, BN))
  1150. left_alpha = tile_to_blocksize(left_alpha, (BM, BN))
  1151. right_alpha = tile_to_blocksize(right_alpha, (BM, BN))
  1152. # tl.dot supports float16, float32, int32 as accumulator types.
  1153. dot_out_dtype = {
  1154. torch.float16: tl.float32,
  1155. torch.bfloat16: tl.float32,
  1156. torch.float32: tl.float64,
  1157. torch.float64: tl.float64,
  1158. torch.int8: tl.int32,
  1159. torch.int32: tl.int32,
  1160. }[out.dtype]
  1161. n_batches = dense.size(0)
  1162. n_block_rows = crow_indices.size(-1) - 1
  1163. n_block_cols = dense.size(-3)
  1164. full_grid = (n_batches, n_block_cols, n_block_rows)
  1165. if max_grid is not None:
  1166. grid_blocks = tuple(max_grid[:3][::-1]) + (None,) * (3 - len(max_grid[:3]))
  1167. else:
  1168. grid_blocks = None
  1169. tensor_dims_map = {
  1170. values: (0, None, None),
  1171. crow_indices: (0, None, -1),
  1172. col_indices: (0, None, None),
  1173. input: (0, -3, -4),
  1174. dense: (0, -3, None),
  1175. left_alpha: (0, -3, -4),
  1176. right_alpha: (0, -3, -4),
  1177. out: (0, -3, -4),
  1178. }
  1179. if alpha == 0:
  1180. raise AssertionError("alpha must not be 0")
  1181. def kernel(grid, *sliced_tensors):
  1182. # pyrefly: ignore [unsupported-operation]
  1183. _bsr_strided_addmm_kernel[grid](
  1184. *ptr_stride_extractor(*sliced_tensors),
  1185. # pyrefly: ignore [bad-argument-count]
  1186. beta,
  1187. alpha,
  1188. # pyrefly: ignore [bad-keyword-argument, bad-argument-type]
  1189. beta_is_one=beta == 1,
  1190. # pyrefly: ignore [bad-keyword-argument, bad-argument-type]
  1191. beta_is_nonzero=beta != 0,
  1192. # pyrefly: ignore [bad-keyword-argument, bad-argument-type]
  1193. alpha_is_one=alpha == 1,
  1194. # pyrefly: ignore [bad-keyword-argument, bad-argument-type]
  1195. left_alpha_is_one=left_alpha_is_one,
  1196. # pyrefly: ignore [bad-keyword-argument, bad-argument-type]
  1197. right_alpha_is_one=right_alpha_is_one,
  1198. # pyrefly: ignore [bad-keyword-argument, bad-argument-type]
  1199. BLOCKSIZE_ROW=BM,
  1200. # pyrefly: ignore [bad-keyword-argument, bad-argument-type]
  1201. BLOCKSIZE_INNER=BK,
  1202. # pyrefly: ignore [bad-keyword-argument]
  1203. BLOCKSIZE_COL=BN,
  1204. # pyrefly: ignore [bad-keyword-argument, bad-argument-type]
  1205. allow_tf32=dot_out_dtype == tl.float32,
  1206. # pyrefly: ignore [bad-keyword-argument, bad-argument-type]
  1207. acc_dtype=dot_out_dtype,
  1208. **meta,
  1209. )
  1210. launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks)
  1211. if out.data_ptr() != out_backup.data_ptr():
  1212. # prepare_inputs has made a copy of out, copy its content back
  1213. # to out_backup:
  1214. out_backup.copy_(out_untiled.view(out_backup.shape))
  1215. return out_backup
  1216. if has_triton():
  1217. import triton
  1218. import triton.language as tl
  1219. @triton.jit
  1220. def _sampled_addmm_kernel(
  1221. alpha,
  1222. beta,
  1223. IS_BETA_ZERO: tl.constexpr,
  1224. BLOCKSIZE_ROW: tl.constexpr,
  1225. BLOCKSIZE_COL: tl.constexpr,
  1226. k,
  1227. TILE_K: tl.constexpr,
  1228. values_ptr,
  1229. values_batch_stride,
  1230. values_nnz_stride,
  1231. values_row_block_stride,
  1232. values_col_block_stride,
  1233. crow_indices_ptr,
  1234. crow_indices_batch_stride,
  1235. crow_indices_stride,
  1236. col_indices_ptr,
  1237. col_indices_batch_stride,
  1238. col_indices_stride,
  1239. mat1_ptr,
  1240. mat1_batch_stride,
  1241. mat1_tiled_row_stride,
  1242. mat1_tiled_col_stride,
  1243. mat1_row_block_stride,
  1244. mat1_col_block_stride,
  1245. mat2_ptr,
  1246. mat2_batch_stride,
  1247. mat2_tiled_row_stride,
  1248. mat2_tiled_col_stride,
  1249. mat2_row_block_stride,
  1250. mat2_col_block_stride,
  1251. acc_dtype: tl.constexpr,
  1252. allow_tf32: tl.constexpr,
  1253. ):
  1254. batch_pid = tl.program_id(axis=1)
  1255. row_block_pid = tl.program_id(axis=0)
  1256. crow_indices_offset_ptr = (
  1257. crow_indices_ptr
  1258. + crow_indices_batch_stride * batch_pid
  1259. + crow_indices_stride * row_block_pid
  1260. )
  1261. nnz_offset = tl.load(crow_indices_offset_ptr)
  1262. nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)
  1263. # Compute nnz for the row with number row_block_pid.
  1264. # If it is zero, skip the row.
  1265. row_nnz = nnz_offset_next - nnz_offset
  1266. if row_nnz == 0:
  1267. return
  1268. row_block_arange = tl.arange(0, BLOCKSIZE_ROW)
  1269. col_block_arange = tl.arange(0, BLOCKSIZE_COL)
  1270. # Pointers are set to the first block of the current row.
  1271. values_block_ptrs = (
  1272. values_ptr
  1273. + values_batch_stride * batch_pid
  1274. + values_nnz_stride * nnz_offset
  1275. + values_row_block_stride * row_block_arange[:, None]
  1276. + values_col_block_stride * col_block_arange[None, :]
  1277. )
  1278. col_index_nnz_ptr = (
  1279. col_indices_ptr
  1280. + col_indices_batch_stride * batch_pid
  1281. + col_indices_stride * nnz_offset
  1282. )
  1283. # Advance mat1 to the current tiled row, ignore columns.
  1284. mat1_block_ptrs = (
  1285. mat1_ptr
  1286. + mat1_batch_stride * batch_pid
  1287. + mat1_tiled_row_stride * row_block_pid
  1288. + mat1_row_block_stride * row_block_arange[:, None]
  1289. )
  1290. # Advance mat2 in batch and block col dimension.
  1291. mat2_block_ptrs = (
  1292. mat2_ptr
  1293. + mat2_batch_stride * batch_pid
  1294. + mat2_col_block_stride * col_block_arange[None, :]
  1295. )
  1296. k_tile_arange = tl.arange(0, TILE_K)
  1297. for _ in range(row_nnz):
  1298. acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype)
  1299. # find column block index
  1300. col_block = tl.load(col_index_nnz_ptr)
  1301. for k_tile in range(0, k, TILE_K):
  1302. k_offsets = k_tile + k_tile_arange
  1303. mask_k = k_offsets < k
  1304. mat1_block = tl.load(
  1305. mat1_block_ptrs + mat1_col_block_stride * k_offsets[None, :],
  1306. # pyrefly: ignore [bad-index, index-error]
  1307. mask=mask_k[None, :],
  1308. other=0.0,
  1309. )
  1310. mat2_block = tl.load(
  1311. mat2_block_ptrs
  1312. + mat2_tiled_col_stride * col_block
  1313. + mat2_row_block_stride * k_offsets[:, None],
  1314. # pyrefly: ignore [bad-index, index-error]
  1315. mask=mask_k[:, None],
  1316. other=0.0,
  1317. )
  1318. acc_block += tl.dot(
  1319. mat1_block, mat2_block, allow_tf32=allow_tf32, out_dtype=acc_dtype
  1320. )
  1321. if IS_BETA_ZERO:
  1322. acc_block *= alpha
  1323. else:
  1324. acc_block = alpha * acc_block + beta * tl.load(values_block_ptrs)
  1325. # write result
  1326. tl.store(values_block_ptrs, acc_block.to(values_ptr.dtype.element_ty))
  1327. # advance val/col_index ptrs to the next block in the row.
  1328. values_block_ptrs += values_nnz_stride
  1329. col_index_nnz_ptr += col_indices_stride
  1330. @triton.jit
  1331. def _bsr_strided_dense_rowspace_kernel(
  1332. # values prologue
  1333. values_ptr,
  1334. values_batch_stride,
  1335. values_nnz_stride,
  1336. values_row_block_stride,
  1337. values_col_block_stride,
  1338. # values epilogue
  1339. # crow_indices prologue
  1340. crow_indices_ptr,
  1341. crow_indices_batch_stride,
  1342. crow_indices_stride,
  1343. # crow_indices epilogue
  1344. # col_indices prologue
  1345. col_indices_ptr,
  1346. col_indices_batch_stride,
  1347. col_indices_stride,
  1348. # col_indices epilogue
  1349. # dense prologue
  1350. dense_ptr,
  1351. dense_batch_stride,
  1352. dense_tiled_row_stride,
  1353. dense_tiled_col_stride,
  1354. dense_row_block_stride,
  1355. dense_col_block_stride,
  1356. # dense epilogue
  1357. # output prologue
  1358. output_ptr,
  1359. output_batch_stride,
  1360. output_tiled_row_stride,
  1361. output_tiled_col_stride,
  1362. output_row_block_stride,
  1363. output_col_block_stride,
  1364. # output epilogue
  1365. #
  1366. # gh-113754: Always keep all constexpr arguments at the end of
  1367. # triton kernel arguments list because with triton 2.1 or
  1368. # earlier non-contiguous outputs will corrupt CUDA state due
  1369. # to a triton bug (fixed in openai/triton#2262).
  1370. BLOCKSIZE_ROW: tl.constexpr,
  1371. BLOCKSIZE_COL: tl.constexpr,
  1372. acc_dtype: tl.constexpr,
  1373. allow_tf32: tl.constexpr,
  1374. GROUP_SIZE_ROW: tl.constexpr,
  1375. ):
  1376. batch_pid = tl.program_id(axis=2)
  1377. row_block_pid = tl.program_id(axis=0)
  1378. col_block_pid = tl.program_id(axis=1)
  1379. n_block_rows = tl.num_programs(axis=0)
  1380. n_block_cols = tl.num_programs(axis=1)
  1381. row_block_pid, col_block_pid = tl.swizzle2d(
  1382. row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW
  1383. )
  1384. crow_indices_offset_ptr = (
  1385. crow_indices_ptr
  1386. + crow_indices_batch_stride * batch_pid
  1387. + crow_indices_stride * row_block_pid
  1388. )
  1389. nnz_offset = tl.load(crow_indices_offset_ptr)
  1390. nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)
  1391. # Compute nnz for the row with number row_block_pid.
  1392. # If it is zero, skip the row.
  1393. row_nnz = nnz_offset_next - nnz_offset
  1394. if row_nnz == 0:
  1395. return
  1396. row_block_arange = tl.arange(0, BLOCKSIZE_ROW)
  1397. col_block_arange = tl.arange(0, BLOCKSIZE_COL)
  1398. # Pointers are set to the first block of the current row.
  1399. values_block_ptrs = (
  1400. values_ptr
  1401. + values_batch_stride * batch_pid
  1402. + values_nnz_stride * nnz_offset
  1403. + values_row_block_stride * row_block_arange[:, None]
  1404. + values_col_block_stride * col_block_arange[None, :]
  1405. )
  1406. # NOTE: dense is advanced into all dimensions but the tiled row one.
  1407. # That will be advanced in the loop according to values in col_indices.
  1408. dense_block_ptrs = (
  1409. dense_ptr
  1410. + dense_batch_stride * batch_pid
  1411. + dense_tiled_col_stride * col_block_pid
  1412. + dense_row_block_stride * col_block_arange[:, None]
  1413. + dense_col_block_stride * row_block_arange[None, :]
  1414. )
  1415. # Pointers are set to exact write-to locations
  1416. output_ptrs = (
  1417. output_ptr
  1418. + output_batch_stride * batch_pid
  1419. + output_tiled_row_stride * row_block_pid
  1420. + output_tiled_col_stride * col_block_pid
  1421. + output_row_block_stride * row_block_arange[:, None]
  1422. + output_col_block_stride * row_block_arange[None, :]
  1423. )
  1424. # Set pointer to the first nonzero element in the current row
  1425. col_index_nnz_ptr = (
  1426. col_indices_ptr
  1427. + col_indices_batch_stride * batch_pid
  1428. + col_indices_stride * nnz_offset
  1429. )
  1430. output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype)
  1431. for _ in range(row_nnz):
  1432. values_block = tl.load(values_block_ptrs)
  1433. # find which row of dense needs to get loaded
  1434. # for multiplication with values_block.
  1435. dense_row_idx = tl.load(col_index_nnz_ptr)
  1436. dense_block = tl.load(
  1437. dense_block_ptrs + dense_tiled_row_stride * dense_row_idx
  1438. )
  1439. # do block mm
  1440. output_acc_block += tl.dot(
  1441. values_block, dense_block, allow_tf32=allow_tf32, out_dtype=acc_dtype
  1442. )
  1443. # move val/col_index ptrs to the next block in the row
  1444. values_block_ptrs += values_nnz_stride
  1445. col_index_nnz_ptr += col_indices_stride
  1446. # write back the result
  1447. tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty))
  1448. def _run_sampled_addmm_kernel(
  1449. alpha,
  1450. beta,
  1451. is_beta_zero,
  1452. blocksize,
  1453. k,
  1454. tile_k,
  1455. values,
  1456. crow_indices,
  1457. col_indices,
  1458. mat1,
  1459. mat2,
  1460. max_grid,
  1461. ):
  1462. n_batches = values.size(0)
  1463. n_block_rows = crow_indices.size(-1) - 1
  1464. full_grid = (n_batches, n_block_rows)
  1465. if max_grid is not None:
  1466. grid_blocks = tuple(max_grid[:2][::-1]) + (None,) * (2 - len(max_grid[:2]))
  1467. else:
  1468. grid_blocks = None
  1469. tensor_dims_map = {
  1470. values: (0, None),
  1471. crow_indices: (0, -1),
  1472. col_indices: (0, None),
  1473. mat1: (0, -4),
  1474. mat2: (0, None),
  1475. }
  1476. if values.dtype in (torch.half, torch.bfloat16):
  1477. acc_dtype = tl.float32
  1478. allow_tf32 = True
  1479. else:
  1480. acc_dtype = tl.float64
  1481. allow_tf32 = False
  1482. def kernel(grid, *sliced_tensors):
  1483. _sampled_addmm_kernel[grid](
  1484. alpha,
  1485. beta,
  1486. is_beta_zero,
  1487. *blocksize,
  1488. # pyrefly: ignore [bad-argument-count]
  1489. k,
  1490. tile_k,
  1491. *ptr_stride_extractor(*sliced_tensors),
  1492. # pyrefly: ignore [bad-keyword-argument, bad-argument-type]
  1493. acc_dtype=acc_dtype,
  1494. # pyrefly: ignore [bad-keyword-argument, bad-argument-type]
  1495. allow_tf32=allow_tf32,
  1496. # pyrefly: ignore [unexpected-keyword]
  1497. num_stages=1,
  1498. # pyrefly: ignore [unexpected-keyword]
  1499. num_warps=4,
  1500. )
  1501. launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks)
  1502. def sampled_addmm(
  1503. input: torch.Tensor,
  1504. mat1: torch.Tensor,
  1505. mat2: torch.Tensor,
  1506. *,
  1507. beta=1.0,
  1508. alpha=1.0,
  1509. out: torch.Tensor | None = None,
  1510. skip_checks: bool = False,
  1511. max_grid: tuple[int | None, int | None, int | None] | None = None,
  1512. ):
  1513. f_name = "sampled_addmm"
  1514. check_bsr_layout(f_name, input)
  1515. input_broadcasted = broadcast_batch_dims_bsr(f_name, input, mat1, mat2)
  1516. if not skip_checks:
  1517. check_device(f_name, mat1, input.device)
  1518. check_device(f_name, mat2, input.device)
  1519. if beta != 0.0 and input.dtype is torch.bool:
  1520. check(
  1521. False,
  1522. f"{f_name}(): having beta == {beta} not equal to 0.0 with boolean mask is not allowed.",
  1523. )
  1524. if input.dtype is not torch.bool:
  1525. check_dtype(f_name, mat1, input.dtype)
  1526. check_dtype(f_name, mat2, input.dtype)
  1527. else:
  1528. check_dtype(f_name, mat1, mat2.dtype)
  1529. check_mm_compatible_shapes(f_name, mat1, mat2)
  1530. if out is not None:
  1531. check_bsr_layout(f_name, out)
  1532. check_device(f_name, out, mat1.device)
  1533. check_dtype(f_name, out, input.dtype)
  1534. check(
  1535. out.shape == input_broadcasted.shape and out._nnz() == input._nnz(),
  1536. f"{f_name}(): Expects `out` to be of shape {input_broadcasted.shape} "
  1537. f"and with nnz equal to {input_broadcasted._nnz()} "
  1538. f"but got out.shape = {out.shape} and out.nnz = {out._nnz()}",
  1539. )
  1540. if out is None:
  1541. out = input_broadcasted.to(mat1.dtype, copy=True)
  1542. else:
  1543. out.copy_(input_broadcasted)
  1544. if out.numel() == 0 or out._nnz() == 0:
  1545. return out
  1546. blocksize = out.values().shape[-2:]
  1547. k = mat1.size(-1)
  1548. # NOTE: (m, 0) @ (0, n) == zeros(m, n)
  1549. if alpha == 0.0 or k == 0:
  1550. out.values().mul_(beta)
  1551. return out
  1552. # prepare inputs by reshaping them to be kernel-compatible
  1553. out_backup = out
  1554. crow_indices, col_indices, values, mat1, mat2 = prepare_inputs(out, mat1, mat2)
  1555. mat1 = tile_to_blocksize(mat1, (blocksize[0], k))
  1556. mat2 = tile_to_blocksize(mat2, (k, blocksize[1]))
  1557. tile_k = max(*blocksize)
  1558. _run_sampled_addmm_kernel(
  1559. alpha,
  1560. beta,
  1561. beta == 0.0,
  1562. blocksize,
  1563. k,
  1564. tile_k,
  1565. values,
  1566. crow_indices,
  1567. col_indices,
  1568. mat1,
  1569. mat2,
  1570. max_grid,
  1571. )
  1572. # If nnz x block strides are not the same in out_backup.values and values,
  1573. # it means that out_backup.values and values are not the views of each other,
  1574. # so we have to copy.
  1575. if out_backup.values().stride()[-3:] != values.stride()[-3:]:
  1576. out_backup.values().copy_(values.reshape(out_backup.values().shape))
  1577. return out_backup
  1578. def bsr_dense_mm(
  1579. bsr: torch.Tensor,
  1580. dense: torch.Tensor,
  1581. *,
  1582. out: torch.Tensor | None = None,
  1583. skip_checks: bool = False,
  1584. max_grid: tuple[int | None, int | None, int | None] | None = None,
  1585. meta: dict | None = None,
  1586. ):
  1587. f_name = "bsr_dense_mm"
  1588. m, _kl = bsr.shape[-2:]
  1589. if not skip_checks:
  1590. check_bsr_layout(f_name, bsr)
  1591. check_device(f_name, bsr, dense.device)
  1592. check_dtype(f_name, bsr, dense.dtype, (torch.int8,))
  1593. check_mm_compatible_shapes(f_name, bsr, dense)
  1594. n = dense.size(-1)
  1595. row_block, col_block = bsr.values().shape[-2:]
  1596. check_blocksize(f_name, (row_block, col_block))
  1597. check(
  1598. not n % 16,
  1599. f"{f_name}(): dense.size(-1) == {n} should be divisible by 16",
  1600. )
  1601. else:
  1602. _kr, n = dense.shape[-2:]
  1603. original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense)
  1604. if out is not None and not skip_checks:
  1605. expected_out_shape = original_batch_dims_broadcasted + (m, n)
  1606. check(
  1607. out.shape == expected_out_shape,
  1608. "bsr_dense_mm(): `out` argument has wrong shape, "
  1609. f"expected {expected_out_shape}, but got {out.shape}.",
  1610. )
  1611. check(
  1612. out.is_contiguous() or out.transpose(-2, -1).is_contiguous(),
  1613. "bsr_dense_mm(): only row-major/col-major `out` arguments are supported, "
  1614. "i.e. (out.is_contiguous() or out.transpose(-2, -1).is_contiguous()) "
  1615. "should be True.",
  1616. )
  1617. # Allocate out
  1618. if out is None:
  1619. out = dense.new_empty(original_batch_dims_broadcasted + (m, n))
  1620. # Short circuit if lhs is zero
  1621. if bsr._nnz() == 0:
  1622. return out.zero_()
  1623. # with beta==0, addmm ignores input content, so we can use out
  1624. # as a placeholder for input because their shapes match:
  1625. return bsr_dense_addmm(out, bsr, dense, alpha=1, beta=0, out=out)
  1626. @triton.jit
  1627. def _bsr_softmax_kernel(
  1628. crow_indices_ptr,
  1629. crow_indices_batch_stride,
  1630. crow_indices_stride,
  1631. values_ptr,
  1632. values_batch_stride,
  1633. values_row_block_stride,
  1634. values_nnz_col_block_stride,
  1635. row_block,
  1636. col_block,
  1637. MAX_ROW_NNZ: tl.constexpr,
  1638. TILE: tl.constexpr,
  1639. ):
  1640. batch_pid = tl.program_id(axis=2)
  1641. row_block_offset_pid = tl.program_id(axis=1)
  1642. row_block_pid = tl.program_id(axis=0)
  1643. crow_indices_offset_ptr = (
  1644. crow_indices_ptr
  1645. + crow_indices_batch_stride * batch_pid
  1646. + crow_indices_stride * row_block_pid
  1647. )
  1648. nnz_offset = tl.load(crow_indices_offset_ptr)
  1649. nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)
  1650. # Compute nnz for the row with number row_block_pid.
  1651. # If it is zero, skip the row.
  1652. row_nnz = nnz_offset_next - nnz_offset
  1653. if row_nnz == 0:
  1654. return
  1655. row_arange = tl.arange(0, TILE)
  1656. mask = row_arange < row_nnz * col_block
  1657. curr_row_values_ptrs = (
  1658. values_ptr
  1659. + values_batch_stride * batch_pid
  1660. + values_row_block_stride * row_block_offset_pid
  1661. + nnz_offset * col_block
  1662. )
  1663. # find max in the row
  1664. row_tile = tl.load(
  1665. curr_row_values_ptrs + row_arange, mask=mask, other=-float("inf")
  1666. ).to(tl.float32)
  1667. max_row_value = tl.max(row_tile, axis=0)
  1668. for _ in range(TILE, MAX_ROW_NNZ, TILE):
  1669. row_arange += TILE
  1670. mask = row_arange < row_nnz * col_block
  1671. row_tile = tl.load(
  1672. curr_row_values_ptrs + row_arange, mask=mask, other=-float("inf")
  1673. ).to(tl.float32)
  1674. curr_max_row_value = tl.max(row_tile, axis=0)
  1675. max_row_value = tl.where(
  1676. max_row_value > curr_max_row_value, max_row_value, curr_max_row_value
  1677. )
  1678. # find denominator for stable softmax
  1679. num = tl.exp(row_tile - max_row_value)
  1680. denom = tl.sum(num, axis=0)
  1681. for _ in range(TILE, MAX_ROW_NNZ, TILE):
  1682. row_arange -= TILE
  1683. mask = row_arange < row_nnz * col_block
  1684. row_tile = tl.load(
  1685. curr_row_values_ptrs + row_arange, mask=mask, other=-float("inf")
  1686. ).to(tl.float32)
  1687. num = tl.exp(row_tile - max_row_value)
  1688. denom += tl.sum(num, axis=0)
  1689. # populate output
  1690. tl.store(
  1691. curr_row_values_ptrs + row_arange,
  1692. (num / denom).to(values_ptr.dtype.element_ty),
  1693. mask=mask,
  1694. )
  1695. for _ in range(TILE, MAX_ROW_NNZ, TILE):
  1696. row_arange += TILE
  1697. mask = row_arange < row_nnz * col_block
  1698. row_tile = tl.load(
  1699. curr_row_values_ptrs + row_arange, mask=mask, other=-float("inf")
  1700. ).to(tl.float32)
  1701. num = tl.exp(row_tile - max_row_value)
  1702. tl.store(
  1703. curr_row_values_ptrs + row_arange,
  1704. (num / denom).to(values_ptr.dtype.element_ty),
  1705. mask=mask,
  1706. )
  1707. def bsr_softmax(input, max_row_nnz=None):
  1708. f_name = "bsr_softmax"
  1709. check_bsr_layout(f_name, input)
  1710. check_dtype(f_name, input, input.dtype)
  1711. if input._nnz() == 0 or input.numel() == 0:
  1712. return input.clone()
  1713. m, n = input.shape[-2:]
  1714. nnz = input._nnz()
  1715. row_block, col_block = input.values().shape[-2:]
  1716. if max_row_nnz is None:
  1717. max_row_nnz = triton.next_power_of_2(n)
  1718. else:
  1719. max_row_nnz = triton.next_power_of_2(max_row_nnz)
  1720. crow_indices = input.crow_indices().unsqueeze(0).flatten(0, -2)
  1721. # reshape values from
  1722. # (b1, ..., bn, nnz, row_block, col_block) to
  1723. # (b1 * ... * bn, row_block, nnz * col_block).
  1724. # This simplifies batch dim manipulation and unlocks
  1725. # the possibility to access all nnzs in any given row.
  1726. if input.values().transpose(-3, -2).is_contiguous():
  1727. # Need to clone to avoid `contiguous` returning a view.
  1728. values = input.values().clone()
  1729. else:
  1730. values = input.values()
  1731. values = (
  1732. values.transpose(-3, -2)
  1733. .contiguous()
  1734. .unsqueeze(0)
  1735. .flatten(0, -4)
  1736. .reshape(-1, row_block, nnz * col_block)
  1737. )
  1738. full_grid = (values.shape[0], row_block, m // row_block)
  1739. grid_blocks = None
  1740. tensor_dims_map = {
  1741. # We span nnz number of blocks, not nnz + 1,
  1742. # hence crow_indices[..., :-1]
  1743. crow_indices[..., :-1]: (0, None, -1),
  1744. values: (0, None, None),
  1745. }
  1746. def kernel(grid, *sliced_tensors):
  1747. _bsr_softmax_kernel[grid](
  1748. *ptr_stride_extractor(*sliced_tensors),
  1749. # pyrefly: ignore [bad-argument-count]
  1750. row_block,
  1751. col_block,
  1752. max_row_nnz,
  1753. # Triton's max numel is bounded by 2 ** 17.
  1754. min(2**17, max_row_nnz),
  1755. )
  1756. launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks)
  1757. values = (
  1758. values.reshape(-1, row_block, nnz, col_block)
  1759. .transpose(-3, -2)
  1760. .reshape(*input.values().shape)
  1761. )
  1762. return torch.sparse_compressed_tensor(
  1763. input.crow_indices().clone(),
  1764. input.col_indices().clone(),
  1765. values,
  1766. size=input.shape,
  1767. layout=input.layout,
  1768. )
  1769. def _scaled_dot_product_attention(
  1770. query: torch.Tensor,
  1771. key: torch.Tensor,
  1772. value: torch.Tensor,
  1773. attn_mask: torch.Tensor | None,
  1774. dropout_p: float = 0.0,
  1775. is_causal: bool = False,
  1776. scale: float | None = None,
  1777. ):
  1778. f_name = "_scaled_dot_product_attention"
  1779. check(not is_causal, f"{f_name}(): is_causal == True is not supported.")
  1780. check(attn_mask is not None, f"{f_name}(): attn_mask == None is not supported.")
  1781. if attn_mask is None:
  1782. raise AssertionError("attn_mask must not be None")
  1783. check(
  1784. attn_mask.layout == torch.sparse_bsr,
  1785. f"{f_name}(): "
  1786. f"attn_mask.layout must be {torch.sparse_bsr}, but got "
  1787. f"attn_mask.layout == {attn_mask.layout}.",
  1788. )
  1789. check_device(f_name, key, query.device)
  1790. check_device(f_name, value, query.device)
  1791. check_device(f_name, attn_mask, query.device)
  1792. check_dtype(f_name, key, query.dtype)
  1793. check_dtype(f_name, value, query.dtype)
  1794. if attn_mask.dtype is not torch.bool:
  1795. check_dtype(f_name, attn_mask, query.dtype)
  1796. # pyrefly: ignore [not-callable]
  1797. sdpa = sampled_addmm(
  1798. attn_mask, query, key.transpose(-2, -1), beta=0.0, skip_checks=False
  1799. )
  1800. if scale is None and query.size(-1) == 0 or scale == 0.0:
  1801. check(
  1802. False,
  1803. f"{f_name}(): current value of scale == {scale} "
  1804. "results in division by zero.",
  1805. )
  1806. scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
  1807. sdpa.values().mul_(scale_factor)
  1808. # pyrefly: ignore [not-callable]
  1809. sdpa = bsr_softmax(sdpa)
  1810. torch.nn.functional.dropout(sdpa.values(), p=dropout_p, inplace=True)
  1811. # pyrefly: ignore [not-callable]
  1812. sdpa = bsr_dense_mm(sdpa, value)
  1813. return sdpa
  1814. @triton.jit
  1815. def _scatter_mm2_kernel(
  1816. M: tl.constexpr,
  1817. K: tl.constexpr,
  1818. N: tl.constexpr,
  1819. blocks_ptr,
  1820. blocks_stride_P,
  1821. blocks_stride_M,
  1822. blocks_stride_K,
  1823. others_ptr,
  1824. others_stride_Q,
  1825. others_stride_K,
  1826. others_stride_N,
  1827. accumulators_ptr,
  1828. accumulators_stride_R,
  1829. accumulators_stride_M,
  1830. accumulators_stride_N,
  1831. pq_offsets_ptr,
  1832. pq_offsets_stride,
  1833. pq_ptr,
  1834. pq_stride_T,
  1835. pq_stride_1,
  1836. dot_out_dtype: tl.constexpr,
  1837. TILE_M: tl.constexpr,
  1838. TILE_N: tl.constexpr,
  1839. allow_tf32: tl.constexpr,
  1840. ):
  1841. Ms = M // TILE_M
  1842. pid_t = tl.program_id(axis=0)
  1843. pid = tl.program_id(axis=1)
  1844. pid_m = pid // Ms
  1845. pid_n = pid % Ms
  1846. rm = pid_m * TILE_M + tl.arange(0, TILE_M)
  1847. rn = pid_n * TILE_N + tl.arange(0, TILE_N)
  1848. rk = tl.arange(0, K)
  1849. A_ptr = blocks_ptr + (
  1850. rm[:, None] * blocks_stride_M + rk[None, :] * blocks_stride_K
  1851. )
  1852. B_ptr = others_ptr + (
  1853. rk[:, None] * others_stride_K + rn[None, :] * others_stride_N
  1854. )
  1855. g0 = tl.load(pq_offsets_ptr + pid_t * pq_offsets_stride)
  1856. g1 = tl.load(pq_offsets_ptr + (pid_t + 1) * pq_offsets_stride)
  1857. if g0 == g1:
  1858. return
  1859. acc_block = tl.zeros((TILE_M, TILE_N), dtype=dot_out_dtype)
  1860. for i in range(g0, g1):
  1861. p = tl.load(pq_ptr + i * pq_stride_T)
  1862. q = tl.load(pq_ptr + i * pq_stride_T + pq_stride_1)
  1863. A = tl.load(A_ptr + p * blocks_stride_P)
  1864. B = tl.load(B_ptr + q * others_stride_Q)
  1865. acc_block += tl.dot(A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
  1866. C_ptr = (
  1867. accumulators_ptr
  1868. + pid_t * accumulators_stride_R
  1869. + (
  1870. rm[:, None] * accumulators_stride_M
  1871. + rn[None, :] * accumulators_stride_N
  1872. )
  1873. )
  1874. tl.store(C_ptr, acc_block.to(accumulators_ptr.dtype.element_ty))
  1875. def _scatter_mm2(
  1876. blocks: torch.Tensor,
  1877. others: torch.Tensor,
  1878. pq_offsets: torch.Tensor,
  1879. pq_indices: torch.Tensor,
  1880. accumulators: torch.Tensor,
  1881. ):
  1882. _P, M, K = blocks.shape
  1883. _Q, _, N = others.shape
  1884. meta = dict(
  1885. TILE_M=max(16, M // 4), TILE_N=max(16, N // 4), num_stages=1, num_warps=2
  1886. )
  1887. def grid(META):
  1888. return (
  1889. pq_offsets.shape[0] - 1,
  1890. triton.cdiv(M, META["TILE_M"]) * triton.cdiv(N, META["TILE_N"]),
  1891. 1,
  1892. )
  1893. dot_out_dtype = {
  1894. torch.float16: tl.float32,
  1895. torch.bfloat16: tl.float32,
  1896. torch.float32: tl.float64,
  1897. torch.float64: tl.float64,
  1898. }[accumulators.dtype]
  1899. if "allow_tf32" not in meta:
  1900. meta.update(allow_tf32=dot_out_dtype == tl.float32)
  1901. _scatter_mm2_kernel[grid](
  1902. # pyrefly: ignore [bad-argument-type]
  1903. M,
  1904. # pyrefly: ignore [bad-argument-type]
  1905. K,
  1906. # pyrefly: ignore [bad-argument-type]
  1907. N,
  1908. blocks,
  1909. blocks.stride(0),
  1910. blocks.stride(1),
  1911. blocks.stride(2),
  1912. others,
  1913. others.stride(0),
  1914. others.stride(1),
  1915. others.stride(2),
  1916. accumulators,
  1917. accumulators.stride(0),
  1918. accumulators.stride(1),
  1919. accumulators.stride(2),
  1920. pq_offsets,
  1921. pq_offsets.stride(0),
  1922. pq_indices,
  1923. pq_indices.stride(0),
  1924. pq_indices.stride(1),
  1925. # pyrefly: ignore [bad-argument-type]
  1926. dot_out_dtype=dot_out_dtype,
  1927. # pyrefly: ignore [bad-argument-type]
  1928. **meta,
  1929. )
  1930. @triton.jit
  1931. def _scatter_mm6_kernel(
  1932. nbatches,
  1933. Ms,
  1934. Ks: tl.constexpr,
  1935. N,
  1936. blocks_ptr,
  1937. blocks_stride_P,
  1938. blocks_stride_M,
  1939. blocks_stride_K,
  1940. others_ptr,
  1941. others_stride_B,
  1942. others_stride_K,
  1943. others_stride_N,
  1944. accumulators_ptr,
  1945. accumulators_stride_B,
  1946. accumulators_stride_M,
  1947. accumulators_stride_N,
  1948. c_indices_ptr,
  1949. r_offsets_ptr,
  1950. p_offsets_ptr,
  1951. q_offsets_ptr,
  1952. is_compressed: tl.constexpr,
  1953. dot_out_dtype: tl.constexpr,
  1954. SPLIT_N: tl.constexpr,
  1955. TILE_M: tl.constexpr,
  1956. TILE_N: tl.constexpr,
  1957. GROUP_SIZE: tl.constexpr,
  1958. allow_tf32: tl.constexpr,
  1959. ):
  1960. Ns = N // SPLIT_N
  1961. BLOCKS_M = Ms // TILE_M
  1962. BLOCKS_N = Ns // TILE_N
  1963. pid_t_ = tl.program_id(axis=0)
  1964. pid = tl.program_id(axis=1)
  1965. pid_b = pid_t_ % nbatches
  1966. pid_t = pid_t_ // nbatches
  1967. num_pid_in_group = GROUP_SIZE * BLOCKS_N
  1968. group_id = pid // num_pid_in_group
  1969. first_pid_m = group_id * GROUP_SIZE
  1970. group_size_m = min(BLOCKS_M - first_pid_m, GROUP_SIZE)
  1971. pid_m = first_pid_m + (pid % group_size_m)
  1972. pid_n = (pid % num_pid_in_group) // group_size_m
  1973. rm = pid_m * TILE_M + tl.arange(0, TILE_M)
  1974. rn = pid_n * TILE_N + tl.arange(0, TILE_N)
  1975. rk = tl.arange(0, Ks)
  1976. A_ptr = blocks_ptr + (
  1977. rm[:, None] * blocks_stride_M + rk[None, :] * blocks_stride_K
  1978. )
  1979. B_ptr = (
  1980. others_ptr
  1981. + pid_b * others_stride_B
  1982. + (rk[:, None] * others_stride_K + rn[None, :] * others_stride_N)
  1983. )
  1984. # When is_compressed is True, r is the only variable that
  1985. # depends on pid_t. This property allows sorting r values
  1986. # before calling the kernel. The sorting of r is equivalent to
  1987. # defining swizzle operator outside of the kernel.
  1988. r = tl.load(r_offsets_ptr + pid_t)
  1989. if is_compressed:
  1990. m = (r // N) // Ms
  1991. n = (r % N) // Ns
  1992. r0 = tl.load(c_indices_ptr + m)
  1993. r1 = tl.load(c_indices_ptr + m + 1)
  1994. g0 = n * r1 + (SPLIT_N - n) * r0
  1995. nnz = r1 - r0
  1996. else:
  1997. g0 = tl.load(c_indices_ptr + pid_t)
  1998. g1 = tl.load(c_indices_ptr + pid_t + 1)
  1999. nnz = g1 - g0
  2000. q_ptr = q_offsets_ptr + g0
  2001. acc_block = tl.zeros((TILE_M, TILE_N), dtype=dot_out_dtype)
  2002. if is_compressed:
  2003. A_ptr += r0 * blocks_stride_P # type: ignore[possibly-undefined]
  2004. for _ in range(nnz):
  2005. q = tl.load(q_ptr)
  2006. B = tl.load(B_ptr + q)
  2007. A = tl.load(A_ptr)
  2008. acc_block += tl.dot(
  2009. A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32
  2010. )
  2011. A_ptr += blocks_stride_P
  2012. q_ptr += 1
  2013. else:
  2014. p_ptr = p_offsets_ptr + g0
  2015. for _ in range(nnz):
  2016. q = tl.load(q_ptr)
  2017. B = tl.load(B_ptr + q)
  2018. p = tl.load(p_ptr)
  2019. A = tl.load(A_ptr + p * blocks_stride_P)
  2020. p_ptr += 1
  2021. q_ptr += 1
  2022. acc_block += tl.dot(
  2023. A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32
  2024. )
  2025. C_ptr = (
  2026. accumulators_ptr
  2027. + r
  2028. + pid_b * accumulators_stride_B
  2029. + (
  2030. rm[:, None] * accumulators_stride_M
  2031. + rn[None, :] * accumulators_stride_N
  2032. )
  2033. )
  2034. tl.store(C_ptr, acc_block.to(accumulators_ptr.dtype.element_ty))
  2035. def _scatter_mm6(
  2036. blocks: torch.Tensor,
  2037. others: torch.Tensor,
  2038. c_indices: torch.Tensor,
  2039. r_offsets: torch.Tensor,
  2040. p_offsets: torch.Tensor,
  2041. q_offsets: torch.Tensor,
  2042. meta: dict,
  2043. accumulators: torch.Tensor,
  2044. force_contiguous: bool = True,
  2045. ):
  2046. SPLIT_N = meta["SPLIT_N"]
  2047. _P, Ms, Ks = blocks.shape
  2048. B, _K, N = others.shape
  2049. B_, _M, N_ = accumulators.shape
  2050. if N_ != N:
  2051. raise AssertionError(f"accumulators N ({N_}) != others N ({N})")
  2052. Ns = N // SPLIT_N
  2053. if B_ != B:
  2054. raise AssertionError(f"accumulators B ({B_}) != others B ({B})")
  2055. def grid(META):
  2056. return (
  2057. r_offsets.shape[0] * B,
  2058. triton.cdiv(Ms, META["TILE_M"]) * triton.cdiv(Ns, META["TILE_N"]),
  2059. )
  2060. dot_out_dtype = {
  2061. torch.float16: tl.float32,
  2062. torch.bfloat16: tl.float32,
  2063. torch.float32: tl.float64,
  2064. torch.float64: tl.float64,
  2065. }[accumulators.dtype]
  2066. if "allow_tf32" not in meta:
  2067. meta.update(allow_tf32=dot_out_dtype == tl.float32)
  2068. if c_indices.stride(0) != 1:
  2069. raise AssertionError(
  2070. f"c_indices.stride(0) must be 1, got {c_indices.stride(0)}"
  2071. )
  2072. if r_offsets.stride(0) != 1:
  2073. raise AssertionError(
  2074. f"r_offsets.stride(0) must be 1, got {r_offsets.stride(0)}"
  2075. )
  2076. if p_offsets.stride(0) != 1:
  2077. raise AssertionError(
  2078. f"p_offsets.stride(0) must be 1, got {p_offsets.stride(0)}"
  2079. )
  2080. if q_offsets.stride(0) != 1:
  2081. raise AssertionError(
  2082. f"q_offsets.stride(0) must be 1, got {q_offsets.stride(0)}"
  2083. )
  2084. # Re non-contiguous tensor arguments. Sometimes triton kernel
  2085. # launches may fail with
  2086. #
  2087. # RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered
  2088. #
  2089. # that appears to be case when the size of a non-contiguous
  2090. # tensor argument is larger than a certain threshold. Could
  2091. # this be related to shared memory or L1 cache size of a GPU
  2092. # card? In anycase, ensuring that tensor arguments are
  2093. # contiguous seems to avoid the above exception. So, in the
  2094. # following we'll always convert tensor arguments to
  2095. # C-contiguous tensors.
  2096. if force_contiguous:
  2097. blocks = blocks.contiguous()
  2098. others = others.contiguous()
  2099. if not accumulators.is_contiguous():
  2100. accumulators_ = accumulators.contiguous()
  2101. else:
  2102. accumulators_ = accumulators
  2103. else:
  2104. accumulators_ = accumulators
  2105. _scatter_mm6_kernel[grid](
  2106. B,
  2107. Ms,
  2108. # pyrefly: ignore [bad-argument-type]
  2109. Ks,
  2110. N,
  2111. blocks,
  2112. blocks.stride(0),
  2113. blocks.stride(1),
  2114. blocks.stride(2),
  2115. others,
  2116. others.stride(0),
  2117. others.stride(1),
  2118. others.stride(2),
  2119. accumulators_,
  2120. accumulators_.stride(0),
  2121. accumulators_.stride(1),
  2122. accumulators_.stride(2),
  2123. c_indices,
  2124. r_offsets,
  2125. p_offsets,
  2126. q_offsets,
  2127. # pyrefly: ignore [bad-argument-type]
  2128. dot_out_dtype=dot_out_dtype,
  2129. **meta,
  2130. )
  2131. if force_contiguous and not accumulators.is_contiguous():
  2132. accumulators.copy_(accumulators_)
  2133. @triton.jit
  2134. def _bsr_strided_addmm_kernel(
  2135. # values prologue
  2136. values_ptr,
  2137. values_batch_stride,
  2138. values_nnz_stride,
  2139. values_row_block_stride,
  2140. values_col_block_stride,
  2141. # values epilogue
  2142. # crow_indices prologue
  2143. crow_indices_ptr,
  2144. crow_indices_batch_stride,
  2145. crow_indices_stride,
  2146. # crow_indices epilogue
  2147. # col_indices prologue
  2148. col_indices_ptr,
  2149. col_indices_batch_stride,
  2150. col_indices_stride,
  2151. # col_indices epilogue
  2152. # input prologue
  2153. input_ptr,
  2154. input_batch_stride,
  2155. input_tiled_row_stride,
  2156. input_tiled_col_stride,
  2157. input_row_block_stride,
  2158. input_col_block_stride,
  2159. # input epilogue
  2160. # dense prologue
  2161. dense_ptr,
  2162. dense_batch_stride,
  2163. dense_tiled_row_stride,
  2164. dense_tiled_col_stride,
  2165. dense_row_block_stride,
  2166. dense_col_block_stride,
  2167. # dense epilogue
  2168. # left_alpha prologue
  2169. left_alpha_ptr,
  2170. left_alpha_batch_stride,
  2171. left_alpha_tiled_row_stride,
  2172. left_alpha_tiled_col_stride: tl.constexpr,
  2173. left_alpha_row_block_stride,
  2174. left_alpha_col_block_stride: tl.constexpr,
  2175. # left_alpha epilogue
  2176. # right_alpha prologue
  2177. right_alpha_ptr,
  2178. right_alpha_batch_stride,
  2179. right_alpha_tiled_row_stride: tl.constexpr,
  2180. right_alpha_tiled_col_stride,
  2181. right_alpha_row_block_stride: tl.constexpr,
  2182. right_alpha_col_block_stride,
  2183. # right_alpha epilogue
  2184. # output prologue
  2185. output_ptr,
  2186. output_batch_stride,
  2187. output_tiled_row_stride,
  2188. output_tiled_col_stride,
  2189. output_row_block_stride,
  2190. output_col_block_stride,
  2191. # output epilogue
  2192. beta,
  2193. alpha,
  2194. beta_is_one: tl.constexpr,
  2195. beta_is_nonzero: tl.constexpr,
  2196. alpha_is_one: tl.constexpr,
  2197. left_alpha_is_one: tl.constexpr,
  2198. right_alpha_is_one: tl.constexpr,
  2199. BLOCKSIZE_ROW: tl.constexpr,
  2200. BLOCKSIZE_COL: tl.constexpr,
  2201. BLOCKSIZE_INNER: tl.constexpr,
  2202. acc_dtype: tl.constexpr,
  2203. allow_tf32: tl.constexpr,
  2204. GROUP_SIZE_ROW: tl.constexpr,
  2205. SPLIT_N: tl.constexpr,
  2206. ):
  2207. # left/right_alpha tensors are originally (* + 1)-dimensional
  2208. if left_alpha_tiled_col_stride != 0:
  2209. raise AssertionError(
  2210. f"left_alpha_tiled_col_stride must be 0, got {left_alpha_tiled_col_stride}"
  2211. )
  2212. if left_alpha_col_block_stride != 0:
  2213. raise AssertionError(
  2214. f"left_alpha_col_block_stride must be 0, got {left_alpha_col_block_stride}"
  2215. )
  2216. if right_alpha_tiled_row_stride != 0:
  2217. raise AssertionError(
  2218. f"right_alpha_tiled_row_stride must be 0, got {right_alpha_tiled_row_stride}"
  2219. )
  2220. if right_alpha_row_block_stride != 0:
  2221. raise AssertionError(
  2222. f"right_alpha_row_block_stride must be 0, got {right_alpha_row_block_stride}"
  2223. )
  2224. batch_pid = tl.program_id(axis=2)
  2225. row_block_pid = tl.program_id(axis=0)
  2226. col_block_pid = tl.program_id(axis=1)
  2227. n_block_rows = tl.num_programs(axis=0)
  2228. n_block_cols = tl.num_programs(axis=1)
  2229. row_block_pid, col_block_pid = tl.swizzle2d(
  2230. row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW
  2231. )
  2232. crow_indices_offset_ptr = (
  2233. crow_indices_ptr
  2234. + crow_indices_batch_stride * batch_pid
  2235. + crow_indices_stride * row_block_pid
  2236. )
  2237. nnz_offset = tl.load(crow_indices_offset_ptr)
  2238. nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)
  2239. # Compute nnz for the row with number row_block_pid.
  2240. row_nnz = nnz_offset_next - nnz_offset
  2241. row_block_arange = tl.arange(0, BLOCKSIZE_ROW)
  2242. inner_block_arange = tl.arange(0, BLOCKSIZE_INNER)
  2243. col_block_arange = tl.arange(0, BLOCKSIZE_COL)
  2244. # Pointers are set to the first block of the current row.
  2245. values_block_ptrs = (
  2246. values_ptr
  2247. + values_batch_stride * batch_pid
  2248. + values_nnz_stride * nnz_offset
  2249. + values_row_block_stride * row_block_arange[:, None]
  2250. + values_col_block_stride * inner_block_arange[None, :]
  2251. )
  2252. # NOTE: dense is advanced into all dimensions but the tiled row one.
  2253. # That will be advanced in the loop according to values in col_indices.
  2254. dense_block_ptrs = (
  2255. dense_ptr
  2256. + dense_batch_stride * batch_pid
  2257. + dense_tiled_col_stride * col_block_pid
  2258. + dense_row_block_stride * inner_block_arange[:, None]
  2259. + dense_col_block_stride * col_block_arange[None, :]
  2260. )
  2261. # Pointers are set to exact write-to locations
  2262. output_ptrs = (
  2263. output_ptr
  2264. + output_batch_stride * batch_pid
  2265. + output_tiled_row_stride * row_block_pid
  2266. + output_tiled_col_stride * col_block_pid
  2267. + output_row_block_stride * row_block_arange[:, None]
  2268. + output_col_block_stride * col_block_arange[None, :]
  2269. )
  2270. # Set pointer to the first nonzero element in the current row
  2271. col_index_nnz_ptr = (
  2272. col_indices_ptr
  2273. + col_indices_batch_stride * batch_pid
  2274. + col_indices_stride * nnz_offset
  2275. )
  2276. output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype)
  2277. for _ in range(row_nnz):
  2278. values_block = tl.load(values_block_ptrs)
  2279. # find which row of dense needs to get loaded
  2280. # for multiplication with values_block.
  2281. dense_row_idx = tl.load(col_index_nnz_ptr)
  2282. dense_block = tl.load(
  2283. dense_block_ptrs + dense_tiled_row_stride * dense_row_idx
  2284. )
  2285. # do block mm
  2286. output_acc_block += tl.dot(
  2287. values_block, dense_block, allow_tf32=allow_tf32, out_dtype=acc_dtype
  2288. )
  2289. # move val/col_index ptrs to the next block in the row
  2290. values_block_ptrs += values_nnz_stride
  2291. col_index_nnz_ptr += col_indices_stride
  2292. if not alpha_is_one:
  2293. output_acc_block *= alpha
  2294. if not left_alpha_is_one:
  2295. left_alpha_ptrs = (
  2296. left_alpha_ptr
  2297. + left_alpha_batch_stride * batch_pid
  2298. + left_alpha_tiled_row_stride * row_block_pid
  2299. + left_alpha_tiled_col_stride * col_block_pid
  2300. + left_alpha_row_block_stride * row_block_arange[:, None]
  2301. + left_alpha_col_block_stride * col_block_arange[None, :]
  2302. )
  2303. output_acc_block *= tl.load(left_alpha_ptrs)
  2304. if not right_alpha_is_one:
  2305. right_alpha_ptrs = (
  2306. right_alpha_ptr
  2307. + right_alpha_batch_stride * batch_pid
  2308. + right_alpha_tiled_row_stride * row_block_pid
  2309. + right_alpha_tiled_col_stride * col_block_pid
  2310. + right_alpha_row_block_stride * row_block_arange[:, None]
  2311. + right_alpha_col_block_stride * col_block_arange[None, :]
  2312. )
  2313. output_acc_block *= tl.load(right_alpha_ptrs)
  2314. if beta_is_nonzero:
  2315. input_ptrs = (
  2316. input_ptr
  2317. + input_batch_stride * batch_pid
  2318. + input_tiled_row_stride * row_block_pid
  2319. + input_tiled_col_stride * col_block_pid
  2320. + input_row_block_stride * row_block_arange[:, None]
  2321. + input_col_block_stride * col_block_arange[None, :]
  2322. )
  2323. if beta_is_one:
  2324. output_acc_block += tl.load(input_ptrs)
  2325. else:
  2326. output_acc_block += beta * tl.load(input_ptrs)
  2327. # write back the result
  2328. tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty))
  2329. else:
  2330. bsr_softmax = None # type: ignore[assignment]
  2331. bsr_dense_mm = None # type: ignore[assignment]
  2332. sampled_addmm = None # type: ignore[assignment]
  2333. _scaled_dot_product_attention = None # type: ignore[assignment]
  2334. _scatter_mm2 = None # type: ignore[assignment]
  2335. _scatter_mm6 = None # type: ignore[assignment]
  2336. _bsr_strided_addmm_kernel = None # type: ignore[assignment]