flash_paged.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import torch
  2. from ..generation.continuous_batching import PagedAttentionCache
  3. from ..modeling_flash_attention_utils import lazy_import_paged_flash_attention
  4. @torch.compiler.disable
  5. def paged_attention_forward(
  6. module: torch.nn.Module,
  7. q: torch.Tensor,
  8. k: torch.Tensor,
  9. v: torch.Tensor,
  10. attention_mask: torch.Tensor | None, # Unused in flash
  11. cache: PagedAttentionCache,
  12. cu_seq_lens_q: torch.Tensor,
  13. cu_seq_lens_k: torch.Tensor | dict[str, torch.Tensor],
  14. max_seqlen_q: int,
  15. max_seqlen_k: int | dict[str, int],
  16. block_table: torch.Tensor | None,
  17. **kwargs,
  18. ) -> tuple[torch.Tensor, None]:
  19. """Performs the forward pass of attention with paged key-value cache. This function handles the cache updates and
  20. performs the attention computation. For decode-only batches (when block_table is provided), uses
  21. `flash_attn_with_kvcache` for fused attention + cache update. Otherwise uses `flash_attn_varlen_func`.
  22. See the [paged attention guide](https://huggingface.co/docs/transformers/en/paged_attention) for more details.
  23. Args:
  24. q: (1, nheads, total_q, headdim), where total_q = total number of query tokens in the batch.
  25. k: (1, nheads_k, total_k, headdim), where total_k = total number of key tokens in the batch.
  26. v: (1, nheads_k, total_k, headdim), where total_k = total number of key tokens in the batch.
  27. cu_seq_lens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  28. of the sequences in the batch, used to index into q.
  29. cu_seq_lens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  30. of the sequences in the batch, used to index into kv.
  31. max_seqlen_q: int. Maximum query sequence length in the batch.
  32. max_seqlen_k: int. Maximum key sequence length in the batch.
  33. block_table: (num_groups, batch_size, max_blocks_per_seq), dtype int32. Block table for paged KV cache.
  34. If provided, uses flash_attn_with_kvcache for fused attention + cache update. For each request, the block
  35. table is a vector of size (max_blocks_per_seq,) with indices indicating the physical location of the cache
  36. to read from and write to. The kernel, using the cache_seqlens for that request, knows how much cache to
  37. read and dispatches the read using the block table. Same for the write. If a request has fewer than
  38. max_blocks_per_seq blocks, the block table is padded with -1s to indicate that the block is not allocated.
  39. """
  40. # Retrieve the flash attention functions
  41. flash_attn_varlen_func, flash_attn_with_kvcache = lazy_import_paged_flash_attention(
  42. module.config._attn_implementation
  43. )
  44. # Retrieve the cumulative sequence lengths for the current layer
  45. sliding_window = (-1, -1) if not getattr(module, "sliding_window", False) else (module.sliding_window - 1, 0)
  46. layer_type = "full_attention" if sliding_window == (-1, -1) else "sliding_attention"
  47. if isinstance(cu_seq_lens_k, dict):
  48. cu_seq_lens_k = cu_seq_lens_k[layer_type]
  49. max_seqlen_k = max_seqlen_k[layer_type]
  50. # If no block table is provided, use flash_attn_varlen_func with read/write indices
  51. if block_table is None:
  52. # .update changes the shape of k and v from [1, num_kv_heads, seqlen_kv, head_dim] to [-1, num_kv_heads, head_dim]
  53. k, v = cache.update(
  54. key_states=k,
  55. value_states=v,
  56. layer_idx=module.layer_idx,
  57. read_index=kwargs["read_index"],
  58. write_index=kwargs["write_index"],
  59. )
  60. custom_kwargs = {"s_aux": kwargs.get("s_aux")} if "s_aux" in kwargs else {}
  61. attn_output = flash_attn_varlen_func(
  62. q.transpose(1, 2).squeeze(0).contiguous(),
  63. k.contiguous(),
  64. v.contiguous(),
  65. cu_seq_lens_q.to(torch.int32),
  66. cu_seq_lens_k.to(torch.int32).clone(),
  67. max_seqlen_q,
  68. max_seqlen_k,
  69. softmax_scale=module.scaling,
  70. causal=True, # kind of a must, it automatically aligns the mask for q < k
  71. window_size=sliding_window, # -1 means infinite context window
  72. **custom_kwargs,
  73. )
  74. if isinstance(attn_output, tuple):
  75. attn_output = attn_output[0]
  76. # Otherwise, use flash_attn_with_kvcache which updates the cache in-place and computes attention
  77. else:
  78. # Get layer group index for this layer
  79. group_idx, layer_idx_in_group = cache.layer_index_to_group_indices[module.layer_idx]
  80. # KV cache shape: [num_pages, num_kv_heads, head_dim] -> [num_blocks, block_size, num_kv_heads, head_dim]
  81. k_cache = cache.key_cache[layer_idx_in_group].view(
  82. -1, cache.block_size, cache.num_key_value_heads, cache.head_dim
  83. )
  84. v_cache = cache.value_cache[layer_idx_in_group].view(
  85. -1, cache.block_size, cache.num_key_value_heads, cache.head_dim
  86. )
  87. # Reshape Q, K, V from [1, num_*_heads, batch_size, head_dim] to [batch_size, 1, num_*_heads, head_dim]
  88. q = q.permute(2, 0, 1, 3).contiguous()
  89. k = k.permute(2, 0, 1, 3).contiguous()
  90. v = v.permute(2, 0, 1, 3).contiguous()
  91. # Compute cache_seqlens from cu_seq_lens_k (current cache length BEFORE adding new tokens)
  92. # cu_seq_lens_k is cumulative, so seqlens[i] = cu_seq_lens_k[i+1] - cu_seq_lens_k[i] - 1 (subtract 1 for the new token)
  93. batch_size = k.size(0)
  94. cache_seqlens = (cu_seq_lens_k[1 : batch_size + 1] - cu_seq_lens_k[:batch_size] - 1).to(torch.int32)
  95. # The arg name for the block table is not the same in VLLM's kernel and Tri Dao's kernel, so we need to parse it
  96. flash_kwargs = {cache.get_block_table_key(flash_attn_with_kvcache): block_table[group_idx]}
  97. if "s_aux" in kwargs:
  98. flash_kwargs["s_aux"] = kwargs["s_aux"] # this is only available in VLLM's FA3
  99. # Call flash_attn_with_kvcache - this updates cache in-place and computes attention
  100. attn_output = flash_attn_with_kvcache(
  101. q=q,
  102. k_cache=k_cache,
  103. v_cache=v_cache,
  104. k=k,
  105. v=v,
  106. cache_seqlens=cache_seqlens,
  107. softmax_scale=module.scaling,
  108. causal=True,
  109. window_size=sliding_window,
  110. **flash_kwargs,
  111. )
  112. if isinstance(attn_output, tuple):
  113. attn_output = attn_output[0]
  114. # Reshape output from [batch_size, 1, num_heads, head_dim] to [batch_size, num_heads, head_dim]
  115. attn_output = attn_output.squeeze(1)
  116. return attn_output, None