ssd_minimal.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # Copyright (c) 2024, Albert Gu and Tri Dao.
  2. """Minimal implementation of SSD.
  3. This is the same as Listing 1 from the paper.
  4. """
  5. import torch
  6. import torch.nn.functional as F
  7. from einops import rearrange, repeat
  8. from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
  9. def segsum_unstable(x):
  10. """Naive segment sum calculation."""
  11. T = x.size(-1)
  12. x_cumsum = torch.cumsum(x, dim=-1)
  13. x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
  14. mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
  15. x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
  16. return x_segsum
  17. def segsum(x):
  18. """More stable segment sum calculation."""
  19. T = x.size(-1)
  20. x = repeat(x, "... d -> ... d e", e=T)
  21. mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
  22. x = x.masked_fill(~mask, 0)
  23. x_segsum = torch.cumsum(x, dim=-2)
  24. mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
  25. x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
  26. return x_segsum
  27. def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
  28. """
  29. Arguments:
  30. X: (batch, length, n_heads, d_head)
  31. A: (batch, length, n_heads)
  32. B: (batch, length, n_heads, d_state)
  33. C: (batch, length, n_heads, d_state)
  34. Return:
  35. Y: (batch, length, n_heads, d_head)
  36. """
  37. assert X.dtype == A.dtype == B.dtype == C.dtype
  38. assert X.shape[1] % block_len == 0
  39. # Rearrange into blocks/chunks
  40. X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)]
  41. A = rearrange(A, "b c l h -> b h c l")
  42. A_cumsum = torch.cumsum(A, dim=-1)
  43. # 1. Compute the output for each intra-chunk (diagonal blocks)
  44. L = torch.exp(segsum(A))
  45. Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
  46. # 2. Compute the state for each intra-chunk
  47. # (right term of low-rank factorization of off-diagonal blocks; B terms)
  48. decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
  49. states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
  50. # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
  51. # (middle term of factorization of off-diag blocks; A terms)
  52. if initial_states is None:
  53. initial_states = torch.zeros_like(states[:, :1])
  54. states = torch.cat([initial_states, states], dim=1)
  55. decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
  56. new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
  57. states, final_state = new_states[:, :-1], new_states[:, -1]
  58. # 4. Compute state -> output conversion per chunk
  59. # (left term of low-rank factorization of off-diagonal blocks; C terms)
  60. state_decay_out = torch.exp(A_cumsum)
  61. Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)
  62. # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
  63. Y = rearrange(Y_diag+Y_off, "b c l h p -> b (c l) h p")
  64. return Y, final_state
  65. # Simple test
  66. def test_correctness():
  67. torch.manual_seed(42)
  68. ## Dimensions
  69. # Denoted (B, T, Q, D, P) in the paper
  70. batch, seqlen, chunk_size, dim, headdim = 1, 2048, 64, 2048, 64
  71. nheads = dim // headdim # (H) in the paper
  72. ngroups = 1 # (G) in the paper
  73. dstate = 64 # (N) in the paper
  74. dtype = torch.float32
  75. device = "cuda"
  76. x = torch.randn(batch, seqlen, nheads, headdim, dtype=dtype, device=device)
  77. dt = F.softplus(torch.randn(batch, seqlen, nheads, dtype=torch.float32, device=device) - 4).requires_grad_()
  78. A = (-torch.exp(torch.rand(nheads, dtype=torch.float32, device=device))).requires_grad_()
  79. B = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)
  80. C = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)
  81. D = torch.randn(nheads, dtype=dtype, device=device)
  82. # Comparing fused version and minimal version
  83. y = mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None)
  84. y_min, _ = ssd_minimal_discrete(x*dt.unsqueeze(-1), A*dt, B, C, chunk_size)