| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810 |
- # -------------------------------------------------------------------------
- # Copyright (c) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License.
- # --------------------------------------------------------------------------
- from logging import getLogger
- from fusion_base import Fusion
- from fusion_utils import FusionUtils
- from onnx import NodeProto, TensorProto, helper
- from onnx_model import OnnxModel
- logger = getLogger(__name__)
- class FusionEmbedLayerNoMask(Fusion):
- """
- Fuse embedding layer into one node (EmbedLayerNormalization).
- It supports the following model types: BERT, DistilBert, ALBert.
- """
- def __init__(self, model: OnnxModel, description: str = "no mask"):
- super().__init__(
- model,
- "EmbedLayerNormalization",
- ["LayerNormalization", "SkipLayerNormalization"],
- description,
- )
- self.utils = FusionUtils(model)
- self.shape_infer = None
- self.shape_infer_done = False
- # The following will be reset in each fuse call of FusionEmbedLayerNormalization
- self.attention = None
- self.embed_node = None
- def match_two_gather(self, add: NodeProto) -> None | tuple[NodeProto, NodeProto]:
- gather_0_path = self.model.match_parent_path(add, ["Gather"], [0])
- if gather_0_path is None:
- return None
- gather_1_path = self.model.match_parent_path(add, ["Gather"], [1])
- if gather_1_path is None:
- return None
- return gather_0_path[0], gather_1_path[0]
- def check_attention_subgraph(
- self,
- layernorm: NodeProto,
- input_name_to_nodes: dict[str, list[NodeProto]],
- is_distil_bert: bool,
- ) -> bool:
- """Check that LayerNormalization has a child of Attention node or subgraph like Attention.
- Args:
- layernorm (NodeProto): LayerNormalization node
- input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes
- is_distil_bert (bool): whether it is DistilBert or not
- Returns:
- bool: whether there is Attention node or subgraph like Attention
- """
- self.attention = self.model.find_first_child_by_type(
- layernorm, "Attention", input_name_to_nodes, recursive=False
- )
- if self.attention is not None:
- return True
- if layernorm.output[0] not in input_name_to_nodes:
- return False
- children = input_name_to_nodes[layernorm.output[0]]
- children_types = sorted([child.op_type for child in children])
- # Try find MultiHeadAttention
- if children_types == ["MatMul", "MatMul", "MatMul", "SkipLayerNormalization"]:
- for node in children:
- if node.op_type == "SkipLayerNormalization":
- path1 = self.model.match_parent_path(
- node,
- ["Add", "MatMul", "MultiHeadAttention", "MatMul"],
- [None, None, 0, 0],
- )
- if path1 is not None and path1[-1].input[0] == layernorm.output[0]:
- self.cross_attention = path1[2]
- return True
- # In case user disables attention fusion, check whether subgraph looks like Attention.
- # For Albert, there is MatMul+Add after embedding layer before attention.
- if len(children) == 1 and children[0].op_type == "MatMul" and children[0].output[0] in input_name_to_nodes:
- grandchildren = input_name_to_nodes[children[0].output[0]]
- if (
- len(grandchildren) == 1
- and grandchildren[0].op_type == "Add"
- and grandchildren[0].output[0] in input_name_to_nodes
- ):
- nodes = input_name_to_nodes[grandchildren[0].output[0]]
- for node in nodes:
- if node.op_type == "Attention":
- self.attention = node
- return True
- children_types = sorted([child.op_type for child in nodes])
- # Two Shape nodes might be merged by ORT
- if is_distil_bert:
- # SkipLayerNormailization might exist when model has been optimized by ORT first.
- if (
- children_types != ["MatMul", "MatMul", "MatMul", "Shape", "SkipLayerNormalization"]
- and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape", "Shape"]
- and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape"]
- ):
- logger.debug("No Attention like subgraph in children of LayerNormalization")
- return False
- else:
- if children_types != [
- "Add",
- "MatMul",
- "MatMul",
- "MatMul",
- ] and children_types != [
- "MatMul",
- "MatMul",
- "MatMul",
- "SkipLayerNormalization",
- ]:
- logger.debug("No Attention like subgraph in children of LayerNormalization")
- return False
- return True
- def match_position_embedding_distilbert(self, position_embedding_gather, input_ids, output_name_to_node):
- """ Match position embedding path from input_ids to Gather for DistilBert.
- Pattern is like the following:
- (input_ids)
- |
- Shape
- | \
- | Gather (indices=1)
- | |
- | Cast (optional)
- | |
- | Range (start=0, end=*, delta=1)
- | |
- | Unsqueeze
- | /
- Expand
- |
- Gather
- """
- # remove after tests pass
- path1 = self.model.match_parent_path(position_embedding_gather, ["Expand", "Shape"], [1, 1])
- if path1 is None:
- path1 = self.model.match_parent_path(
- position_embedding_gather,
- ["Expand", "Where", "Reshape", "Shape"],
- [1, 1, 2, 0],
- )
- if path1 is None:
- return False
- expand, shape = path1[0], path1[-1]
- if shape.input[0] != input_ids:
- return False
- _, path2, _ = self.model.match_parent_paths(
- expand,
- [
- (["Unsqueeze", "Range", "Cast", "Gather", "Shape"], [0, 0, 1, 0, 0]),
- (["Unsqueeze", "Range", "Gather", "Shape"], [0, 0, 1, 0]),
- ],
- output_name_to_node,
- )
- if path2 is None:
- return False
- range_node = path2[1]
- if not (
- self.utils.check_node_input_value(range_node, 0, 0) and self.utils.check_node_input_value(range_node, 2, 1)
- ):
- return False
- gather_node = path2[-2]
- if not (self.utils.check_node_input_value(gather_node, 1, 1)):
- return False
- shape_node = path2[-1]
- if shape_node.input[0] != input_ids:
- return False
- return True
- def match_position_embedding_roberta(self, position_embedding_gather, input_ids, output_name_to_node):
- """Match position embedding path from input_ids to Gather for Roberta.
- Roberta Embedding Layer Pattern (* is optional since it might be removed by ORT, ? is the padding word id):
- (input_ids) --> Equal(B=?) -- Not -- Cast(to=6) -- CumSum(axis=1) -- Mul -- Cast(to=7) -- Add(B=1) -- Cast(to=7)* --> Gather
- | ^
- V |
- +------------------------------+
- Roberta new pattern from transformers v4.9:
- (input_ids) --> Equal(B=?) -- Not -- Cast(to=6) -- CumSum(axis=1) -- Add(B=0) -- Mul -- Cast(to=7) -- Add(B=1) --> Gather
- | ^
- V |
- +-------------------------------------------+
- start_node = position_embedding_gather
- start_index = 1
- # match optional Cast node.
- parent = self.model.get_parent(start_node, start_index, output_name_to_node)
- if parent is None:
- return
- if parent.op_type == "Cast":
- if OnnxModel.get_node_attribute(parent, "to") != 7:
- return
- start_node = parent
- start_index = 0
- i, path, return_indices = self.model.match_parent_paths(
- start_node,
- [ (['Add', 'Cast', 'Mul', 'CumSum', 'Cast', 'Not', 'Equal'], [start_index, 0, 0, 0, 0, 0, 0]),
- (['Add', 'Cast', 'Mul', 'Add', 'CumSum', 'Cast', 'Not', 'Equal'], [start_index, 0, 0, 0, 0, 0, 0, 0])],
- output_name_to_node)
- if path is not None:
- # constant input of Add shall be 1.
- i, value = self.model.get_constant_input(path[0])
- if value != 1:
- return False
- _, self.padding_word_id = self.model.get_constant_input(path[-1])
- return input_ids == path[-1].input[0]
- """
- return False
- def match_position_embedding_bert(self, position_embedding_gather, input_ids, output_name_to_node):
- """ Match position embedding path from input_ids to Gather for BERT.
- BERT Embedding Layer Pattern:
- (input_ids)
- / \
- / Shape
- / |
- / Gather (indices=1)
- / |
- / Add (optional, B=0)
- / |
- Gather (segment_ids) Unsqueeze (axes=0)
- \\ | |
- \\ Gather Slice (data[1,512], starts=0, ends=*, axes=1, steps=1)
- \\ / |
- Add Gather
- \\ /
- Add
- |
- LayerNormalization
- """
- path = self.model.match_parent_path(
- position_embedding_gather,
- ["Slice", "Unsqueeze"],
- [1, 2],
- output_name_to_node,
- )
- if path is None:
- return False
- slice, unsqueeze = path
- slice_weight = self.model.get_constant_value(slice.input[0])
- if not (
- slice_weight is not None
- and len(slice_weight.shape) == 2
- and slice_weight.shape[0] == 1
- and self.utils.check_node_input_value(slice, 1, [0])
- and self.utils.check_node_input_value(slice, 3, [1])
- and (len(slice.input) == 4 or self.utils.check_node_input_value(slice, 4, [1]))
- ):
- return False
- opset_version = self.model.get_opset_version()
- if opset_version < 13:
- if not FusionUtils.check_node_attribute(unsqueeze, "axes", [0]):
- return False
- else:
- if not self.utils.check_node_input_value(unsqueeze, 1, [0]):
- return False
- node = self.model.get_parent(unsqueeze, 0, output_name_to_node)
- if node is None:
- return False
- if node.op_type == "Add":
- if not self.utils.check_node_input_value(node, 1, 0):
- return False
- gather = self.model.get_parent(node, 0, output_name_to_node)
- else:
- gather = node
- if gather is None or gather.op_type != "Gather":
- return False
- if not (self.utils.check_node_input_value(gather, 1, 1)):
- return False
- shape = self.model.get_parent(gather, 0, output_name_to_node)
- if shape is None or shape.op_type != "Shape":
- return False
- return input_ids == shape.input[0]
- def match_position_embedding(self, position_embedding_gather, input_ids, output_name_to_node):
- if self.match_position_embedding_bert(position_embedding_gather, input_ids, output_name_to_node):
- return True
- # TODO: Support roberta (position starts from 2 instead of 0) in EmbedLayerNormalization kernel
- # related: https://github.com/huggingface/transformers/issues/10736
- # if self.match_position_embedding_roberta(position_embedding_gather, input_ids, output_name_to_node):
- # return True
- if self.match_position_embedding_distilbert(position_embedding_gather, input_ids, output_name_to_node):
- return True
- return False
- def check_embedding(self, word_embedding_gather, segment_embedding_gather, position_embedding_gather):
- """Sanity check of embedding weights, and match hidden_size of weights and shape of inputs."""
- input_ids = word_embedding_gather.input[1]
- segment_ids = segment_embedding_gather.input[1] if segment_embedding_gather else None
- position_ids = position_embedding_gather.input[1]
- if not self.shape_infer_done:
- self.shape_infer = self.model.infer_runtime_shape(update=True)
- self.shape_infer_done = True
- if self.shape_infer is not None:
- input_ids_shape = self.shape_infer.get_edge_shape(input_ids)
- position_ids_shape = self.shape_infer.get_edge_shape(position_ids)
- assert input_ids_shape and position_ids_shape
- if not (
- len(input_ids_shape) == 2
- and len(position_ids_shape) == 2
- and input_ids_shape[1] == position_ids_shape[1]
- ):
- logger.info(
- f"Cannot fuse EmbedLayerNormalization: input_ids and position_ids not matched in 2nd dimension: {input_ids_shape} vs {position_ids_shape}"
- )
- return False
- if segment_ids and not self.shape_infer.compare_shape(input_ids, segment_ids):
- logger.info(
- f"Cannot fuse EmbedLayerNormalization: input_ids and segment_ids does not have same shape: {input_ids_shape} != {self.shape_infer.get_edge_shape(segment_ids)}"
- )
- return False
- word_embedding_table = self.model.get_constant_value(word_embedding_gather.input[0])
- if word_embedding_table is None or len(word_embedding_table.shape) != 2:
- logger.info("Cannot fuse EmbedLayerNormalization: word embedding table is not expected")
- return False
- position_embedding_table = self.model.get_constant_value(position_embedding_gather.input[0])
- if (
- position_embedding_table is None
- or len(position_embedding_table.shape) != 2
- or (word_embedding_table.shape[1] != position_embedding_table.shape[1])
- ):
- logger.info("Cannot fuse EmbedLayerNormalization: position embedding table is not expected")
- return False
- if segment_ids:
- segment_embedding_table = self.model.get_constant_value(segment_embedding_gather.input[0])
- if (
- segment_embedding_table is None
- or len(segment_embedding_table.shape) != 2
- or (word_embedding_table.shape[1] != segment_embedding_table.shape[1])
- ):
- logger.info("Cannot fuse EmbedLayerNormalization: segment embedding table is not expected")
- return False
- # In normal case, word embedding table is the largest, and segment embedding table is the smallest, while position embedding table is in between.
- # TODO: use other information (like initializer names) to identify different embedding weights automatically.
- if word_embedding_table.shape[0] <= position_embedding_table.shape[0]:
- logger.warning(
- f"word_embedding_table ({word_embedding_gather.input[0]}) size {word_embedding_table.shape[0]} <= position_embedding_table ({position_embedding_gather.input[0]}) size {position_embedding_table.shape[0]}"
- )
- if segment_ids:
- if word_embedding_table.shape[0] <= segment_embedding_table.shape[0]:
- logger.warning(
- f"word_embedding_table ({word_embedding_gather.input[0]}) size {word_embedding_table.shape[0]} <= segment_embedding_table ({segment_embedding_gather.input[0]}) size {segment_embedding_table.shape[0]}"
- )
- if position_embedding_table.shape[0] <= segment_embedding_table.shape[0]:
- logger.warning(
- f"position_embedding_table ({position_embedding_gather.input[0]}) size {position_embedding_table.shape[0]} <= segment_embedding_table ({segment_embedding_gather.input[0]}) size {segment_embedding_table.shape[0]}"
- )
- return True
- def cast_to_int32(self, input_name: str) -> tuple[str, None | NodeProto]:
- """Cast a graph input or node input to int32.
- Args:
- input_name (str): name of graph input or node input
- Returns:
- A tuple of casted input name and the cast node.
- int32_output (str): If input is int32, it is the input name, Otherwise it is output name of Cast node.
- input_cast_node (Union[None, NodeProto]): Cast node. It could be None if input is int32.
- """
- input_cast_node = None
- graph_input = self.model.find_graph_input(input_name)
- if graph_input is not None:
- if graph_input.type.tensor_type.elem_type != TensorProto.INT32:
- int32_output, input_cast_node = self.utils.cast_input_to_int32(input_name)
- else:
- int32_output = input_name
- else:
- int32_output, input_cast_node = self.utils.cast_input_to_int32(input_name)
- return int32_output, input_cast_node
- def create_fused_node(
- self,
- input_ids: str,
- layernorm: NodeProto,
- word_embedding_gather: NodeProto,
- position_embedding_gather: NodeProto,
- segment_embedding_gather: None | NodeProto,
- position_ids: str | None = None,
- embedding_sum_output=False,
- embedding_sum_name=None,
- ):
- """Create an EmbedLayerNormalization node. Note that segment embedding is optional.
- Args:
- input_ids (str): input_ids for word embeddings
- layernorm (NodeProto): LayerNormalization or SkipLayerNormalization node.
- word_embedding_gather (NodeProto): the Gather node for word embedding
- position_embedding_gather (NodeProto): the Gather node for position embedding
- segment_embedding_gather (Union[None, NodeProto]): the Gather node for segment embedding, or None.
- Returns:
- NodeProto: the EmbedLayerNormalization node created.
- """
- nodes_to_add = []
- input_ids, _ = self.cast_to_int32(input_ids)
- node_name = self.model.create_node_name("EmbedLayerNormalization")
- if layernorm.op_type == "LayerNormalization":
- gamma = layernorm.input[1]
- beta = layernorm.input[2]
- else: # SkipLayerNormalization
- gamma = layernorm.input[2]
- beta = layernorm.input[3]
- embed_node_inputs = None
- if segment_embedding_gather is not None:
- segment_ids, _ = self.cast_to_int32(segment_embedding_gather.input[1])
- embed_node_inputs = [
- input_ids,
- segment_ids,
- word_embedding_gather.input[0],
- position_embedding_gather.input[0],
- segment_embedding_gather.input[0],
- gamma,
- beta,
- ]
- else: # no segment embedding
- embed_node_inputs = [
- input_ids,
- "",
- word_embedding_gather.input[0],
- position_embedding_gather.input[0],
- "",
- gamma,
- beta,
- ]
- if position_ids is not None:
- # Adding an empty input for mask before position_ids
- embed_node_inputs.append("")
- position_ids, _ = self.cast_to_int32(position_ids)
- embed_node_inputs.append(position_ids)
- embed_node_outputs = [node_name + "_output", node_name + "_dummy_mask_index"]
- if embedding_sum_output:
- name = embedding_sum_name if embedding_sum_name is not None else node_name + "_embedding_sum"
- embed_node_outputs.append(name)
- embed_node = helper.make_node(
- "EmbedLayerNormalization",
- embed_node_inputs,
- outputs=embed_node_outputs,
- name=node_name,
- )
- embed_node.domain = "com.microsoft"
- # Pass attribute "epsilon" from normalize node to EmbedLayerNormalization.
- for att in layernorm.attribute:
- if att.name == "epsilon":
- embed_node.attribute.extend([att])
- # Set default value to 1e-12 if no attribute is found.
- # OnnxRuntime 1.2.0 or older has no epsilon attribute. The optimized model can only work for 1.3.0 or later.
- if len(embed_node.attribute) == 0:
- embed_node.attribute.extend([helper.make_attribute("epsilon", 1.0e-12)])
- # Make sure new EmbedLayerNormalization node is the last one in self.nodes_to_add.
- nodes_to_add.append(embed_node)
- for node in nodes_to_add:
- self.node_name_to_graph_name[node.name] = self.this_graph_name
- self.nodes_to_add.extend(nodes_to_add)
- self.embed_node = embed_node
- return embed_node
- def finish_fusion(self, layernorm, embed_node):
- self.model.replace_input_of_all_nodes(layernorm.output[0], embed_node.output[0])
- # use prune graph to remove nodes that is not needed
- self.prune_graph = True
- def is_skip_layer_norm_with_sum_output(self, node):
- return (node.op_type == "SkipLayerNormalization") and len(node.output) > 3 and len(node.output[3]) > 0
- def fuse_gpt2(
- self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node, optional_segment_gather=None
- ):
- # graph checks
- # gpt2 has optional segment embedding, subgraph pattern is like
- # input_ids position_ids
- # | |
- # token_ids Gather Gather
- # | \ /
- # Gather (optional) Add _ _ _ _ _
- # \ | |
- # LayerNormalization |
- # | |
- # Attention |
- # | |
- # Matmul |
- # | /
- # Add /
- # \ /
- # Add
- two_gather = self.match_two_gather(add_before_layernorm)
- if two_gather is None:
- return False
- word_embedding_gather, position_embedding_gather = two_gather
- input_ids = word_embedding_gather.input[1]
- position_ids = position_embedding_gather.input[1]
- if not self.check_attention_subgraph(layernorm, input_name_to_nodes, is_distil_bert=False):
- return False
- if not self.check_embedding(word_embedding_gather, None, position_embedding_gather):
- return False
- # If layernorm node is SkipLayerNormalization, we need look at its optional fourth output.
- # If the add_before_layernorm node is an Add node, then the add_output output is the first output of this node.
- # If the add_before_layernorm node is a SkipLayerNormalization node, then the add_output output
- # is the (optional) fourth index output of this node.
- # When add_before_layernorm is SkipLayerNormalization, add_before_layernorm and layernorm are same node.
- if layernorm.op_type == "SkipLayerNormalization":
- need_embedding_sum_output = self.is_skip_layer_norm_with_sum_output(layernorm)
- sum_output_index = 3
- node_with_sum_output = layernorm
- sum_output = layernorm.output[3] if need_embedding_sum_output else None
- is_sum_graph_output = (sum_output is not None) and (self.model.find_graph_output(sum_output) is not None)
- else: # layernorm.op_type == "LayerNormalization"
- node_with_sum_output = add_before_layernorm
- sum_output_index = 0 if add_before_layernorm.op_type == "Add" else 3
- sum_output = (
- add_before_layernorm.output[sum_output_index]
- if len(add_before_layernorm.output) > sum_output_index
- else None
- )
- is_sum_graph_output = (sum_output is not None) and (self.model.find_graph_output(sum_output) is not None)
- is_sum_used_by_multiple_nodes = (
- sum_output and (sum_output in input_name_to_nodes) and len(input_name_to_nodes[sum_output]) > 1
- )
- need_embedding_sum_output = (sum_output is not None) and (
- add_before_layernorm.op_type != "Add" or is_sum_graph_output or is_sum_used_by_multiple_nodes
- )
- # make the fused node
- embed_node = self.create_fused_node(
- input_ids,
- layernorm,
- word_embedding_gather,
- position_embedding_gather,
- optional_segment_gather,
- position_ids,
- embedding_sum_output=need_embedding_sum_output,
- embedding_sum_name=sum_output if is_sum_graph_output else None,
- )
- if need_embedding_sum_output:
- node_with_sum_output.output[sum_output_index] = "_no_use__to_be_removed_"
- if not is_sum_graph_output:
- self.model.replace_input_of_all_nodes(sum_output, embed_node.output[2])
- self.finish_fusion(layernorm, embed_node)
- return True
- def fuse_distilbert(self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node):
- """Fuse embedding layer for DistilBert
- Args:
- layernorm (NodeProto): node of LayerNormalization or SkipLayerNormalization
- add_before_layernorm (NodeProto): the Add node before LayerNormalization, or the SkipLayerNormalization itself
- input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes
- output_name_to_node (Dict[str, List[NodeProto]]): map from output name to nodes
- """
- # DistilBert has no segment embedding, subgraph pattern is like
- # input_ids
- # | \
- # | (position_embedding_subgraph)
- # | |
- # Gather Gather
- # \ /
- # Add
- # |
- # LayerNormalization
- two_gather = self.match_two_gather(add_before_layernorm)
- if two_gather is None:
- return False
- word_embedding_gather, position_embedding_gather = two_gather
- input_ids = word_embedding_gather.input[1]
- if not self.check_attention_subgraph(layernorm, input_name_to_nodes, is_distil_bert=True):
- return False
- if not self.match_position_embedding(position_embedding_gather, input_ids, output_name_to_node):
- return False
- if not self.check_embedding(word_embedding_gather, None, position_embedding_gather):
- return False
- embed_node = self.create_fused_node(
- input_ids, layernorm, word_embedding_gather, position_embedding_gather, None
- )
- self.finish_fusion(layernorm, embed_node)
- return True
- def fuse_bert(self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node):
- """Fuse embedding layer for Bert
- Args:
- layernorm (NodeProto): node of LayerNormalization or SkipLayerNormalization
- add_before_layernorm (NodeProto): the Add node before LayerNormalization, or the SkipLayerNormalization itself
- input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes
- output_name_to_node (Dict[str, List[NodeProto]]): map from output name to nodes
- """
- add_2_gather = self.model.match_parent_path(add_before_layernorm, ["Add"], [0])
- if add_2_gather is None:
- return False
- two_gather = self.match_two_gather(add_2_gather[0])
- if two_gather is None:
- return False
- word_embedding_gather, segment_embedding_gather = two_gather
- input_ids = word_embedding_gather.input[1]
- if not self.check_attention_subgraph(layernorm, input_name_to_nodes, is_distil_bert=False):
- return False
- position_embedding_path = self.model.match_parent_path(add_before_layernorm, ["Gather"], [1])
- if position_embedding_path is None:
- return False
- position_embedding_gather = position_embedding_path[0]
- if not self.match_position_embedding(position_embedding_gather, input_ids, output_name_to_node):
- if not self.match_position_embedding(segment_embedding_gather, input_ids, output_name_to_node):
- return False
- # position and segment are switched
- temp = segment_embedding_gather
- segment_embedding_gather = position_embedding_gather
- position_embedding_gather = temp
- if not self.check_embedding(word_embedding_gather, segment_embedding_gather, position_embedding_gather):
- return False
- embed_node = self.create_fused_node(
- input_ids,
- layernorm,
- word_embedding_gather,
- position_embedding_gather,
- segment_embedding_gather,
- )
- self.finish_fusion(layernorm, embed_node)
- return True
- def fuse(self, node, input_name_to_nodes, output_name_to_node):
- first_add_path = self.model.match_parent_path(node, ["Add"], [0])
- if node.op_type == "LayerNormalization":
- if first_add_path is None:
- return
- add_before_layernorm = first_add_path[0]
- optional_segment_gather = None
- else: # SkipLayerNormalization
- gather_0_path = self.model.match_parent_path(node, ["Gather"], [0])
- gather_1_path = self.model.match_parent_path(node, ["Gather"], [1])
- if gather_0_path is None and gather_1_path is not None:
- if first_add_path is None:
- return
- add_before_layernorm = first_add_path[0]
- optional_segment_gather = gather_1_path[0]
- elif gather_0_path is not None and gather_1_path is None:
- first_add_path = self.model.match_parent_path(node, ["Add"], [1])
- if first_add_path is None:
- return
- add_before_layernorm = first_add_path[0]
- optional_segment_gather = gather_0_path[0]
- else:
- add_before_layernorm = node # Add is fused into SkipLayerNormalization
- optional_segment_gather = None
- if self.fuse_gpt2(
- node, add_before_layernorm, input_name_to_nodes, output_name_to_node, optional_segment_gather
- ):
- return
- if self.fuse_distilbert(node, add_before_layernorm, input_name_to_nodes, output_name_to_node):
- return
- if self.fuse_bert(node, add_before_layernorm, input_name_to_nodes, output_name_to_node):
- return
- class FusionEmbedLayerNormalization(FusionEmbedLayerNoMask):
- def __init__(self, model: OnnxModel, use_mask_index=False):
- super().__init__(model, "with mask")
- self.use_mask_index = use_mask_index
- def replace_mask(self, mask_int32, attention_nodes):
- # Inputs of EmbedLayerNorm: input_ids, segment_ids (optional), word_embedding, position_embedding,
- # segment_embedding (optional), gamma, beta, mask (optional), position_ids (optional)
- embed_node = self.embed_node
- if len(embed_node.input) == 7:
- embed_node.input.append(mask_int32)
- logger.debug("append mask to %s", embed_node.name)
- elif len(embed_node.input) > 7 and not embed_node.input[7]:
- embed_node.input[7] = mask_int32
- logger.debug("replace mask in %s", embed_node.name)
- else:
- logger.debug("skip mask in %s", embed_node.name)
- return
- for attention_node in attention_nodes:
- logger.debug("update mask_index in %s", attention_node.name)
- if attention_node.op_type == "Attention":
- attention_node.input[3] = embed_node.output[1]
- elif attention_node.op_type == "MultiHeadAttention":
- attention_node.input[4] = embed_node.output[1]
- def fuse(self, node, input_name_to_nodes, output_name_to_node):
- # Reset attention and embed_node so that we know fusion is successful when they are not None.
- self.attention = None
- self.cross_attention = None
- self.embed_node = None
- super().fuse(node, input_name_to_nodes, output_name_to_node)
- if self.embed_node is None:
- return
- if not self.use_mask_index:
- logger.debug("--use_mask_index is not set: EmbedLayerNormalization will not have mask")
- self.increase_counter("EmbedLayerNormalization(no mask)")
- return
- if self.attention is None and self.cross_attention is None:
- logger.debug("EmbedLayerNormalization will not have mask since attention node is not found")
- self.increase_counter("EmbedLayerNormalization(no mask)")
- return
- if self.attention:
- mask_int32 = self.attention.input[3]
- else:
- mask_int32 = self.cross_attention.input[4]
- children_nodes = input_name_to_nodes[mask_int32]
- if self.model.find_graph_input(mask_int32):
- attention_nodes = [node for node in children_nodes if node.op_type in ["Attention", "MultiHeadAttention"]]
- self.replace_mask(mask_int32, attention_nodes)
- self.increase_counter("EmbedLayerNormalization(with mask)")
- return
- if mask_int32 not in output_name_to_node:
- logger.debug("EmbedLayerNormalization will not have mask since %s is not a node output", mask_int32)
- self.increase_counter("EmbedLayerNormalization(no mask)")
- return
- node = output_name_to_node[mask_int32]
- if node.op_type in ["ReduceSum", "Cast"]:
- attention_nodes = [node for node in children_nodes if node.op_type in ["Attention", "MultiHeadAttention"]]
- if node.op_type == "ReduceSum":
- mask_int32 = node.input[0]
- if len(children_nodes) == len(attention_nodes):
- self.nodes_to_remove.append(node)
- self.replace_mask(mask_int32, attention_nodes)
- self.increase_counter("EmbedLayerNormalization(with mask)")
|