sdpa_attention.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import torch
  2. from ..utils import is_torch_npu_available, is_torch_xpu_available, logging
  3. from ..utils.import_utils import is_torch_greater_or_equal
  4. logger = logging.get_logger(__name__)
  5. _is_torch_greater_or_equal_than_2_5 = is_torch_greater_or_equal("2.5", accept_dev=True)
  6. _is_torch_greater_or_equal_than_2_8 = is_torch_greater_or_equal("2.8", accept_dev=True)
  7. _is_torch_xpu_available = is_torch_xpu_available()
  8. _is_torch_npu_available = is_torch_npu_available()
  9. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  10. """
  11. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  12. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  13. """
  14. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  15. if n_rep == 1:
  16. return hidden_states
  17. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  18. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  19. def use_gqa_in_sdpa(attention_mask: torch.Tensor | None, key: torch.Tensor) -> bool:
  20. # GQA can only be used under the following conditions
  21. # 1.cuda or Ascend NPU
  22. # - torch version >= 2.5
  23. # - attention_mask is None (otherwise it will fall back to the math kernel)
  24. # 2.xpu
  25. # - torch version >= 2.8
  26. if _is_torch_xpu_available:
  27. return _is_torch_greater_or_equal_than_2_8
  28. return _is_torch_greater_or_equal_than_2_5 and attention_mask is None
  29. def sdpa_attention_forward(
  30. module: torch.nn.Module,
  31. query: torch.Tensor,
  32. key: torch.Tensor,
  33. value: torch.Tensor,
  34. attention_mask: torch.Tensor | None,
  35. dropout: float = 0.0,
  36. scaling: float | None = None,
  37. is_causal: bool | None = None,
  38. **kwargs,
  39. ) -> tuple[torch.Tensor, None]:
  40. if kwargs.get("output_attentions", False):
  41. logger.warning_once(
  42. "`sdpa` attention does not support `output_attentions=True`."
  43. " Please set your attention to `eager` if you want any of these features."
  44. )
  45. sdpa_kwargs = {}
  46. if hasattr(module, "num_key_value_groups"):
  47. if not use_gqa_in_sdpa(attention_mask, key):
  48. key = repeat_kv(key, module.num_key_value_groups)
  49. value = repeat_kv(value, module.num_key_value_groups)
  50. else:
  51. sdpa_kwargs = {"enable_gqa": True}
  52. # Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented
  53. is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
  54. # SDPA's Flash Attention (and cuDNN) kernels rely on the `is_causal` flag. However, there are certain conditions:
  55. # - Not in decoding phase (otherwise we want full attention on the single query token)
  56. # - Attention mask is not to be provided (even if it is a causal pattern)
  57. # - Internally, we marked this as compatible with causal, i.e. it is a decoder attention type
  58. #
  59. # Quirks on the conditionals:
  60. # - We avoid inline passing this to the SDPA function directly to support both torch.compile's dynamic shapes and
  61. # full graph options. Otherwise, dynamic shapes are prevented from compiling.
  62. # - It is important to check first for the shape, otherwise compile will fail with
  63. # `argument 'is_causal' must be bool, not SymBool`.
  64. is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
  65. # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
  66. # We convert it to a bool for the SDPA kernel that only accepts bools.
  67. if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
  68. is_causal = is_causal.item()
  69. # When `is_causal = False` and the `attention_mask` is not of boolean type, the Ascend NPU's SDPA interface cannot utilize the FlashAttentionScore operator,
  70. # and falls back to small-operator concatenation. To invoke the FlashAttentionScore, the attention_mask must be converted to boolean type.
  71. # This adaptation ensures the `attention_mask` meets the requirement for using FlashAttentionScore.
  72. if _is_torch_npu_available:
  73. if attention_mask is not None and attention_mask.dtype != torch.bool:
  74. # Convert to boolean type, making sdpa to force call FlashAttentionScore to improve performance.
  75. attention_mask = torch.logical_not(attention_mask.bool()).to(query.device)
  76. attn_output = torch.nn.functional.scaled_dot_product_attention(
  77. query,
  78. key,
  79. value,
  80. attn_mask=attention_mask,
  81. dropout_p=dropout,
  82. scale=scaling,
  83. is_causal=is_causal,
  84. **sdpa_kwargs,
  85. )
  86. attn_output = attn_output.transpose(1, 2).contiguous()
  87. return attn_output, None