| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import math
- import os
- import torch
- from ..utils.import_utils import is_torch_npu_available
- if is_torch_npu_available():
- from torch_npu import npu_fusion_attention
- # FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default.
- # Set environment variable `NPU_FA2_SPARSE_MODE` to 2 when using top-left aligned causal mask.
- TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE = 2
- DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE = 3
- SPARSE_MODE = int(os.getenv("NPU_FA2_SPARSE_MODE", default=DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE))
- if SPARSE_MODE not in [TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE, DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE]:
- raise ValueError(
- "Environment variable `NPU_FA2_SPARSE_MODE` can only be set as 2 (top-left aligned causal mask) "
- "or 3 (down-right aligned causal mask)."
- )
- ATTN_MASK_NPU_CACHE = {}
- def get_attn_mask_npu(device):
- """Get or create attention mask for the specified device."""
- if device not in ATTN_MASK_NPU_CACHE:
- ATTN_MASK_NPU_CACHE[device] = torch.triu(torch.ones([2048, 2048], device=device), diagonal=1).bool()
- return ATTN_MASK_NPU_CACHE[device]
- def is_npu_fa2_top_left_aligned_causal_mask():
- return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE if is_torch_npu_available() else False
- def npu_flash_attn_func(
- q,
- k,
- v,
- dropout_p=0.0,
- softmax_scale=None,
- causal=False,
- **kwargs,
- ):
- keep_prob = 1.0 - dropout_p
- if softmax_scale is None:
- softmax_scale = 1.0 / math.sqrt(q.shape[-1])
- if not causal:
- head_num = q.shape[2]
- output = npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0]
- else:
- attn_mask_npu = get_attn_mask_npu(q.device)
- head_num = q.shape[2]
- output = npu_fusion_attention(
- q,
- k,
- v,
- head_num,
- "BSND",
- keep_prob=keep_prob,
- scale=softmax_scale,
- atten_mask=attn_mask_npu,
- sparse_mode=SPARSE_MODE,
- )[0]
- return output
- def npu_flash_attn_varlen_func(
- q,
- k,
- v,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q=None, # defined for aligning params order with corresponding function in `flash-attn`
- max_seqlen_k=None, # defined for aligning params order with corresponding function in `flash-attn`
- dropout_p=0.0,
- softmax_scale=None,
- causal=False,
- **kwargs,
- ):
- keep_prob = 1.0 - dropout_p
- if softmax_scale is None:
- softmax_scale = 1.0 / math.sqrt(q.shape[-1])
- if not causal:
- head_num = q.shape[1]
- output = npu_fusion_attention(
- q,
- k,
- v,
- head_num,
- pse=None,
- atten_mask=None,
- scale=softmax_scale,
- keep_prob=keep_prob,
- input_layout="TND",
- actual_seq_qlen=tuple(cu_seqlens_q[1:].cpu().numpy().tolist()),
- actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()),
- )[0]
- else:
- attn_mask_npu = get_attn_mask_npu(q.device)
- head_num = q.shape[1]
- output = npu_fusion_attention(
- q,
- k,
- v,
- head_num,
- pse=None,
- padding_mask=None,
- atten_mask=attn_mask_npu,
- scale=softmax_scale,
- keep_prob=keep_prob,
- input_layout="TND",
- actual_seq_qlen=tuple(cu_seqlens_q[1:].cpu().numpy().tolist()),
- actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()),
- sparse_mode=SPARSE_MODE,
- )[0]
- return output
- # This function is not implemented but should never be called because block table is not used on NPU
- def npu_flash_attn_with_kvcache():
- raise NotImplementedError("npu_flash_attn_with_kvcache is not implemented")
|