| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169 |
- # Copyright (c) 2024, Tri Dao, Albert Gu.
- import torch
- import triton
- import triton.language as tl
- @triton.autotune(
- configs=[
- triton.Config({'BLOCK_N': 32}),
- triton.Config({'BLOCK_N': 64}),
- triton.Config({'BLOCK_N': 128}),
- triton.Config({'BLOCK_N': 256}),
- triton.Config({'BLOCK_N': 512}),
- triton.Config({'BLOCK_N': 1024}),
- ],
- key=['ncols'],
- )
- @triton.jit
- def _swiglu_fwd_kernel(
- X,
- Y,
- OUT,
- stride_x_row, # how much to increase the pointer when moving by 1 row
- stride_y_row,
- stride_out_row,
- ncols,
- BLOCK_N: tl.constexpr,
- ):
- # Map the program id to the row of X and Y it should compute.
- row = tl.program_id(0)
- start_col = tl.program_id(1) * BLOCK_N
- X += row * stride_x_row
- Y += row * stride_y_row
- OUT += row * stride_out_row
- cols = start_col + tl.arange(0, BLOCK_N)
- x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
- y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
- out = x * tl.sigmoid(x) * y
- tl.store(OUT + cols, out, mask=cols < ncols)
- def _swiglu_fwd(xy, out=None):
- if xy.stride(-1) != 1:
- xy = xy.contiguous()
- batch_shape = xy.shape[:-1]
- xy = xy.reshape(-1, xy.shape[-1])
- x, y = xy.chunk(2, dim=-1)
- if out is None:
- out = torch.empty_like(x)
- else:
- out = out.reshape(-1, out.shape[-1])
- assert out.shape == x.shape
- assert out.stride(-1) == 1
- M, N = x.shape
- grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
- with torch.cuda.device(x.device.index):
- _swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N)
- return out.reshape(*batch_shape, out.shape[-1])
- @triton.autotune(
- configs=[
- triton.Config({'BLOCK_N': 32}),
- triton.Config({'BLOCK_N': 64}),
- triton.Config({'BLOCK_N': 128}),
- triton.Config({'BLOCK_N': 256}),
- triton.Config({'BLOCK_N': 512}),
- triton.Config({'BLOCK_N': 1024}),
- ],
- key=['ncols'],
- )
- @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["OUT"] is not None})
- @triton.jit
- def _swiglu_bwd_kernel(
- X,
- Y,
- DOUT,
- OUT,
- DX,
- DY,
- stride_x_row, # how much to increase the pointer when moving by 1 row
- stride_y_row,
- stride_dout_row,
- stride_out_row,
- stride_dx_row,
- stride_dy_row,
- ncols,
- BLOCK_N: tl.constexpr,
- RECOMPUTE_OUTPUT: tl.constexpr,
- ):
- # Map the program id to the row of X and Y it should compute.
- row = tl.program_id(0)
- start_col = tl.program_id(1) * BLOCK_N
- X += row * stride_x_row
- Y += row * stride_y_row
- DOUT += row * stride_dout_row
- if RECOMPUTE_OUTPUT:
- OUT += row * stride_out_row
- DX += row * stride_dx_row
- DY += row * stride_dy_row
- cols = start_col + tl.arange(0, BLOCK_N)
- x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
- y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
- dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32)
- x_sigmoid = tl.sigmoid(x)
- dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout
- dy = x * x_sigmoid * dout
- tl.store(DX + cols, dx, mask=cols < ncols)
- tl.store(DY + cols, dy, mask=cols < ncols)
- if RECOMPUTE_OUTPUT:
- out = x * x_sigmoid * y
- tl.store(OUT + cols, out, mask=cols < ncols)
- def _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None):
- if xy.stride(-1) != 1:
- xy = xy.contiguous()
- if dout.stride(-1) != 1:
- dout = dout.contiguous()
- batch_shape = xy.shape[:-1]
- xy = xy.reshape(-1, xy.shape[-1])
- x, y = xy.chunk(2, dim=-1)
- dout = dout.reshape(-1, dout.shape[-1])
- assert dout.shape == x.shape
- if dxy is None:
- dxy = torch.empty_like(xy)
- else:
- dxy = dxy.reshape(-1, dxy.shape[-1])
- assert dxy.shape == xy.shape
- dx, dy = dxy.chunk(2, dim=-1)
- assert dx.stride(-1) == 1
- assert dy.stride(-1) == 1
- if recompute_output:
- if out is None:
- out = torch.empty_like(x)
- else:
- out = out.reshape(-1, out.shape[-1])
- assert out.shape == x.shape
- assert out.stride(-1) == 1
- M, N = x.shape
- grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
- with torch.cuda.device(x.device.index):
- _swiglu_bwd_kernel[grid](x, y, dout, out if recompute_output else None, dx, dy,
- x.stride(0), y.stride(0), dout.stride(0),
- out.stride(0) if recompute_output else 0,
- dx.stride(0), dy.stride(0),
- N)
- if not recompute_output:
- return dxy.reshape(*batch_shape, dxy.shape[-1])
- else:
- return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1])
- class SwiGLU(torch.autograd.Function):
- @staticmethod
- def forward(ctx, xy):
- ctx.save_for_backward(xy)
- return _swiglu_fwd(xy)
- @staticmethod
- def backward(ctx, dout):
- xy, = ctx.saved_tensors
- return _swiglu_bwd(xy, dout)
- swiglu = SwiGLU.apply
|