| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262 |
- # Copyright (c) 2024, Tri Dao, Albert Gu.
- """We want triton==2.1.0 or 2.2.0 for this
- """
- import math
- import torch
- import torch.nn.functional as F
- import triton
- import triton.language as tl
- from einops import rearrange, repeat
- def init_to_zero(names):
- return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
- @triton.autotune(
- configs=[
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
- ],
- key=['chunk_size', 'K', 'IS_CAUSAL'],
- )
- @triton.jit
- def _bmm_chunk_fwd_kernel(
- # Pointers to matrices
- a_ptr, b_ptr, out_ptr, seq_idx_ptr,
- # Matrix dimensions
- seqlen, chunk_size, K, ngroups,
- stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,
- stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk,
- stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn,
- stride_seq_idx_batch, stride_seq_idx_seqlen,
- # Meta-parameters
- IS_CAUSAL: tl.constexpr,
- dot_dtype: tl.constexpr,
- HAS_SEQ_IDX: tl.constexpr,
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
- ):
- pid_b = tl.program_id(axis=1)
- pid_ch = tl.program_id(axis=2)
- pid_c = pid_ch // ngroups
- pid_h = pid_ch - pid_c * ngroups
- num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
- pid_m = tl.program_id(axis=0) // num_pid_n
- pid_n = tl.program_id(axis=0) % num_pid_n
- if IS_CAUSAL:
- if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
- return
- a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
- b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head
- if HAS_SEQ_IDX:
- seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- offs_k = tl.arange(0, BLOCK_SIZE_K)
- a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)
- b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
- a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0).to(dot_dtype)
- b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0).to(dot_dtype)
- acc += tl.dot(a, b)
- a_ptrs += BLOCK_SIZE_K * stride_ak
- b_ptrs += BLOCK_SIZE_K * stride_bk
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- if HAS_SEQ_IDX:
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
- seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
- seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2)
- acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
- out = acc.to(out_ptr.dtype.element_ty)
- out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head
- out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)
- tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))
- @triton.autotune(
- configs=[
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2),
- ],
- key=['chunk_size', 'K'],
- )
- @triton.jit
- def _bmm_chunk_bwd_kernel(
- # Pointers to matrices
- a_ptr, dout_ptr, db_ptr, res_ptr,
- # Matrix dimensions
- seqlen, chunk_size, K, ngroups,
- stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,
- stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_csize_m, stride_dout_csize_n,
- stride_db_batch, stride_db_seqlen, stride_db_head, stride_db_k,
- stride_res_batch, stride_res_seqlen, stride_res_head, stride_res_k,
- # Meta-parameters
- dot_dtype: tl.constexpr,
- HAS_RESIDUAL: tl.constexpr,
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_CS: tl.constexpr,
- ):
- pid_b = tl.program_id(axis=1)
- pid_ch = tl.program_id(axis=2)
- pid_c = pid_ch // ngroups
- pid_h = pid_ch - pid_c * ngroups
- num_pid_n = tl.cdiv(K, BLOCK_SIZE_N)
- pid_m = tl.program_id(axis=0) // num_pid_n
- pid_n = tl.program_id(axis=0) % num_pid_n
- a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
- dout_ptr += pid_b * stride_dout_batch + pid_c * stride_dout_chunk + pid_h * stride_dout_head
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- offs_cs = tl.arange(0, BLOCK_SIZE_CS)
- dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize_n + offs_cs[None, :] * stride_dout_csize_m)
- a_ptrs = a_ptr + (offs_cs[:, None] * stride_a_seqlen + offs_n[None, :] * stride_ak)
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
- for cs in range(0, tl.cdiv(chunk_size_limit, BLOCK_SIZE_CS)):
- dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_cs[None, :] < chunk_size_limit - cs * BLOCK_SIZE_CS), other=0.0).to(dot_dtype)
- a = tl.load(a_ptrs, mask=(offs_cs[:, None] < chunk_size_limit - cs * BLOCK_SIZE_CS) & (offs_n[None, :] < K), other=0.0).to(dot_dtype)
- acc += tl.dot(dout, a)
- dout_ptrs += BLOCK_SIZE_CS * stride_dout_csize_m
- a_ptrs += BLOCK_SIZE_CS * stride_a_seqlen
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- if HAS_RESIDUAL:
- res_ptr += pid_b * stride_res_batch + pid_c * chunk_size * stride_res_seqlen + pid_h * stride_res_head
- res_ptrs = res_ptr + (offs_m[:, None] * stride_res_seqlen + offs_n[None, :] * stride_res_k)
- res = tl.load(res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)).to(tl.float32)
- acc += res
- db = acc.to(db_ptr.dtype.element_ty)
- db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_h * stride_db_head
- db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_k)
- tl.store(db_ptrs, db, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K))
- def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):
- """
- Argument:
- a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
- b: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
- seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
- causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
- guaranteed to be correct.
- Return:
- out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
- """
- # Check constraints.
- has_groups = a.dim() == 4
- if not has_groups:
- batch, seqlen, k = a.shape
- else:
- batch, seqlen, ngroups, k = a.shape
- assert b.shape == a.shape
- if seq_idx is not None:
- assert seq_idx.shape == (batch, seqlen)
- if a.stride(-1) != 1 and a.stride(1) != 1:
- a = a.contiguous()
- if b.stride(-1) != 1 and b.stride(1) != 1:
- b = b.contiguous()
- nchunks = math.ceil(seqlen / chunk_size)
- # Allocates output.
- out_dtype = a.dtype if output_dtype is None else output_dtype
- out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size),
- device=a.device, dtype=out_dtype)
- dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else
- (tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32))
- grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']),
- batch, nchunks if not has_groups else nchunks * ngroups)
- with torch.cuda.device(a.device.index):
- _bmm_chunk_fwd_kernel[grid](
- a, b, out, seq_idx,
- seqlen, chunk_size, k, ngroups if has_groups else 1,
- a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),
- b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1),
- out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1),
- *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
- causal,
- dot_dtype,
- HAS_SEQ_IDX=seq_idx is not None,
- )
- return out
- def _bmm_chunk_bwd(a, dout, residual=None, out=None):
- """
- Argument:
- a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
- dout: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
- residual: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
- Return:
- out: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
- If there was seq_idx in the fwd pass, then dout[i, j] for seq_idx[i] != seq_idx[j] should already be
- zeroed out before calling this function.
- """
- # Check constraints.
- has_groups = a.dim() == 4
- if not has_groups:
- batch, seqlen, k = a.shape
- else:
- batch, seqlen, ngroups, k = a.shape
- nchunks, chunk_size = dout.shape[1], dout.shape[-1]
- if a.stride(-1) != 1 and a.stride(-2) != 1:
- a = a.contiguous()
- if dout.stride(-1) != 1 and dout.stride(-2) != 1:
- dout = dout.contiguous()
- if residual is not None:
- assert residual.shape == (batch, seqlen, k) if not has_groups else (batch, seqlen, ngroups, k)
- if residual.stride(-1) != 1 and residual.stride(1) != 1:
- residual = residual.contiguous()
- # Allocates output.
- if out is not None:
- assert out.shape == a.shape
- assert out.stride(-1) == 1 or out.stride(1) == 1
- else:
- out = torch.empty_like(a)
- dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or dout.dtype == torch.bfloat16 else
- (tl.float16 if a.dtype == torch.float16 or dout.dtype == torch.float16 else tl.float32))
- grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(k, META['BLOCK_SIZE_N']), batch,
- nchunks if not has_groups else nchunks * ngroups)
- residual_strides = ((residual.stride(0), residual.stride(1), 0 if not has_groups else residual.stride(2),
- residual.stride(-1))
- if residual is not None else (0, 0, 0, 0))
- with torch.cuda.device(a.device.index):
- _bmm_chunk_bwd_kernel[grid](
- a, dout, out, residual,
- seqlen, chunk_size, k, ngroups if has_groups else 1,
- a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),
- dout.stride(0), dout.stride(1), 0 if not has_groups else dout.stride(2), dout.stride(-2), dout.stride(-1),
- out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-1),
- residual_strides[0], residual_strides[1], residual_strides[2], residual_strides[3],
- dot_dtype,
- HAS_RESIDUAL=residual is not None,
- )
- return out
|