| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435 |
- # -------------------------------------------------------------------------
- # Copyright (c) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License.
- # --------------------------------------------------------------------------
- import logging
- import numpy as np
- from fusion_attention import AttentionMask, FusionAttention
- from onnx import helper
- from onnx_model import OnnxModel
- logger = logging.getLogger(__name__)
- class FusionBartAttention(FusionAttention):
- """
- Fuse Bart Attention subgraph into one Attention node.
- """
- def __init__(
- self,
- model: OnnxModel,
- hidden_size: int,
- num_heads: int,
- attention_mask: AttentionMask,
- ):
- super().__init__(model, hidden_size, num_heads, attention_mask)
- def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
- # SkipLayerNormalization has two inputs, and one of them is the root input for attention.
- qkv_nodes = self.model.match_parent_path(
- normalize_node,
- ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
- [1, 1, 0, 0, 0],
- )
- if qkv_nodes is not None:
- (
- add_out,
- matmul_out,
- reshape_qkv,
- transpose_qkv,
- matmul_qkv,
- ) = qkv_nodes
- else:
- logger.debug("fuse_attention: failed to match qkv path")
- return
- other_inputs = []
- for input_ in normalize_node.input:
- if input_ not in output_name_to_node:
- continue
- if input_ == qkv_nodes[0].output[0]:
- continue
- other_inputs.append(input_)
- if len(other_inputs) != 1:
- return
- root_input = other_inputs[0]
- # Sometimes the input name to the attention MatMul nodes does not match the input name to the end
- # SkipLayerNormalization node (name saved in root_input). We find the true input name to the MatMul
- # nodes by getting the initial SkipLayerNormalization node and checking how many MatMul nodes are
- # children nodes for each of its output names.
- """
- root_input
- +---------------------------------------------------+
- | |
- | |
- SkipLayerNormalization --> Attention --> MatMul --> SkipLayerNormalization
- """
- skip_layernorm = output_name_to_node[root_input]
- # For some attention blocks, the end SkipLayerNormalization node may point to another node whose
- # child is the LayerNormalization node.
- if skip_layernorm.op_type in {"Add", "Clip"}:
- skip_layernorm = self.model.get_children(skip_layernorm)[0]
- for output in skip_layernorm.output:
- if not output:
- continue
- children = input_name_to_nodes[output]
- children_types = [child.op_type for child in children]
- if children_types.count("MatMul") >= 1:
- root_input = output
- break
- graph_input_names = {node.name for node in self.model.graph().input}
- graph_output_names = {node.name for node in self.model.graph().output}
- v_nodes_past_or_present = self.model.match_parent_path(
- matmul_qkv,
- ["Transpose", "Reshape", "Add", "MatMul"],
- [1, 0, 0, None],
- )
- v_nodes_with_past = self.model.match_parent_path(
- matmul_qkv,
- ["Concat", "Transpose", "Reshape", "Add", "MatMul"],
- [1, 1, 0, 0, None],
- )
- v_nodes_past_only_oai = self.model.match_parent_path(
- matmul_qkv,
- ["Transpose", "Reshape", "Reshape", "Transpose"],
- [1, 0, 0, 0],
- )
- past_v, present_v = "", ""
- v_nodes, add_v, matmul_v = [], None, None
- if v_nodes_past_or_present is not None:
- v_nodes = v_nodes_past_or_present
- (transpose_v, reshape_v, add_v, matmul_v) = v_nodes
- # Find past_v input name
- start_child_nodes = input_name_to_nodes[add_v.output[0]]
- for start_child_node in start_child_nodes:
- if start_child_node.op_type == "Concat":
- concat_v_nodes = self.model.match_parent_path(
- start_child_node,
- ["Reshape", "Transpose"],
- [0, 0],
- )
- if concat_v_nodes is not None:
- past_v = concat_v_nodes[-1].input[0]
- start_child_nodes = input_name_to_nodes[start_child_node.output[0]]
- break
- # Find present_v output name
- for start_child_node in start_child_nodes:
- start_grandchild_nodes = input_name_to_nodes[start_child_node.output[0]]
- for start_grandchild_node in start_grandchild_nodes:
- if start_grandchild_node.output[0] in graph_output_names:
- present_v = start_grandchild_node.output[0]
- break
- if present_v != "":
- break
- elif v_nodes_with_past is not None:
- v_nodes = v_nodes_with_past
- (concat_v, transpose_v, reshape_v, add_v, matmul_v) = v_nodes
- past_v = concat_v.input[0]
- present_v = concat_v.output[0]
- elif matmul_qkv.input[1] in graph_input_names:
- # Hugging Face's cross-attention where past_v is used directly as value
- past_v = matmul_qkv.input[1]
- elif v_nodes_past_only_oai is not None:
- # OpenAI's cross-attention where past_v is used directly as value
- v_nodes = v_nodes_past_only_oai
- past_v = v_nodes[-1].input[0]
- else:
- logger.debug("fuse_attention: failed to match v path")
- return
- past_v = past_v if past_v in graph_input_names else ""
- present_v = present_v if present_v in graph_output_names else ""
- qk_nodes_no_mask = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
- qk_nodes_with_mask = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])
- qk_nodes, add_qk = [], None
- if qk_nodes_no_mask is not None:
- _, matmul_qk = qk_nodes_no_mask
- qk_nodes = qk_nodes_no_mask
- elif qk_nodes_with_mask is not None:
- _, add_qk, matmul_qk = qk_nodes_with_mask
- qk_nodes = qk_nodes_with_mask
- else:
- logger.debug("fuse_attention: failed to match qk path")
- return
- q_nodes_hf = self.model.match_parent_path(
- matmul_qk,
- ["Transpose", "Reshape", "Mul", "Add", "MatMul"],
- [0, 0, 0, 0, 1],
- )
- q_nodes_oai = self.model.match_parent_path(
- matmul_qk,
- ["Mul", "Transpose", "Reshape", "Add", "MatMul"],
- [0, 0, 0, 0, 1],
- )
- q_nodes = []
- if q_nodes_hf is not None:
- q_nodes = q_nodes_hf
- (transpose_q, reshape_q, mul_q, add_q, matmul_q) = q_nodes
- elif q_nodes_oai is not None:
- q_nodes = q_nodes_oai
- (mul_q, transpose_q, reshape_q, add_q, matmul_q) = q_nodes
- else:
- logger.debug("fuse_attention: failed to match q path")
- return
- k_nodes_no_past_hf = self.model.match_parent_path(
- matmul_qk,
- ["Transpose", "Reshape", "MatMul"],
- [1, 0, 0],
- )
- k_nodes_with_past_hf = self.model.match_parent_path(
- matmul_qk,
- ["Transpose", "Concat", "Transpose", "Reshape", "MatMul"],
- [1, 0, 1, 0, 0],
- )
- k_nodes_past_or_present_oai = self.model.match_parent_path(
- matmul_qk,
- ["Mul", "Transpose", "Reshape", "MatMul"],
- [1, 0, 0, 0],
- )
- k_nodes_past_only_oai = self.model.match_parent_path(
- matmul_qk,
- ["Mul", "Transpose", "Reshape", "Reshape", "Transpose"],
- [1, 0, 0, 0, 0],
- )
- past_k, present_k = "", ""
- k_nodes, add_k, matmul_k = [], None, None
- if k_nodes_no_past_hf is not None:
- k_nodes = k_nodes_no_past_hf
- (transpose_k, reshape_k, matmul_k) = k_nodes
- # Find present_k output name
- transpose_k_nodes = input_name_to_nodes[reshape_k.output[0]]
- for transpose_k_node in transpose_k_nodes:
- if transpose_k_node.output[0] in graph_output_names:
- present_k = transpose_k_node.output[0]
- break
- elif k_nodes_with_past_hf is not None:
- k_nodes = k_nodes_with_past_hf
- (_, concat_k, transpose_k, reshape_k, matmul_k) = k_nodes
- past_k = concat_k.input[0]
- present_k = concat_k.output[0]
- elif output_name_to_node[matmul_qk.input[1]].input[0] in graph_input_names:
- # Hugging Face's cross-attention where past_k is used directly as key
- k_nodes = [output_name_to_node[matmul_qk.input[1]]]
- past_k = k_nodes[0].input[0]
- elif k_nodes_past_or_present_oai is not None:
- k_nodes = k_nodes_past_or_present_oai
- (_, transpose_k, reshape_k, matmul_k) = k_nodes
- # Find past_k input name
- start_child_nodes = input_name_to_nodes[matmul_k.output[0]]
- for start_child_node in start_child_nodes:
- if start_child_node.op_type == "Concat":
- concat_k_nodes = self.model.match_parent_path(
- start_child_node,
- ["Reshape", "Transpose"],
- [0, 0],
- )
- if concat_k_nodes is not None:
- past_k = concat_k_nodes[-1].input[0]
- start_child_nodes = input_name_to_nodes[start_child_node.output[0]]
- break
- # Find present_k output name
- for start_child_node in start_child_nodes:
- start_grandchild_nodes = input_name_to_nodes[start_child_node.output[0]]
- for start_grandchild_node in start_grandchild_nodes:
- if start_grandchild_node.output[0] in graph_output_names:
- present_k = start_grandchild_node.output[0]
- break
- if present_k != "":
- break
- elif k_nodes_past_only_oai is not None:
- # OpenAI's cross-attention where past_k is used directly as key
- k_nodes = k_nodes_past_only_oai
- past_k = k_nodes[-1].input[0]
- else:
- logger.debug("fuse_attention: failed to match k path")
- return
- past_k = past_k if past_k in graph_input_names else ""
- present_k = present_k if present_k in graph_output_names else ""
- if matmul_k is not None and add_k is None:
- # Create empty Add node for attention graph
- add_v_tensor = self.model.get_initializer(add_v.input[0])
- bias_dim = add_v_tensor.dims[0]
- dtype = add_v_tensor.data_type
- empty_bias_name = "empty_bias"
- empty_tensor = self.model.get_initializer(empty_bias_name)
- if empty_tensor is None:
- self.add_initializer(
- empty_bias_name,
- dtype,
- dims=[bias_dim],
- vals=np.array([0.0] * bias_dim, dtype=helper.tensor_dtype_to_np_dtype(dtype)),
- )
- add_name = self.model.create_node_name("Add")
- add_k = helper.make_node("Add", [empty_bias_name, matmul_k.output[0]], [reshape_k.name], add_name)
- three_root_inputs = bool(past_k) and bool(past_v) and matmul_k is None and matmul_v is None
- one_root_input = (
- not three_root_inputs
- and matmul_q.input[0] == root_input
- and matmul_k.input[0] == root_input
- and matmul_v.input[0] == root_input
- )
- two_root_inputs = (
- not three_root_inputs
- and matmul_q.input[0] == root_input
- and matmul_k.input[0] == matmul_v.input[0]
- and matmul_k.input[0] != matmul_q.input[0]
- )
- # There are 5 types of attention:
- # 1) Encoder attention with one_root_input=True and qk_nodes=qk_nodes_no_mask
- # 2) Decoder self attention with one_root_input=True and qk_nodes=qk_nodes_with_mask
- # 3) Decoder cross attention with two_root_inputs=True and qk_nodes=qk_nodes_no_mask
- # 4) Decoder self attention with past with one_root_input=True and qk_nodes=qk_nodes_with_mask and past_k=past_decoder_key and past_v=past_decoder_value
- # 5) Decoder cross attention with past with three_root_inputs=True and qk_nodes=qk_nodes_no_mask
- encoder_attention = one_root_input and qk_nodes == qk_nodes_no_mask
- decoder_self_attention = one_root_input and qk_nodes == qk_nodes_with_mask
- decoder_cross_attention = two_root_inputs and qk_nodes == qk_nodes_no_mask
- decoder_self_attention_with_past = decoder_self_attention and bool(past_k) and bool(past_v)
- decoder_cross_attention_with_past = three_root_inputs and qk_nodes == qk_nodes_no_mask
- # For decoder self-attentions, the attention mask needs to be included in the attention node
- causal_mask = qk_nodes == qk_nodes_with_mask
- mask_nodes = []
- if causal_mask:
- mask_nodes_bart = self.model.match_parent_path(
- add_qk,
- ["Where"],
- [1],
- )
- mask_nodes_whisper_hf = self.model.match_parent_path(
- add_qk,
- ["Slice", "Expand", "Where"],
- [1, 0, 1],
- )
- mask_nodes_whisper_oai = self.model.match_parent_path(
- add_qk,
- ["Slice", "Unsqueeze", "Gather", "Shape", "Add"],
- [1, 2, 0, 0, 0],
- )
- mask_nodes_whisper_oai_unit_test = self.model.match_parent_path(
- add_qk,
- ["Slice", "Slice"],
- [1, 0],
- )
- if mask_nodes_whisper_hf is not None:
- mask_nodes = mask_nodes_whisper_hf
- elif mask_nodes_whisper_oai is not None:
- mask_nodes = mask_nodes_whisper_oai
- elif mask_nodes_whisper_oai_unit_test is not None:
- mask_nodes = mask_nodes_whisper_oai_unit_test
- elif mask_nodes_bart is not None:
- mask_nodes = mask_nodes_bart
- else:
- logger.debug("fuse_attention: failed to match mask nodes")
- return
- assert len(mask_nodes) > 0
- if (
- encoder_attention
- or decoder_self_attention
- or decoder_cross_attention
- or decoder_self_attention_with_past
- or decoder_cross_attention_with_past
- ):
- attention_last_node = reshape_qkv
- num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
- if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0:
- logger.debug("fuse_attention: failed to detect num_heads or hidden_size")
- return
- new_node = None
- if decoder_self_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past:
- # Note: Decoder attention with past key and past value is fused as multi-head attention
- # rather than attention because multi-head attention supports separate past key and past
- # value whereas attention supports concatenated past key and past value.
- new_node = (
- self.create_multihead_attention_node(
- q_matmul=matmul_q,
- k_matmul=matmul_k if decoder_cross_attention or decoder_self_attention_with_past else past_k,
- v_matmul=matmul_v if decoder_cross_attention or decoder_self_attention_with_past else past_v,
- q_add=add_q,
- k_add=add_k if decoder_cross_attention or decoder_self_attention_with_past else None,
- v_add=add_v if decoder_cross_attention or decoder_self_attention_with_past else None,
- num_heads=num_heads,
- hidden_size=hidden_size,
- output=attention_last_node.output[0],
- unidirectional=causal_mask,
- past_k=past_k if decoder_self_attention_with_past else "",
- past_v=past_v if decoder_self_attention_with_past else "",
- present_k=present_k,
- present_v=present_v,
- )
- if self.use_multi_head_attention
- else None
- )
- else:
- # Temporarily set multi-head attention flag to false
- use_multi_head_attention_ground_truth = self.use_multi_head_attention
- self.use_multi_head_attention = False
- new_node = self.create_attention_node(
- mask_index=None,
- q_matmul=matmul_q,
- k_matmul=matmul_k,
- v_matmul=matmul_v,
- q_add=add_q,
- k_add=add_k,
- v_add=add_v,
- num_heads=num_heads,
- hidden_size=hidden_size,
- first_input=root_input,
- output=attention_last_node.output[0],
- causal=causal_mask,
- past_k=past_k,
- past_v=past_v,
- present_k=present_k,
- present_v=present_v,
- )
- self.use_multi_head_attention = use_multi_head_attention_ground_truth
- if new_node is None:
- logger.debug("fuse_attention: failed to create fused node")
- return
- self.nodes_to_add.append(new_node)
- self.node_name_to_graph_name[new_node.name] = self.this_graph_name
- self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
- self.nodes_to_remove.extend(qk_nodes)
- # When using multi-head attention, keep MatMul nodes in original graph
- if decoder_self_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past:
- if len(q_nodes) > 0 and q_nodes[-1].op_type == "MatMul":
- q_nodes.pop()
- if len(k_nodes) > 0 and k_nodes[-1].op_type == "MatMul":
- k_nodes.pop()
- if len(v_nodes) > 0 and v_nodes[-1].op_type == "MatMul":
- v_nodes.pop()
- if self.disable_multi_head_attention_bias:
- if len(q_nodes) > 0 and q_nodes[-1].op_type == "Add":
- q_nodes.pop()
- if len(k_nodes) > 0 and k_nodes[-1].op_type == "Add":
- k_nodes.pop()
- if len(v_nodes) > 0 and v_nodes[-1].op_type == "Add":
- v_nodes.pop()
- self.nodes_to_remove.extend(q_nodes)
- self.nodes_to_remove.extend(k_nodes)
- self.nodes_to_remove.extend(v_nodes)
- # Use prune graph to remove mask nodes since they are shared by all attention nodes.
- self.prune_graph = True
|