k_activations.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. # Copyright (c) 2024, Tri Dao, Albert Gu.
  2. import torch
  3. import triton
  4. import triton.language as tl
  5. @triton.autotune(
  6. configs=[
  7. triton.Config({'BLOCK_N': 32}),
  8. triton.Config({'BLOCK_N': 64}),
  9. triton.Config({'BLOCK_N': 128}),
  10. triton.Config({'BLOCK_N': 256}),
  11. triton.Config({'BLOCK_N': 512}),
  12. triton.Config({'BLOCK_N': 1024}),
  13. ],
  14. key=['ncols'],
  15. )
  16. @triton.jit
  17. def _swiglu_fwd_kernel(
  18. X,
  19. Y,
  20. OUT,
  21. stride_x_row, # how much to increase the pointer when moving by 1 row
  22. stride_y_row,
  23. stride_out_row,
  24. ncols,
  25. BLOCK_N: tl.constexpr,
  26. ):
  27. # Map the program id to the row of X and Y it should compute.
  28. row = tl.program_id(0)
  29. start_col = tl.program_id(1) * BLOCK_N
  30. X += row * stride_x_row
  31. Y += row * stride_y_row
  32. OUT += row * stride_out_row
  33. cols = start_col + tl.arange(0, BLOCK_N)
  34. x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
  35. y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
  36. out = x * tl.sigmoid(x) * y
  37. tl.store(OUT + cols, out, mask=cols < ncols)
  38. def _swiglu_fwd(xy, out=None):
  39. if xy.stride(-1) != 1:
  40. xy = xy.contiguous()
  41. batch_shape = xy.shape[:-1]
  42. xy = xy.reshape(-1, xy.shape[-1])
  43. x, y = xy.chunk(2, dim=-1)
  44. if out is None:
  45. out = torch.empty_like(x)
  46. else:
  47. out = out.reshape(-1, out.shape[-1])
  48. assert out.shape == x.shape
  49. assert out.stride(-1) == 1
  50. M, N = x.shape
  51. grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
  52. with torch.cuda.device(x.device.index):
  53. _swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N)
  54. return out.reshape(*batch_shape, out.shape[-1])
  55. @triton.autotune(
  56. configs=[
  57. triton.Config({'BLOCK_N': 32}),
  58. triton.Config({'BLOCK_N': 64}),
  59. triton.Config({'BLOCK_N': 128}),
  60. triton.Config({'BLOCK_N': 256}),
  61. triton.Config({'BLOCK_N': 512}),
  62. triton.Config({'BLOCK_N': 1024}),
  63. ],
  64. key=['ncols'],
  65. )
  66. @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["OUT"] is not None})
  67. @triton.jit
  68. def _swiglu_bwd_kernel(
  69. X,
  70. Y,
  71. DOUT,
  72. OUT,
  73. DX,
  74. DY,
  75. stride_x_row, # how much to increase the pointer when moving by 1 row
  76. stride_y_row,
  77. stride_dout_row,
  78. stride_out_row,
  79. stride_dx_row,
  80. stride_dy_row,
  81. ncols,
  82. BLOCK_N: tl.constexpr,
  83. RECOMPUTE_OUTPUT: tl.constexpr,
  84. ):
  85. # Map the program id to the row of X and Y it should compute.
  86. row = tl.program_id(0)
  87. start_col = tl.program_id(1) * BLOCK_N
  88. X += row * stride_x_row
  89. Y += row * stride_y_row
  90. DOUT += row * stride_dout_row
  91. if RECOMPUTE_OUTPUT:
  92. OUT += row * stride_out_row
  93. DX += row * stride_dx_row
  94. DY += row * stride_dy_row
  95. cols = start_col + tl.arange(0, BLOCK_N)
  96. x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
  97. y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
  98. dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32)
  99. x_sigmoid = tl.sigmoid(x)
  100. dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout
  101. dy = x * x_sigmoid * dout
  102. tl.store(DX + cols, dx, mask=cols < ncols)
  103. tl.store(DY + cols, dy, mask=cols < ncols)
  104. if RECOMPUTE_OUTPUT:
  105. out = x * x_sigmoid * y
  106. tl.store(OUT + cols, out, mask=cols < ncols)
  107. def _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None):
  108. if xy.stride(-1) != 1:
  109. xy = xy.contiguous()
  110. if dout.stride(-1) != 1:
  111. dout = dout.contiguous()
  112. batch_shape = xy.shape[:-1]
  113. xy = xy.reshape(-1, xy.shape[-1])
  114. x, y = xy.chunk(2, dim=-1)
  115. dout = dout.reshape(-1, dout.shape[-1])
  116. assert dout.shape == x.shape
  117. if dxy is None:
  118. dxy = torch.empty_like(xy)
  119. else:
  120. dxy = dxy.reshape(-1, dxy.shape[-1])
  121. assert dxy.shape == xy.shape
  122. dx, dy = dxy.chunk(2, dim=-1)
  123. assert dx.stride(-1) == 1
  124. assert dy.stride(-1) == 1
  125. if recompute_output:
  126. if out is None:
  127. out = torch.empty_like(x)
  128. else:
  129. out = out.reshape(-1, out.shape[-1])
  130. assert out.shape == x.shape
  131. assert out.stride(-1) == 1
  132. M, N = x.shape
  133. grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
  134. with torch.cuda.device(x.device.index):
  135. _swiglu_bwd_kernel[grid](x, y, dout, out if recompute_output else None, dx, dy,
  136. x.stride(0), y.stride(0), dout.stride(0),
  137. out.stride(0) if recompute_output else 0,
  138. dx.stride(0), dy.stride(0),
  139. N)
  140. if not recompute_output:
  141. return dxy.reshape(*batch_shape, dxy.shape[-1])
  142. else:
  143. return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1])
  144. class SwiGLU(torch.autograd.Function):
  145. @staticmethod
  146. def forward(ctx, xy):
  147. ctx.save_for_backward(xy)
  148. return _swiglu_fwd(xy)
  149. @staticmethod
  150. def backward(ctx, dout):
  151. xy, = ctx.saved_tensors
  152. return _swiglu_bwd(xy, dout)
  153. swiglu = SwiGLU.apply