flash_attention.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import torch
  2. from ..modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
  3. from ..utils import logging
  4. logger = logging.get_logger(__name__)
  5. _use_top_left_mask = flash_attn_supports_top_left_mask()
  6. def get_target_dtype(query: torch.Tensor, module: torch.nn.Module) -> torch.dtype:
  7. """If the query is in float32, return a target dtype compatible with flash attention. Return None otherwise."""
  8. if query.dtype == torch.float32:
  9. if torch.is_autocast_enabled("cuda"):
  10. return torch.get_autocast_dtype("cuda")
  11. # Handle the case where the model is quantized
  12. elif hasattr(module.config, "_is_quantized"):
  13. return module.config.dtype
  14. else:
  15. return next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype
  16. return None
  17. def flash_attention_forward(
  18. module: torch.nn.Module,
  19. query: torch.Tensor,
  20. key: torch.Tensor,
  21. value: torch.Tensor,
  22. attention_mask: torch.Tensor | None,
  23. dropout: float = 0.0,
  24. scaling: float | None = None,
  25. sliding_window: int | None = None,
  26. softcap: float | None = None,
  27. is_causal: bool | None = None,
  28. **kwargs,
  29. ) -> tuple[torch.Tensor, None]:
  30. if kwargs.get("output_attentions", False):
  31. logger.warning_once(
  32. "Flash Attention does not support `output_attentions=True`."
  33. " Please set your attention to `eager` if you want any of these features."
  34. )
  35. # This is before the transpose
  36. seq_len = query.shape[2]
  37. if any(dim == 0 for dim in query.shape):
  38. raise ValueError(
  39. "Tensor query has shape with a zero dimension.\n"
  40. "FlashAttention does not support inputs with dim=0.\n"
  41. "Please check your input shapes or use SDPA instead."
  42. )
  43. # FA2 uses non-transposed inputs
  44. query = query.transpose(1, 2)
  45. key = key.transpose(1, 2)
  46. value = value.transpose(1, 2)
  47. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  48. # therefore the input hidden states gets silently casted in float32. Hence, we need
  49. # cast them back in the correct dtype just to be sure everything works as expected.
  50. # This might slowdown training & inference so it is recommended to not cast the LayerNorms
  51. # in fp32. (usually our RMSNorm modules handle it correctly)
  52. target_dtype = get_target_dtype(query, module)
  53. # Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented
  54. is_causal = is_causal if is_causal is not None else module.is_causal
  55. attn_output = _flash_attention_forward(
  56. query,
  57. key,
  58. value,
  59. attention_mask,
  60. query_length=seq_len,
  61. is_causal=is_causal,
  62. dropout=dropout,
  63. softmax_scale=scaling,
  64. sliding_window=sliding_window,
  65. softcap=softcap,
  66. use_top_left_mask=_use_top_left_mask,
  67. target_dtype=target_dtype,
  68. attn_implementation=module.config._attn_implementation,
  69. layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None,
  70. **kwargs,
  71. )
  72. return attn_output, None