# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- import logging import numpy as np from fusion_attention import AttentionMask, FusionAttention from onnx import helper from onnx_model import OnnxModel logger = logging.getLogger(__name__) class FusionBartAttention(FusionAttention): """ Fuse Bart Attention subgraph into one Attention node. """ def __init__( self, model: OnnxModel, hidden_size: int, num_heads: int, attention_mask: AttentionMask, ): super().__init__(model, hidden_size, num_heads, attention_mask) def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # SkipLayerNormalization has two inputs, and one of them is the root input for attention. qkv_nodes = self.model.match_parent_path( normalize_node, ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], [1, 1, 0, 0, 0], ) if qkv_nodes is not None: ( add_out, matmul_out, reshape_qkv, transpose_qkv, matmul_qkv, ) = qkv_nodes else: logger.debug("fuse_attention: failed to match qkv path") return other_inputs = [] for input_ in normalize_node.input: if input_ not in output_name_to_node: continue if input_ == qkv_nodes[0].output[0]: continue other_inputs.append(input_) if len(other_inputs) != 1: return root_input = other_inputs[0] # Sometimes the input name to the attention MatMul nodes does not match the input name to the end # SkipLayerNormalization node (name saved in root_input). We find the true input name to the MatMul # nodes by getting the initial SkipLayerNormalization node and checking how many MatMul nodes are # children nodes for each of its output names. """ root_input +---------------------------------------------------+ | | | | SkipLayerNormalization --> Attention --> MatMul --> SkipLayerNormalization """ skip_layernorm = output_name_to_node[root_input] # For some attention blocks, the end SkipLayerNormalization node may point to another node whose # child is the LayerNormalization node. if skip_layernorm.op_type in {"Add", "Clip"}: skip_layernorm = self.model.get_children(skip_layernorm)[0] for output in skip_layernorm.output: if not output: continue children = input_name_to_nodes[output] children_types = [child.op_type for child in children] if children_types.count("MatMul") >= 1: root_input = output break graph_input_names = {node.name for node in self.model.graph().input} graph_output_names = {node.name for node in self.model.graph().output} v_nodes_past_or_present = self.model.match_parent_path( matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None], ) v_nodes_with_past = self.model.match_parent_path( matmul_qkv, ["Concat", "Transpose", "Reshape", "Add", "MatMul"], [1, 1, 0, 0, None], ) v_nodes_past_only_oai = self.model.match_parent_path( matmul_qkv, ["Transpose", "Reshape", "Reshape", "Transpose"], [1, 0, 0, 0], ) past_v, present_v = "", "" v_nodes, add_v, matmul_v = [], None, None if v_nodes_past_or_present is not None: v_nodes = v_nodes_past_or_present (transpose_v, reshape_v, add_v, matmul_v) = v_nodes # Find past_v input name start_child_nodes = input_name_to_nodes[add_v.output[0]] for start_child_node in start_child_nodes: if start_child_node.op_type == "Concat": concat_v_nodes = self.model.match_parent_path( start_child_node, ["Reshape", "Transpose"], [0, 0], ) if concat_v_nodes is not None: past_v = concat_v_nodes[-1].input[0] start_child_nodes = input_name_to_nodes[start_child_node.output[0]] break # Find present_v output name for start_child_node in start_child_nodes: start_grandchild_nodes = input_name_to_nodes[start_child_node.output[0]] for start_grandchild_node in start_grandchild_nodes: if start_grandchild_node.output[0] in graph_output_names: present_v = start_grandchild_node.output[0] break if present_v != "": break elif v_nodes_with_past is not None: v_nodes = v_nodes_with_past (concat_v, transpose_v, reshape_v, add_v, matmul_v) = v_nodes past_v = concat_v.input[0] present_v = concat_v.output[0] elif matmul_qkv.input[1] in graph_input_names: # Hugging Face's cross-attention where past_v is used directly as value past_v = matmul_qkv.input[1] elif v_nodes_past_only_oai is not None: # OpenAI's cross-attention where past_v is used directly as value v_nodes = v_nodes_past_only_oai past_v = v_nodes[-1].input[0] else: logger.debug("fuse_attention: failed to match v path") return past_v = past_v if past_v in graph_input_names else "" present_v = present_v if present_v in graph_output_names else "" qk_nodes_no_mask = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0]) qk_nodes_with_mask = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0]) qk_nodes, add_qk = [], None if qk_nodes_no_mask is not None: _, matmul_qk = qk_nodes_no_mask qk_nodes = qk_nodes_no_mask elif qk_nodes_with_mask is not None: _, add_qk, matmul_qk = qk_nodes_with_mask qk_nodes = qk_nodes_with_mask else: logger.debug("fuse_attention: failed to match qk path") return q_nodes_hf = self.model.match_parent_path( matmul_qk, ["Transpose", "Reshape", "Mul", "Add", "MatMul"], [0, 0, 0, 0, 1], ) q_nodes_oai = self.model.match_parent_path( matmul_qk, ["Mul", "Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, 0, 1], ) q_nodes = [] if q_nodes_hf is not None: q_nodes = q_nodes_hf (transpose_q, reshape_q, mul_q, add_q, matmul_q) = q_nodes elif q_nodes_oai is not None: q_nodes = q_nodes_oai (mul_q, transpose_q, reshape_q, add_q, matmul_q) = q_nodes else: logger.debug("fuse_attention: failed to match q path") return k_nodes_no_past_hf = self.model.match_parent_path( matmul_qk, ["Transpose", "Reshape", "MatMul"], [1, 0, 0], ) k_nodes_with_past_hf = self.model.match_parent_path( matmul_qk, ["Transpose", "Concat", "Transpose", "Reshape", "MatMul"], [1, 0, 1, 0, 0], ) k_nodes_past_or_present_oai = self.model.match_parent_path( matmul_qk, ["Mul", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0], ) k_nodes_past_only_oai = self.model.match_parent_path( matmul_qk, ["Mul", "Transpose", "Reshape", "Reshape", "Transpose"], [1, 0, 0, 0, 0], ) past_k, present_k = "", "" k_nodes, add_k, matmul_k = [], None, None if k_nodes_no_past_hf is not None: k_nodes = k_nodes_no_past_hf (transpose_k, reshape_k, matmul_k) = k_nodes # Find present_k output name transpose_k_nodes = input_name_to_nodes[reshape_k.output[0]] for transpose_k_node in transpose_k_nodes: if transpose_k_node.output[0] in graph_output_names: present_k = transpose_k_node.output[0] break elif k_nodes_with_past_hf is not None: k_nodes = k_nodes_with_past_hf (_, concat_k, transpose_k, reshape_k, matmul_k) = k_nodes past_k = concat_k.input[0] present_k = concat_k.output[0] elif output_name_to_node[matmul_qk.input[1]].input[0] in graph_input_names: # Hugging Face's cross-attention where past_k is used directly as key k_nodes = [output_name_to_node[matmul_qk.input[1]]] past_k = k_nodes[0].input[0] elif k_nodes_past_or_present_oai is not None: k_nodes = k_nodes_past_or_present_oai (_, transpose_k, reshape_k, matmul_k) = k_nodes # Find past_k input name start_child_nodes = input_name_to_nodes[matmul_k.output[0]] for start_child_node in start_child_nodes: if start_child_node.op_type == "Concat": concat_k_nodes = self.model.match_parent_path( start_child_node, ["Reshape", "Transpose"], [0, 0], ) if concat_k_nodes is not None: past_k = concat_k_nodes[-1].input[0] start_child_nodes = input_name_to_nodes[start_child_node.output[0]] break # Find present_k output name for start_child_node in start_child_nodes: start_grandchild_nodes = input_name_to_nodes[start_child_node.output[0]] for start_grandchild_node in start_grandchild_nodes: if start_grandchild_node.output[0] in graph_output_names: present_k = start_grandchild_node.output[0] break if present_k != "": break elif k_nodes_past_only_oai is not None: # OpenAI's cross-attention where past_k is used directly as key k_nodes = k_nodes_past_only_oai past_k = k_nodes[-1].input[0] else: logger.debug("fuse_attention: failed to match k path") return past_k = past_k if past_k in graph_input_names else "" present_k = present_k if present_k in graph_output_names else "" if matmul_k is not None and add_k is None: # Create empty Add node for attention graph add_v_tensor = self.model.get_initializer(add_v.input[0]) bias_dim = add_v_tensor.dims[0] dtype = add_v_tensor.data_type empty_bias_name = "empty_bias" empty_tensor = self.model.get_initializer(empty_bias_name) if empty_tensor is None: self.add_initializer( empty_bias_name, dtype, dims=[bias_dim], vals=np.array([0.0] * bias_dim, dtype=helper.tensor_dtype_to_np_dtype(dtype)), ) add_name = self.model.create_node_name("Add") add_k = helper.make_node("Add", [empty_bias_name, matmul_k.output[0]], [reshape_k.name], add_name) three_root_inputs = bool(past_k) and bool(past_v) and matmul_k is None and matmul_v is None one_root_input = ( not three_root_inputs and matmul_q.input[0] == root_input and matmul_k.input[0] == root_input and matmul_v.input[0] == root_input ) two_root_inputs = ( not three_root_inputs and matmul_q.input[0] == root_input and matmul_k.input[0] == matmul_v.input[0] and matmul_k.input[0] != matmul_q.input[0] ) # There are 5 types of attention: # 1) Encoder attention with one_root_input=True and qk_nodes=qk_nodes_no_mask # 2) Decoder self attention with one_root_input=True and qk_nodes=qk_nodes_with_mask # 3) Decoder cross attention with two_root_inputs=True and qk_nodes=qk_nodes_no_mask # 4) Decoder self attention with past with one_root_input=True and qk_nodes=qk_nodes_with_mask and past_k=past_decoder_key and past_v=past_decoder_value # 5) Decoder cross attention with past with three_root_inputs=True and qk_nodes=qk_nodes_no_mask encoder_attention = one_root_input and qk_nodes == qk_nodes_no_mask decoder_self_attention = one_root_input and qk_nodes == qk_nodes_with_mask decoder_cross_attention = two_root_inputs and qk_nodes == qk_nodes_no_mask decoder_self_attention_with_past = decoder_self_attention and bool(past_k) and bool(past_v) decoder_cross_attention_with_past = three_root_inputs and qk_nodes == qk_nodes_no_mask # For decoder self-attentions, the attention mask needs to be included in the attention node causal_mask = qk_nodes == qk_nodes_with_mask mask_nodes = [] if causal_mask: mask_nodes_bart = self.model.match_parent_path( add_qk, ["Where"], [1], ) mask_nodes_whisper_hf = self.model.match_parent_path( add_qk, ["Slice", "Expand", "Where"], [1, 0, 1], ) mask_nodes_whisper_oai = self.model.match_parent_path( add_qk, ["Slice", "Unsqueeze", "Gather", "Shape", "Add"], [1, 2, 0, 0, 0], ) mask_nodes_whisper_oai_unit_test = self.model.match_parent_path( add_qk, ["Slice", "Slice"], [1, 0], ) if mask_nodes_whisper_hf is not None: mask_nodes = mask_nodes_whisper_hf elif mask_nodes_whisper_oai is not None: mask_nodes = mask_nodes_whisper_oai elif mask_nodes_whisper_oai_unit_test is not None: mask_nodes = mask_nodes_whisper_oai_unit_test elif mask_nodes_bart is not None: mask_nodes = mask_nodes_bart else: logger.debug("fuse_attention: failed to match mask nodes") return assert len(mask_nodes) > 0 if ( encoder_attention or decoder_self_attention or decoder_cross_attention or decoder_self_attention_with_past or decoder_cross_attention_with_past ): attention_last_node = reshape_qkv num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q) if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0: logger.debug("fuse_attention: failed to detect num_heads or hidden_size") return new_node = None if decoder_self_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past: # Note: Decoder attention with past key and past value is fused as multi-head attention # rather than attention because multi-head attention supports separate past key and past # value whereas attention supports concatenated past key and past value. new_node = ( self.create_multihead_attention_node( q_matmul=matmul_q, k_matmul=matmul_k if decoder_cross_attention or decoder_self_attention_with_past else past_k, v_matmul=matmul_v if decoder_cross_attention or decoder_self_attention_with_past else past_v, q_add=add_q, k_add=add_k if decoder_cross_attention or decoder_self_attention_with_past else None, v_add=add_v if decoder_cross_attention or decoder_self_attention_with_past else None, num_heads=num_heads, hidden_size=hidden_size, output=attention_last_node.output[0], unidirectional=causal_mask, past_k=past_k if decoder_self_attention_with_past else "", past_v=past_v if decoder_self_attention_with_past else "", present_k=present_k, present_v=present_v, ) if self.use_multi_head_attention else None ) else: # Temporarily set multi-head attention flag to false use_multi_head_attention_ground_truth = self.use_multi_head_attention self.use_multi_head_attention = False new_node = self.create_attention_node( mask_index=None, q_matmul=matmul_q, k_matmul=matmul_k, v_matmul=matmul_v, q_add=add_q, k_add=add_k, v_add=add_v, num_heads=num_heads, hidden_size=hidden_size, first_input=root_input, output=attention_last_node.output[0], causal=causal_mask, past_k=past_k, past_v=past_v, present_k=present_k, present_v=present_v, ) self.use_multi_head_attention = use_multi_head_attention_ground_truth if new_node is None: logger.debug("fuse_attention: failed to create fused node") return self.nodes_to_add.append(new_node) self.node_name_to_graph_name[new_node.name] = self.this_graph_name self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv]) self.nodes_to_remove.extend(qk_nodes) # When using multi-head attention, keep MatMul nodes in original graph if decoder_self_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past: if len(q_nodes) > 0 and q_nodes[-1].op_type == "MatMul": q_nodes.pop() if len(k_nodes) > 0 and k_nodes[-1].op_type == "MatMul": k_nodes.pop() if len(v_nodes) > 0 and v_nodes[-1].op_type == "MatMul": v_nodes.pop() if self.disable_multi_head_attention_bias: if len(q_nodes) > 0 and q_nodes[-1].op_type == "Add": q_nodes.pop() if len(k_nodes) > 0 and k_nodes[-1].op_type == "Add": k_nodes.pop() if len(v_nodes) > 0 and v_nodes[-1].op_type == "Add": v_nodes.pop() self.nodes_to_remove.extend(q_nodes) self.nodes_to_remove.extend(k_nodes) self.nodes_to_remove.extend(v_nodes) # Use prune graph to remove mask nodes since they are shared by all attention nodes. self.prune_graph = True