ssd_bmm.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. # Copyright (c) 2024, Tri Dao, Albert Gu.
  2. """We want triton==2.1.0 or 2.2.0 for this
  3. """
  4. import math
  5. import torch
  6. import torch.nn.functional as F
  7. import triton
  8. import triton.language as tl
  9. from einops import rearrange, repeat
  10. def init_to_zero(names):
  11. return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
  12. @triton.autotune(
  13. configs=[
  14. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
  15. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
  16. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
  17. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
  18. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
  19. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
  20. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
  21. triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
  22. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
  23. ],
  24. key=['chunk_size', 'K', 'IS_CAUSAL'],
  25. )
  26. @triton.jit
  27. def _bmm_chunk_fwd_kernel(
  28. # Pointers to matrices
  29. a_ptr, b_ptr, out_ptr, seq_idx_ptr,
  30. # Matrix dimensions
  31. seqlen, chunk_size, K, ngroups,
  32. stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,
  33. stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk,
  34. stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn,
  35. stride_seq_idx_batch, stride_seq_idx_seqlen,
  36. # Meta-parameters
  37. IS_CAUSAL: tl.constexpr,
  38. dot_dtype: tl.constexpr,
  39. HAS_SEQ_IDX: tl.constexpr,
  40. BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
  41. ):
  42. pid_b = tl.program_id(axis=1)
  43. pid_ch = tl.program_id(axis=2)
  44. pid_c = pid_ch // ngroups
  45. pid_h = pid_ch - pid_c * ngroups
  46. num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
  47. pid_m = tl.program_id(axis=0) // num_pid_n
  48. pid_n = tl.program_id(axis=0) % num_pid_n
  49. if IS_CAUSAL:
  50. if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
  51. return
  52. a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
  53. b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head
  54. if HAS_SEQ_IDX:
  55. seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
  56. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  57. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  58. offs_k = tl.arange(0, BLOCK_SIZE_K)
  59. a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)
  60. b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)
  61. chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
  62. acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
  63. for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
  64. a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0).to(dot_dtype)
  65. b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0).to(dot_dtype)
  66. acc += tl.dot(a, b)
  67. a_ptrs += BLOCK_SIZE_K * stride_ak
  68. b_ptrs += BLOCK_SIZE_K * stride_bk
  69. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  70. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  71. if HAS_SEQ_IDX:
  72. chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
  73. seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
  74. seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2)
  75. acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
  76. out = acc.to(out_ptr.dtype.element_ty)
  77. out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head
  78. out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)
  79. tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))
  80. @triton.autotune(
  81. configs=[
  82. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8),
  83. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
  84. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
  85. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
  86. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
  87. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
  88. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),
  89. triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),
  90. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2),
  91. ],
  92. key=['chunk_size', 'K'],
  93. )
  94. @triton.jit
  95. def _bmm_chunk_bwd_kernel(
  96. # Pointers to matrices
  97. a_ptr, dout_ptr, db_ptr, res_ptr,
  98. # Matrix dimensions
  99. seqlen, chunk_size, K, ngroups,
  100. stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,
  101. stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_csize_m, stride_dout_csize_n,
  102. stride_db_batch, stride_db_seqlen, stride_db_head, stride_db_k,
  103. stride_res_batch, stride_res_seqlen, stride_res_head, stride_res_k,
  104. # Meta-parameters
  105. dot_dtype: tl.constexpr,
  106. HAS_RESIDUAL: tl.constexpr,
  107. BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_CS: tl.constexpr,
  108. ):
  109. pid_b = tl.program_id(axis=1)
  110. pid_ch = tl.program_id(axis=2)
  111. pid_c = pid_ch // ngroups
  112. pid_h = pid_ch - pid_c * ngroups
  113. num_pid_n = tl.cdiv(K, BLOCK_SIZE_N)
  114. pid_m = tl.program_id(axis=0) // num_pid_n
  115. pid_n = tl.program_id(axis=0) % num_pid_n
  116. a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
  117. dout_ptr += pid_b * stride_dout_batch + pid_c * stride_dout_chunk + pid_h * stride_dout_head
  118. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  119. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  120. offs_cs = tl.arange(0, BLOCK_SIZE_CS)
  121. dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize_n + offs_cs[None, :] * stride_dout_csize_m)
  122. a_ptrs = a_ptr + (offs_cs[:, None] * stride_a_seqlen + offs_n[None, :] * stride_ak)
  123. chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
  124. acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
  125. for cs in range(0, tl.cdiv(chunk_size_limit, BLOCK_SIZE_CS)):
  126. dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_cs[None, :] < chunk_size_limit - cs * BLOCK_SIZE_CS), other=0.0).to(dot_dtype)
  127. a = tl.load(a_ptrs, mask=(offs_cs[:, None] < chunk_size_limit - cs * BLOCK_SIZE_CS) & (offs_n[None, :] < K), other=0.0).to(dot_dtype)
  128. acc += tl.dot(dout, a)
  129. dout_ptrs += BLOCK_SIZE_CS * stride_dout_csize_m
  130. a_ptrs += BLOCK_SIZE_CS * stride_a_seqlen
  131. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  132. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  133. if HAS_RESIDUAL:
  134. res_ptr += pid_b * stride_res_batch + pid_c * chunk_size * stride_res_seqlen + pid_h * stride_res_head
  135. res_ptrs = res_ptr + (offs_m[:, None] * stride_res_seqlen + offs_n[None, :] * stride_res_k)
  136. res = tl.load(res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)).to(tl.float32)
  137. acc += res
  138. db = acc.to(db_ptr.dtype.element_ty)
  139. db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_h * stride_db_head
  140. db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_k)
  141. tl.store(db_ptrs, db, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K))
  142. def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):
  143. """
  144. Argument:
  145. a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
  146. b: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
  147. seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
  148. causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
  149. guaranteed to be correct.
  150. Return:
  151. out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
  152. """
  153. # Check constraints.
  154. has_groups = a.dim() == 4
  155. if not has_groups:
  156. batch, seqlen, k = a.shape
  157. else:
  158. batch, seqlen, ngroups, k = a.shape
  159. assert b.shape == a.shape
  160. if seq_idx is not None:
  161. assert seq_idx.shape == (batch, seqlen)
  162. if a.stride(-1) != 1 and a.stride(1) != 1:
  163. a = a.contiguous()
  164. if b.stride(-1) != 1 and b.stride(1) != 1:
  165. b = b.contiguous()
  166. nchunks = math.ceil(seqlen / chunk_size)
  167. # Allocates output.
  168. out_dtype = a.dtype if output_dtype is None else output_dtype
  169. out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size),
  170. device=a.device, dtype=out_dtype)
  171. dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else
  172. (tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32))
  173. grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']),
  174. batch, nchunks if not has_groups else nchunks * ngroups)
  175. with torch.cuda.device(a.device.index):
  176. _bmm_chunk_fwd_kernel[grid](
  177. a, b, out, seq_idx,
  178. seqlen, chunk_size, k, ngroups if has_groups else 1,
  179. a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),
  180. b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1),
  181. out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1),
  182. *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
  183. causal,
  184. dot_dtype,
  185. HAS_SEQ_IDX=seq_idx is not None,
  186. )
  187. return out
  188. def _bmm_chunk_bwd(a, dout, residual=None, out=None):
  189. """
  190. Argument:
  191. a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
  192. dout: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
  193. residual: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
  194. Return:
  195. out: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
  196. If there was seq_idx in the fwd pass, then dout[i, j] for seq_idx[i] != seq_idx[j] should already be
  197. zeroed out before calling this function.
  198. """
  199. # Check constraints.
  200. has_groups = a.dim() == 4
  201. if not has_groups:
  202. batch, seqlen, k = a.shape
  203. else:
  204. batch, seqlen, ngroups, k = a.shape
  205. nchunks, chunk_size = dout.shape[1], dout.shape[-1]
  206. if a.stride(-1) != 1 and a.stride(-2) != 1:
  207. a = a.contiguous()
  208. if dout.stride(-1) != 1 and dout.stride(-2) != 1:
  209. dout = dout.contiguous()
  210. if residual is not None:
  211. assert residual.shape == (batch, seqlen, k) if not has_groups else (batch, seqlen, ngroups, k)
  212. if residual.stride(-1) != 1 and residual.stride(1) != 1:
  213. residual = residual.contiguous()
  214. # Allocates output.
  215. if out is not None:
  216. assert out.shape == a.shape
  217. assert out.stride(-1) == 1 or out.stride(1) == 1
  218. else:
  219. out = torch.empty_like(a)
  220. dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or dout.dtype == torch.bfloat16 else
  221. (tl.float16 if a.dtype == torch.float16 or dout.dtype == torch.float16 else tl.float32))
  222. grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(k, META['BLOCK_SIZE_N']), batch,
  223. nchunks if not has_groups else nchunks * ngroups)
  224. residual_strides = ((residual.stride(0), residual.stride(1), 0 if not has_groups else residual.stride(2),
  225. residual.stride(-1))
  226. if residual is not None else (0, 0, 0, 0))
  227. with torch.cuda.device(a.device.index):
  228. _bmm_chunk_bwd_kernel[grid](
  229. a, dout, out, residual,
  230. seqlen, chunk_size, k, ngroups if has_groups else 1,
  231. a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),
  232. dout.stride(0), dout.stride(1), 0 if not has_groups else dout.stride(2), dout.stride(-2), dout.stride(-1),
  233. out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-1),
  234. residual_strides[0], residual_strides[1], residual_strides[2], residual_strides[3],
  235. dot_dtype,
  236. HAS_RESIDUAL=residual is not None,
  237. )
  238. return out