npu_flash_attention.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. # Licensed under the Apache License, Version 2.0 (the "License");
  2. # you may not use this file except in compliance with the License.
  3. # You may obtain a copy of the License at
  4. #
  5. # http://www.apache.org/licenses/LICENSE-2.0
  6. #
  7. # Unless required by applicable law or agreed to in writing, software
  8. # distributed under the License is distributed on an "AS IS" BASIS,
  9. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. # See the License for the specific language governing permissions and
  11. # limitations under the License.
  12. import math
  13. import os
  14. import torch
  15. from ..utils.import_utils import is_torch_npu_available
  16. if is_torch_npu_available():
  17. from torch_npu import npu_fusion_attention
  18. # FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default.
  19. # Set environment variable `NPU_FA2_SPARSE_MODE` to 2 when using top-left aligned causal mask.
  20. TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE = 2
  21. DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE = 3
  22. SPARSE_MODE = int(os.getenv("NPU_FA2_SPARSE_MODE", default=DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE))
  23. if SPARSE_MODE not in [TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE, DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE]:
  24. raise ValueError(
  25. "Environment variable `NPU_FA2_SPARSE_MODE` can only be set as 2 (top-left aligned causal mask) "
  26. "or 3 (down-right aligned causal mask)."
  27. )
  28. ATTN_MASK_NPU_CACHE = {}
  29. def get_attn_mask_npu(device):
  30. """Get or create attention mask for the specified device."""
  31. if device not in ATTN_MASK_NPU_CACHE:
  32. ATTN_MASK_NPU_CACHE[device] = torch.triu(torch.ones([2048, 2048], device=device), diagonal=1).bool()
  33. return ATTN_MASK_NPU_CACHE[device]
  34. def is_npu_fa2_top_left_aligned_causal_mask():
  35. return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE if is_torch_npu_available() else False
  36. def npu_flash_attn_func(
  37. q,
  38. k,
  39. v,
  40. dropout_p=0.0,
  41. softmax_scale=None,
  42. causal=False,
  43. **kwargs,
  44. ):
  45. keep_prob = 1.0 - dropout_p
  46. if softmax_scale is None:
  47. softmax_scale = 1.0 / math.sqrt(q.shape[-1])
  48. if not causal:
  49. head_num = q.shape[2]
  50. output = npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0]
  51. else:
  52. attn_mask_npu = get_attn_mask_npu(q.device)
  53. head_num = q.shape[2]
  54. output = npu_fusion_attention(
  55. q,
  56. k,
  57. v,
  58. head_num,
  59. "BSND",
  60. keep_prob=keep_prob,
  61. scale=softmax_scale,
  62. atten_mask=attn_mask_npu,
  63. sparse_mode=SPARSE_MODE,
  64. )[0]
  65. return output
  66. def npu_flash_attn_varlen_func(
  67. q,
  68. k,
  69. v,
  70. cu_seqlens_q,
  71. cu_seqlens_k,
  72. max_seqlen_q=None, # defined for aligning params order with corresponding function in `flash-attn`
  73. max_seqlen_k=None, # defined for aligning params order with corresponding function in `flash-attn`
  74. dropout_p=0.0,
  75. softmax_scale=None,
  76. causal=False,
  77. **kwargs,
  78. ):
  79. keep_prob = 1.0 - dropout_p
  80. if softmax_scale is None:
  81. softmax_scale = 1.0 / math.sqrt(q.shape[-1])
  82. if not causal:
  83. head_num = q.shape[1]
  84. output = npu_fusion_attention(
  85. q,
  86. k,
  87. v,
  88. head_num,
  89. pse=None,
  90. atten_mask=None,
  91. scale=softmax_scale,
  92. keep_prob=keep_prob,
  93. input_layout="TND",
  94. actual_seq_qlen=tuple(cu_seqlens_q[1:].cpu().numpy().tolist()),
  95. actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()),
  96. )[0]
  97. else:
  98. attn_mask_npu = get_attn_mask_npu(q.device)
  99. head_num = q.shape[1]
  100. output = npu_fusion_attention(
  101. q,
  102. k,
  103. v,
  104. head_num,
  105. pse=None,
  106. padding_mask=None,
  107. atten_mask=attn_mask_npu,
  108. scale=softmax_scale,
  109. keep_prob=keep_prob,
  110. input_layout="TND",
  111. actual_seq_qlen=tuple(cu_seqlens_q[1:].cpu().numpy().tolist()),
  112. actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()),
  113. sparse_mode=SPARSE_MODE,
  114. )[0]
  115. return output
  116. # This function is not implemented but should never be called because block table is not used on NPU
  117. def npu_flash_attn_with_kvcache():
  118. raise NotImplementedError("npu_flash_attn_with_kvcache is not implemented")