| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591 |
- # -------------------------------------------------------------------------
- # Copyright (c) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License.
- # --------------------------------------------------------------------------
- import logging
- from fusion_attention import FusionAttention
- from fusion_base import Fusion
- from onnx import FunctionProto, NodeProto, TensorProto, helper, numpy_helper
- from onnx_model import OnnxModel
- logger = logging.getLogger(__name__)
- class FusionRotaryAttention(FusionAttention):
- """
- Fuse Attention subgraph with rotary positional embeddings into one MultiHeadAttention node.
- """
- def __init__(
- self,
- model: OnnxModel,
- hidden_size: int,
- num_heads: int,
- ):
- super().__init__(
- model,
- hidden_size,
- num_heads,
- use_multi_head_attention=True,
- search_op_types=[
- "SimplifiedLayerNormalization",
- "SkipSimplifiedLayerNormalization",
- "LayerNormalization",
- "SkipLayerNormalization",
- "Add",
- ],
- )
- def create_mha_node(
- self,
- input: str,
- output: str,
- q_rotary: NodeProto,
- k_rotary: NodeProto,
- v_matmul: NodeProto,
- attn_mask: str = "",
- add_qk: str = "",
- past_k: str = "",
- past_v: str = "",
- present_k: str = "",
- present_v: str = "",
- scale: float | None = None,
- ) -> NodeProto | None:
- assert self.num_heads > 0
- if self.hidden_size > 0 and (self.hidden_size % self.num_heads) != 0:
- logger.debug(
- f"fuse_rotary_attention: input hidden size {self.hidden_size} is not a multiple of num of heads {self.num_heads}"
- )
- return None
- mha_node_name = self.model.create_node_name("MultiHeadAttention")
- mha_inputs = [
- q_rotary.output[0],
- k_rotary.output[0],
- v_matmul.output[0],
- "", # bias
- attn_mask, # key_padding_mask
- add_qk, # attention_bias
- past_k,
- past_v,
- ]
- mha_outputs = [output]
- if present_k and present_v:
- mha_outputs.extend([present_k, present_v])
- mha_node = helper.make_node(
- "MultiHeadAttention",
- inputs=mha_inputs,
- outputs=mha_outputs,
- name=mha_node_name,
- )
- mha_node.domain = "com.microsoft"
- mha_node.attribute.extend([helper.make_attribute("num_heads", self.num_heads)])
- if scale is not None:
- mha_node.attribute.extend([helper.make_attribute("scale", scale)])
- if self.mask_filter_value is not None:
- mha_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
- self.increase_counter("MultiHeadAttention")
- return mha_node
- def check_runtime_shape_paths_for_function(
- self,
- reshape_qkv_2, # Reshape after Transpose
- reshape_qkv_1, # Reshape before Transpose
- reshape_q_2, # Reshape after RotaryEmbedding
- reshape_k_2, # Reshape after RotaryEmbedding
- reshape_v_2, # Reshape after Transpose
- reshape_v_1, # Reshape before Transpose
- add_qk, # Add before Softmax
- root_input, # Root input to attention subgraph
- ):
- # Check #1: check paths for qkv nodes
- concat_qkv_2_path = self.model.match_parent_path(reshape_qkv_2, ["Concat"], [1])
- concat_qkv_1_path = self.model.match_parent_path(reshape_qkv_1, ["Concat"], [1])
- if concat_qkv_2_path is None or concat_qkv_1_path is None:
- return False
- concat_qkv_2, concat_qkv_1 = concat_qkv_2_path[0], concat_qkv_1_path[0]
- reshape_qkv_2_path_1 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
- reshape_qkv_2_path_2 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
- reshape_qkv_1_path_1 = self.model.match_parent_path(concat_qkv_1, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
- reshape_qkv_1_path_2 = self.model.match_parent_path(concat_qkv_1, ["Unsqueeze", "Gather", "Shape"], [2, 0, 0])
- if (
- reshape_qkv_2_path_1 is None
- or reshape_qkv_2_path_2 is None
- or reshape_qkv_1_path_1 is None
- or reshape_qkv_1_path_2 is None
- ):
- return False
- _, gather_1, shape_1 = reshape_qkv_2_path_1
- _, gather_2, shape_2 = reshape_qkv_2_path_2
- # Check root_input --> Shape --> Gather connection
- if shape_1.input[0] != root_input or shape_2.input[0] != root_input:
- return False
- # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_qkv_1_path_1 and reshape_qkv_1_path_2
- if reshape_qkv_1_path_1[1].name != gather_1.name or reshape_qkv_1_path_2[1].name != gather_2.name:
- return False
- # Check #2: check paths for v nodes
- concat_v_2_path = self.model.match_parent_path(reshape_v_2, ["Concat"], [1])
- concat_v_1_path = self.model.match_parent_path(reshape_v_1, ["Concat"], [1])
- if concat_v_2_path is None or concat_v_1_path is None:
- return False
- concat_v_2, concat_v_1 = concat_v_2_path[0], concat_v_1_path[0]
- reshape_v_2_path_1 = self.model.match_parent_path(
- concat_v_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0]
- )
- reshape_v_2_path_2 = self.model.match_parent_path(
- concat_v_2, ["Unsqueeze", "Add", "Gather", "Shape"], [1, 0, 0, 0]
- )
- reshape_v_1_path_1 = self.model.match_parent_path(concat_v_1, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
- reshape_v_1_path_2 = self.model.match_parent_path(concat_v_1, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
- if (
- reshape_v_2_path_1 is None
- or reshape_v_2_path_2 is None
- or reshape_v_1_path_1 is None
- or reshape_v_1_path_2 is None
- ):
- return False
- # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_v_2_path_1
- # Check Gather --> Add --> Unsqueeze --> Concat --> Reshape connection for reshape_v_2_path_2
- # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_v_1_path_1 and reshape_v_1_path_2
- if (
- reshape_v_2_path_1[2].name != gather_1.name
- or reshape_v_2_path_2[2].name != gather_2.name
- or reshape_v_1_path_1[1].name != gather_1.name
- or reshape_v_1_path_2[1].name != gather_2.name
- ):
- return False
- # Check #3: check paths for k nodes
- concat_k_2_path = self.model.match_parent_path(reshape_k_2, ["Concat"], [1])
- if concat_k_2_path is None:
- return False
- concat_k_2 = concat_k_2_path[0]
- reshape_k_2_path_1 = self.model.match_parent_path(
- concat_k_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0]
- )
- reshape_k_2_path_2 = self.model.match_parent_path(
- concat_k_2, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 0, 0]
- )
- if reshape_k_2_path_1 is None or reshape_k_2_path_2 is None:
- return False
- # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_k_2_path_1
- # Check Gather --> Add --> Unsqueeze --> Concat --> Reshape connection for reshape_k_2_path_2
- if reshape_k_2_path_1[2].name != gather_1.name or reshape_k_2_path_2[2].name != gather_2.name:
- return False
- # Check #4: check paths for q nodes
- concat_q_2_path = self.model.match_parent_path(reshape_q_2, ["Concat"], [1])
- if concat_q_2_path is None:
- return False
- concat_q_2 = concat_q_2_path[0]
- reshape_q_2_path_1 = self.model.match_parent_path(
- concat_q_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0]
- )
- reshape_q_2_path_2 = self.model.match_parent_path(concat_q_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
- if reshape_q_2_path_1 is None or reshape_q_2_path_2 is None:
- return False
- # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_q_2_path_1
- # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_q_2_path_2
- if reshape_q_2_path_1[2].name != gather_1.name or reshape_q_2_path_2[1].name != gather_2.name:
- return False
- # Check #5: check Mul nodes are the same for q, k, v
- mul_q = reshape_q_2_path_1[1]
- mul_k = reshape_k_2_path_1[1]
- mul_v = reshape_v_2_path_1[1]
- gather_1_out = gather_1.output[0]
- if mul_q.input[0] != gather_1_out or mul_k.input[0] != gather_1_out or mul_v.input[0] != gather_1_out:
- return False
- # Check #6: check paths for attention mask nodes
- attn_mask_path_1 = self.model.match_parent_path(add_qk, ["Concat", "Slice", "Slice"], [1, 0, 0])
- attn_mask_path_2 = self.model.match_parent_path(add_qk, ["Cast", "Concat", "Slice", "Slice"], [1, 0, 0, 0])
- if attn_mask_path_1 is not None:
- _, slice_qk_2, slice_qk_1 = attn_mask_path_1
- elif attn_mask_path_2 is not None:
- _, _, slice_qk_2, slice_qk_1 = attn_mask_path_2
- else:
- return False
- # Check first input to Slice #1 is 3D attention mask of shape (B,S,T)
- if slice_qk_1.input[0] not in {"attn_mask", "attention_mask"}:
- return False
- slice_qk_2_path = self.model.match_parent_path(
- slice_qk_2, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 1, 0]
- )
- slice_qk_1_path_1 = self.model.match_parent_path(
- slice_qk_1, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 1, 0]
- )
- slice_qk_1_path_2 = self.model.match_parent_path(slice_qk_1, ["Unsqueeze"], [1])
- if slice_qk_2_path is None or slice_qk_1_path_1 is None or slice_qk_1_path_2 is None:
- return False
- # Check Gather --> Add --> Unsqueeze #3 --> Slice #2 connection for slice_qk_2_path
- # Check Gather --> Add --> Unsqueeze #2 --> Slice #1 connection for slice_qk_1_path_1
- if slice_qk_2_path[1].name != slice_qk_1_path_1[1].name or slice_qk_2_path[2].name != slice_qk_1_path_1[2].name:
- return False
- # Check Unsqueeze #1 --> Slice #1 connection for slice_qk_1_path_2
- # Check if first input to Add and Unsqueeze #1 is position ids
- if slice_qk_1_path_1[1].input[0] != slice_qk_1_path_2[0].input[0]:
- return False
- return True
- def check_runtime_shape_paths_for_nodes(
- self,
- reshape_qkv, # Final reshape before o_proj MatMul
- reshape_q, # Reshape before q_proj MatMul
- reshape_k, # Reshape before k_proj MatMul
- reshape_v, # Reshape before v_proj MatMul
- root_input, # Root input to attention subgraph
- ):
- # Check #1: check paths for qkv nodes
- concat_qkv_path = self.model.match_parent_path(reshape_qkv, ["Concat"], [1])
- if concat_qkv_path is None:
- return False
- concat_qkv = concat_qkv_path[0]
- reshape_qkv_path_1 = self.model.match_parent_path(concat_qkv, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
- reshape_qkv_path_2 = self.model.match_parent_path(concat_qkv, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
- if reshape_qkv_path_1 is None or reshape_qkv_path_2 is None:
- return False
- _, gather_1, shape_1 = reshape_qkv_path_1
- _, gather_2, shape_2 = reshape_qkv_path_2
- # Check root_input --> Shape --> Gather connection
- if shape_1.input[0] != root_input or shape_2.input[0] != root_input:
- return False
- # Check #2: check paths for v nodes
- concat_v_path = self.model.match_parent_path(reshape_v, ["Concat"], [1])
- if concat_v_path is None:
- return False
- concat_v = concat_v_path[0]
- reshape_v_path_1 = self.model.match_parent_path(concat_v, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
- reshape_v_path_2 = self.model.match_parent_path(concat_v, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
- if reshape_v_path_1 is None or reshape_v_path_2 is None:
- return False
- # Check Gather --> Unsqueeze --> Concat --> Reshape connection
- if reshape_v_path_1[1].name != gather_1.name or reshape_v_path_2[1].name != gather_2.name:
- return False
- # Check #3: check paths for k nodes
- concat_k_path = self.model.match_parent_path(reshape_k, ["Concat"], [1])
- if concat_k_path is None:
- return False
- concat_k = concat_k_path[0]
- reshape_k_path_1 = self.model.match_parent_path(concat_k, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
- reshape_k_path_2 = self.model.match_parent_path(concat_k, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
- if reshape_k_path_1 is None or reshape_k_path_2 is None:
- return False
- # Check Gather --> Unsqueeze --> Concat --> Reshape connection
- if reshape_k_path_1[1].name != gather_1.name or reshape_k_path_2[1].name != gather_2.name:
- return False
- # Check #4: check paths for q nodes
- concat_q_path = self.model.match_parent_path(reshape_q, ["Concat"], [1])
- if concat_q_path is None:
- return False
- concat_q = concat_q_path[0]
- reshape_q_path_1 = self.model.match_parent_path(concat_q, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
- reshape_q_path_2 = self.model.match_parent_path(concat_q, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
- if reshape_q_path_1 is None or reshape_q_path_2 is None:
- return False
- # Check Gather --> Unsqueeze --> Concat --> Reshape connection
- if reshape_q_path_1[1].name != gather_1.name or reshape_q_path_2[1].name != gather_2.name:
- return False
- return True
- def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
- if normalize_node.op_type not in {"SkipSimplifiedLayerNormalization", "SkipLayerNormalization", "Add"}:
- return
- # qkv_nodes_1 is for LLaMA-2 Microsoft
- # qkv_nodes_2 is for LLaMA-2 Hugging Face
- # qkv_nodes_3 is for LLaMA-2 distribute Hugging Face model
- qkv_nodes = None
- qkv_nodes_1 = self.model.match_parent_path(
- normalize_node,
- ["MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
- [1, 0, 0, 0, 0],
- )
- qkv_nodes_2 = self.model.match_parent_path(
- normalize_node,
- ["MatMul", "Reshape", "Transpose", "MatMul"],
- [1, 0, 0, 0],
- )
- qkv_nodes_3 = self.model.match_parent_path(
- normalize_node,
- ["AllReduce", "MatMul", "Reshape", "Transpose", "MatMul"],
- [1, 0, 0, 0, 0],
- )
- if qkv_nodes_1 is not None:
- _, reshape_qkv_2, _, reshape_qkv_1, matmul_qkv = qkv_nodes_1
- qkv_nodes = qkv_nodes_1
- elif qkv_nodes_2 is not None:
- _, reshape_qkv, _, matmul_qkv = qkv_nodes_2
- qkv_nodes = qkv_nodes_2
- elif qkv_nodes_3 is not None:
- _, _, reshape_qkv, _, matmul_qkv = qkv_nodes_3
- qkv_nodes = qkv_nodes_3
- else:
- logger.debug("fuse_rotary_attention: failed to match qkv nodes")
- return
- # v_nodes_1 is for LLaMA-2 Microsoft
- # v_nodes_3 is for LLaMA-2 Hugging Face
- # v_nodes_4 is for LLaMA-2 70B model
- # v_nodes_5 is for Phi-2 DirectML
- past_v, present_v, past_seq_len = "", "", ""
- v_nodes = None
- add_v = None
- v_nodes_1 = self.model.match_parent_path(
- matmul_qkv,
- ["Reshape", "Transpose", "Concat", "Transpose", "Reshape", "MatMul"],
- [1, 0, 0, 1, 0, 0],
- )
- v_nodes_2 = self.model.match_parent_path(
- matmul_qkv,
- ["Concat", "Transpose", "Reshape", "MatMul"],
- [1, 1, 0, 0],
- )
- v_nodes_3 = self.model.match_parent_path(
- matmul_qkv,
- ["Transpose", "Reshape", "MatMul"],
- [1, 0, 0],
- )
- _, v_nodes_4, _ = self.model.match_parent_paths_all(
- matmul_qkv,
- [
- (
- ["Reshape", "Expand", "Unsqueeze", "Concat", "Transpose", "Reshape", "MatMul"],
- [1, 0, 0, 0, 1, 0, 0],
- ),
- (
- [
- "Reshape",
- "Expand",
- "Where",
- "Equal",
- "Reshape",
- "Concat",
- "Unsqueeze",
- "Gather",
- "Shape",
- "Concat",
- "Transpose",
- "Reshape",
- "MatMul",
- ],
- [1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
- ),
- (
- [
- "Reshape",
- "Expand",
- "Where",
- "Equal",
- "Mul",
- "ConstantOfShape",
- "Shape",
- "Reshape",
- "Concat",
- "Unsqueeze",
- "Gather",
- "Shape",
- "Concat",
- "Transpose",
- "Reshape",
- "MatMul",
- ],
- [1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0],
- ),
- (
- [
- "Reshape",
- "Expand",
- "Where",
- "ConstantOfShape",
- "Shape",
- "Reshape",
- "Concat",
- "Unsqueeze",
- "Gather",
- "Shape",
- "Concat",
- "Transpose",
- "Reshape",
- "MatMul",
- ],
- [1, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0],
- ),
- (
- [
- "Reshape",
- "Expand",
- "Where",
- "Reshape",
- "Concat",
- "Unsqueeze",
- "Gather",
- "Shape",
- "Concat",
- "Transpose",
- "Reshape",
- "MatMul",
- ],
- [1, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0],
- ),
- (
- ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"],
- [1, 1, 0, 0, 0, 0, 1, 0, 0],
- ),
- (
- [
- "Reshape",
- "Concat",
- "Unsqueeze",
- "Mul",
- "Gather",
- "Shape",
- "Concat",
- "Transpose",
- "Reshape",
- "MatMul",
- ],
- [1, 1, 1, 0, 0, 0, 0, 1, 0, 0],
- ),
- (
- ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"],
- [1, 1, 2, 0, 0, 0, 1, 0, 0],
- ),
- (
- ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"],
- [1, 1, 3, 0, 0, 0, 1, 0, 0],
- ),
- ],
- output_name_to_node=None,
- )
- v_nodes_5 = self.model.match_parent_path(
- matmul_qkv,
- ["Concat", "Transpose", "Reshape", "Add", "MatMul"],
- [1, 1, 0, 0, 1],
- )
- if v_nodes_1 is not None:
- reshape_v_2, _, concat_v, _, reshape_v_1, matmul_v = v_nodes_1
- v_nodes = v_nodes_1
- concat_v_path = self.model.match_parent_path(
- concat_v,
- ["Slice", "Unsqueeze"],
- [0, 2],
- )
- if concat_v_path is None:
- logger.debug("fuse_rotary_attention: failed to match past/present concat in v path")
- return
- past_v = concat_v_path[0].input[0]
- past_seq_len = concat_v_path[-1].input[0]
- present_v = concat_v.output[0]
- elif v_nodes_2 is not None:
- concat_v, transpose_v, reshape_v, matmul_v = v_nodes_2
- v_nodes = v_nodes_2
- past_v = concat_v.input[0]
- present_v = concat_v.output[0]
- elif v_nodes_3 is not None:
- transpose_v, reshape_v, matmul_v = v_nodes_3
- v_nodes = v_nodes_3
- present_v = transpose_v.output[0]
- elif v_nodes_4 is not None and len(v_nodes_4) == 9:
- concat_v, transpose_v, reshape_v, matmul_v = v_nodes_4[0][-4:]
- v_nodes = v_nodes_4
- past_v = concat_v.input[0]
- present_v = concat_v.output[0]
- elif v_nodes_5 is not None:
- concat_v, transpose_v, reshape_v, add_v, matmul_v = v_nodes_5
- matmul_v = add_v
- v_nodes = v_nodes_5
- past_v = concat_v.input[0]
- present_v = concat_v.output[0]
- else:
- logger.debug("fuse_rotary_attention: failed to match v path")
- return
- qk_nodes = self.model.match_parent_path(
- matmul_qkv,
- ["Softmax", "Add", "Div", "MatMul"],
- [0, 0, 0, 0],
- )
- add_qk, matmul_qk = None, None
- if qk_nodes is not None:
- _, add_qk, _, matmul_qk = qk_nodes
- else:
- logger.debug("fuse_rotary_attention: failed to match qk nodes")
- return
- # attn_mask_nodes_1, attn_mask_nodes_2 are for LLaMA-2 Microsoft's 3D attention mask
- # attn_mask_nodes_3, attn_mask_nodes_4 are for LLaMA-2 Hugging Face's 2D attention mask
- # attn_mask_nodes_5, attn_mask_nodes_6 are for LLaMA-2 Microsoft's model for the DML EP
- # attn_mask_nodes_7 is for LLaMA-2 Hugging Face's changes to the attention mask
- attn_mask, add_qk_str = "", ""
- attn_mask_nodes_1 = self.model.match_parent_path(
- add_qk,
- ["Concat", "Slice", "Slice"],
- [1, 0, 0],
- )
- attn_mask_nodes_2 = self.model.match_parent_path(
- add_qk,
- ["Cast", "Concat", "Slice", "Slice"],
- [1, 0, 0, 0],
- )
- attn_mask_nodes_3 = self.model.match_parent_path(
- add_qk,
- ["Add", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
- [1, 0, 2, 1, 0, 0, 0],
- )
- attn_mask_nodes_4 = self.model.match_parent_path(
- add_qk,
- ["Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
- [1, 2, 1, 0, 0, 0],
- )
- attn_mask_nodes_5 = self.model.match_parent_path(
- add_qk,
- ["Expand", "Add", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
- [1, 0, 0, 2, 1, 0, 0, 0],
- )
- attn_mask_nodes_6 = self.model.match_parent_path(
- add_qk,
- ["Expand", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
- [1, 0, 2, 1, 0, 0, 0],
- )
- attn_mask_nodes_7 = self.model.match_parent_path(
- add_qk,
- ["Where", "Cast", "Where", "Cast", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
- [1, 0, 0, 0, 0, 1, 0, 0, 0],
- )
- if attn_mask_nodes_1 is not None:
- _, slice_mask_1, slice_mask_2 = attn_mask_nodes_1
- attn_mask = slice_mask_1.output[0]
- elif attn_mask_nodes_2 is not None:
- _, _, slice_mask_1, slice_mask_2 = attn_mask_nodes_2
- attn_mask = slice_mask_1.output[0]
- elif attn_mask_nodes_3 is not None:
- # Reshape from (B,1,S,T) to (B,N,S,T)
- add_qk_str = self.reshape_add_qk(attn_mask_nodes_3[0].output[0])
- elif attn_mask_nodes_4 is not None:
- # Reshape from (B,1,S,T) to (B,N,S,T)
- add_qk_str = self.reshape_add_qk(attn_mask_nodes_4[0].output[0])
- elif attn_mask_nodes_5 is not None:
- # The mask has already been reshaped to (B,N,S,T)
- add_qk_str = attn_mask_nodes_5[0].output[0]
- elif attn_mask_nodes_6 is not None:
- # The mask has already been reshaped to (B,N,S,T)
- add_qk_str = attn_mask_nodes_6[0].output[0]
- elif attn_mask_nodes_7 is not None:
- # Reshape from (B,1,S,T) to (B,N,S,T)
- add_qk_str = self.reshape_add_qk(attn_mask_nodes_7[0].output[0])
- else:
- logger.debug("fuse_rotary_attention: failed to match attention mask nodes")
- return
- # k_nodes_1 is for LLaMA-2 Microsoft
- # k_nodes_2 is for LLaMA-2 Hugging Face
- # k_nodes_4 is for LLaMA-2 70B Hugging Face
- past_k, present_k = "", ""
- k_nodes = None
- slice_k = None
- concat_k_half = None
- k_nodes_1 = self.model.match_parent_path(
- matmul_qk,
- ["Reshape", "Transpose", "Concat", "Transpose", "RotaryEmbedding", "MatMul"],
- [1, 0, 0, 1, 0, 0],
- )
- k_nodes_2 = self.model.match_parent_path(
- matmul_qk,
- ["Transpose", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"],
- [1, 0, 0, 0, 0],
- )
- k_nodes_3 = self.model.match_parent_path(
- matmul_qk,
- ["Transpose", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"],
- [1, 0, 1, 0, 0, 0],
- )
- _, k_nodes_4, _ = self.model.match_parent_paths_all(
- matmul_qk,
- [
- (
- [
- "Transpose",
- "Reshape",
- "Expand",
- "Unsqueeze",
- "Concat",
- "RotaryEmbedding",
- "Transpose",
- "Reshape",
- "MatMul",
- ],
- [1, 0, 0, 0, 0, 1, 0, 0, 0],
- ),
- (
- [
- "Transpose",
- "Reshape",
- "Expand",
- "Where",
- "Equal",
- "Reshape",
- "Concat",
- "Unsqueeze",
- "Gather",
- "Shape",
- "Concat",
- "RotaryEmbedding",
- "Transpose",
- "Reshape",
- "MatMul",
- ],
- [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
- ),
- (
- [
- "Transpose",
- "Reshape",
- "Expand",
- "Where",
- "Equal",
- "Mul",
- "ConstantOfShape",
- "Shape",
- "Reshape",
- "Concat",
- "Unsqueeze",
- "Gather",
- "Shape",
- "Concat",
- "RotaryEmbedding",
- "Transpose",
- "Reshape",
- "MatMul",
- ],
- [1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0],
- ),
- (
- [
- "Transpose",
- "Reshape",
- "Expand",
- "Where",
- "ConstantOfShape",
- "Shape",
- "Reshape",
- "Concat",
- "Unsqueeze",
- "Gather",
- "Shape",
- "Concat",
- "RotaryEmbedding",
- "Transpose",
- "Reshape",
- "MatMul",
- ],
- [1, 0, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0],
- ),
- (
- [
- "Transpose",
- "Reshape",
- "Expand",
- "Where",
- "Reshape",
- "Concat",
- "Unsqueeze",
- "Gather",
- "Shape",
- "Concat",
- "RotaryEmbedding",
- "Transpose",
- "Reshape",
- "MatMul",
- ],
- [1, 0, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0, 0],
- ),
- (
- [
- "Transpose",
- "Reshape",
- "Concat",
- "Unsqueeze",
- "Gather",
- "Shape",
- "Concat",
- "RotaryEmbedding",
- "Transpose",
- "Reshape",
- "MatMul",
- ],
- [1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
- ),
- (
- [
- "Transpose",
- "Reshape",
- "Concat",
- "Unsqueeze",
- "Mul",
- "Gather",
- "Shape",
- "Concat",
- "RotaryEmbedding",
- "Transpose",
- "Reshape",
- "MatMul",
- ],
- [1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0],
- ),
- (
- [
- "Transpose",
- "Reshape",
- "Concat",
- "Unsqueeze",
- "Gather",
- "Shape",
- "Concat",
- "RotaryEmbedding",
- "Transpose",
- "Reshape",
- "MatMul",
- ],
- [1, 0, 1, 2, 0, 0, 0, 1, 0, 0, 0],
- ),
- (
- [
- "Transpose",
- "Reshape",
- "Concat",
- "Unsqueeze",
- "Gather",
- "Shape",
- "Concat",
- "RotaryEmbedding",
- "Transpose",
- "Reshape",
- "MatMul",
- ],
- [1, 0, 1, 3, 0, 0, 0, 1, 0, 0, 0],
- ),
- ],
- output_name_to_node=None,
- )
- k_nodes_5 = self.model.match_parent_path(
- matmul_qk,
- ["Transpose", "Concat", "Concat", "RotaryEmbedding", "Slice", "Transpose", "Reshape", "Add", "MatMul"],
- [1, 0, 1, 0, 0, 0, 0, 0, 1],
- )
- if k_nodes_1 is not None:
- reshape_k_2, _, concat_k, _, rotary_k, matmul_k = k_nodes_1
- k_nodes = k_nodes_1
- concat_k_path = self.model.match_parent_path(
- concat_k,
- ["Slice", "Unsqueeze"],
- [0, 2],
- )
- if concat_k_path is None:
- logger.debug("fuse_rotary_attention: failed to match past/present concat in k path")
- return
- past_k = concat_k_path[0].input[0]
- shared_past_seq_len = concat_k_path[-1].input[0]
- present_k = concat_k.output[0]
- assert past_seq_len == shared_past_seq_len
- elif k_nodes_2 is not None:
- _, rotary_k, _, reshape_k, matmul_k = k_nodes_2
- k_nodes = k_nodes_2
- present_k = rotary_k.output[0]
- elif k_nodes_3 is not None:
- _, concat_k, rotary_k, _, reshape_k, matmul_k = k_nodes_3
- k_nodes = k_nodes_3
- past_k = concat_k.input[0]
- present_k = concat_k.output[0]
- elif k_nodes_4 is not None and len(k_nodes_4) == 9:
- reshape_k, matmul_k = k_nodes_4[0][-2:]
- concat_k, rotary_k = k_nodes_4[0][-5:-3]
- k_nodes = k_nodes_4
- past_k = concat_k.input[0]
- present_k = concat_k.output[0]
- elif k_nodes_5 is not None:
- _, concat_k, concat_k_half, rotary_k, slice_k, _, reshape_k, _, matmul_k = k_nodes_5
- k_nodes = k_nodes_5
- past_k = concat_k.input[0]
- present_k = concat_k.output[0]
- else:
- logger.debug("fuse_rotary_attention: failed to match k nodes")
- return
- # q_nodes_1 is for LLaMA-2 Microsoft
- # q_nodes_2 is for LLaMA-2 Hugging Face
- # q_nodes_3 is for Phi-2 DirectML
- q_nodes = None
- slice_q = None
- concat_q_half = None
- q_nodes_1 = self.model.match_parent_path(
- matmul_qk,
- ["Reshape", "Transpose", "RotaryEmbedding", "MatMul"],
- [0, 0, 0, 0],
- )
- q_nodes_2 = self.model.match_parent_path(
- matmul_qk,
- ["RotaryEmbedding", "Transpose", "Reshape", "MatMul"],
- [0, 0, 0, 0],
- )
- q_nodes_3 = self.model.match_parent_path(
- matmul_qk,
- ["Concat", "RotaryEmbedding", "Slice", "Transpose", "Reshape", "Add", "MatMul"],
- [0, 0, 0, 0, 0, 0, 1],
- )
- if q_nodes_1 is not None:
- reshape_q_2, _, rotary_q, matmul_q = q_nodes_1
- q_nodes = q_nodes_1
- elif q_nodes_2 is not None:
- rotary_q, _, reshape_q, matmul_q = q_nodes_2
- q_nodes = q_nodes_2
- elif q_nodes_3 is not None:
- concat_q_half, rotary_q, slice_q, _, reshape_q, _, matmul_q = q_nodes_3
- q_nodes = q_nodes_3
- else:
- logger.debug("fuse_rotary_attention: failed to match q nodes")
- return
- if matmul_q.input[0] != matmul_k.input[0] and matmul_k.input[0] != matmul_v.input[0]:
- logger.debug("fuse_rotary_attention: failed to find the same root_input for q, k, v paths")
- return
- root_output = ""
- if qkv_nodes == qkv_nodes_1:
- if not self.check_runtime_shape_paths_for_function(
- reshape_qkv_2,
- reshape_qkv_1,
- reshape_q_2,
- reshape_k_2,
- reshape_v_2,
- reshape_v_1,
- add_qk,
- matmul_q.input[0],
- ):
- logger.debug("fuse_rotary_attention: failed to verify runtime shape paths")
- return
- root_output = reshape_qkv_2.output[0]
- elif qkv_nodes in (qkv_nodes_2, qkv_nodes_3):
- if not self.check_runtime_shape_paths_for_nodes(
- reshape_qkv,
- reshape_q,
- reshape_k,
- reshape_v,
- matmul_q.input[0],
- ):
- logger.debug("fuse_rotary_attention: failed to verify runtime shape paths")
- return
- root_output = reshape_qkv.output[0]
- # Rename inputs of rotary_q/k so it connects with output of matmul_q/k
- # Before: MatMul --> Reshape --> Transpose --> RotaryEmbedding
- # After: MatMul --> RotaryEmbedding
- rotary_q.input[0] = slice_q.output[0] if slice_q else matmul_q.output[0]
- rotary_k.input[0] = slice_k.output[0] if slice_k else matmul_k.output[0]
- # Rename current output of rotary_k (present_key) so it doesn't match output of MHA (present_key)
- if concat_q_half is None:
- rotary_k.output[0] = rotary_k.name + "_output_0"
- if qkv_nodes == qkv_nodes_3:
- qkv_nodes = qkv_nodes[1:]
- def create_hidden_size_concat_node(reshape_q):
- """Detect num_heads and hidden_size for ONNX model from phi-2
- Args:
- reshape_q (NodeProto): reshape node for q
- Returns:
- hidden_size_concat_node(NodeProto): Concat node to be used by reshape
- """
- concat = self.model.match_parent(reshape_q, "Concat", 1)
- if concat is None:
- logger.debug("fuse_rotary_attention: failed to trace the concat node from reshape_q")
- return None
- # The shape is a tensor like [?, ?, num_heads, head_size]
- num_head_constant_node = self.model.get_constant_value(concat.input[2])
- head_size_constant_node = self.model.get_constant_value(concat.input[3])
- if num_head_constant_node is None or head_size_constant_node is None:
- logger.debug("fuse_rotary_attention: failed to get constant nodes of num_heads or head_size")
- return None
- num_head_value = num_head_constant_node[0]
- head_size_value = head_size_constant_node[0]
- hidden_size = num_head_value * head_size_value
- hidden_size_initilizer = self.model.create_node_name("Initializer", name_prefix="hidden_size")
- if self.model.get_initializer(hidden_size_initilizer) is None:
- self.add_initializer(
- name=hidden_size_initilizer,
- data_type=TensorProto.INT64,
- dims=[1],
- vals=[hidden_size],
- raw=False,
- )
- hidden_size_reshape_node_name = self.model.create_node_name("Concat", name_prefix="hidden_size_concat")
- hidden_size_concat_node = helper.make_node(
- "Concat",
- inputs=[
- concat.input[0],
- concat.input[1],
- hidden_size_initilizer,
- ],
- outputs=[hidden_size_reshape_node_name + "output_0"],
- name=hidden_size_reshape_node_name,
- )
- hidden_size_concat_node.attribute.extend([helper.make_attribute("axis", 0)])
- return hidden_size_concat_node
- # Add Tranpose and Reshape nodes for patial rotary embedding applied in phi-2 before passing into MHA
- if concat_q_half and concat_k_half:
- # Transpose the key output of rotary Embedding
- k_transpose_node_name = self.model.create_node_name("Transpose")
- k_tranpose_output_name = k_transpose_node_name + "_output_0"
- k_transpose_node = helper.make_node(
- "Transpose",
- inputs=[concat_k_half.output[0]],
- outputs=[k_tranpose_output_name],
- name=k_transpose_node_name,
- )
- k_transpose_node.attribute.extend([helper.make_attribute("perm", [0, 2, 1, 3])])
- # Transpose the query output of rotary Embedding
- q_transpose_node_name = self.model.create_node_name("Transpose")
- q_tranpose_output_name = q_transpose_node_name + "_output_0"
- q_transpose_node = helper.make_node(
- "Transpose",
- inputs=[concat_q_half.output[0]],
- outputs=[q_tranpose_output_name],
- name=q_transpose_node_name,
- )
- q_transpose_node.attribute.extend([helper.make_attribute("perm", [0, 2, 1, 3])])
- hidden_size_concat_node = create_hidden_size_concat_node(reshape_k)
- if hidden_size_concat_node is None:
- logger.debug("fuse_rotary_attention: failed to create hidden_size_concat_node")
- return
- # Reshape the Rotary Embedding output for key for 4D to 3D
- concat_k_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="concat_k_half")
- concat_k_reshape_node = helper.make_node(
- "Reshape",
- inputs=[k_transpose_node.output[0], hidden_size_concat_node.output[0]],
- outputs=[concat_k_reshape_node_name + "_output_0"],
- name=concat_k_reshape_node_name,
- )
- # Reshape the Rotary Embedding output for query from 4D to 3D
- concat_q_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="concat_q_half")
- concat_q_reshape_node = helper.make_node(
- "Reshape",
- inputs=[q_transpose_node.output[0], hidden_size_concat_node.output[0]],
- outputs=[concat_q_reshape_node_name + "_output_0"],
- name=concat_q_reshape_node_name,
- )
- rotary_k = concat_k_reshape_node
- rotary_q = concat_q_reshape_node
- self.nodes_to_add.append(hidden_size_concat_node)
- self.nodes_to_add.append(k_transpose_node)
- self.nodes_to_add.append(q_transpose_node)
- self.nodes_to_add.append(concat_k_reshape_node)
- self.nodes_to_add.append(concat_q_reshape_node)
- self.node_name_to_graph_name[hidden_size_concat_node.name] = self.this_graph_name
- self.node_name_to_graph_name[k_transpose_node.name] = self.this_graph_name
- self.node_name_to_graph_name[q_transpose_node.name] = self.this_graph_name
- self.node_name_to_graph_name[concat_k_reshape_node.name] = self.this_graph_name
- self.node_name_to_graph_name[concat_q_reshape_node.name] = self.this_graph_name
- new_node = self.create_mha_node(
- matmul_q.input[0],
- root_output,
- rotary_q,
- rotary_k,
- matmul_v,
- attn_mask,
- add_qk_str,
- past_k,
- past_v,
- present_k,
- present_v,
- )
- if new_node is None:
- logger.debug("fuse_rotary_attention: failed to create multi-head attention with rotary embeddings")
- return
- self.nodes_to_add.append(new_node)
- self.node_name_to_graph_name[new_node.name] = self.this_graph_name
- self.nodes_to_remove.extend(qkv_nodes[1:])
- if v_nodes != v_nodes_4:
- self.nodes_to_remove.extend(v_nodes[:-1] if add_v is None else v_nodes[:-2])
- else:
- nodes_to_keep = [v_nodes[0][-1]]
- for temp_path in v_nodes:
- self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep)
- self.nodes_to_remove.extend(qk_nodes)
- if k_nodes == k_nodes_1:
- self.nodes_to_remove.extend(k_nodes[:-2])
- elif k_nodes == k_nodes_2:
- self.nodes_to_remove.append(k_nodes[0])
- self.nodes_to_remove.append(k_nodes[2])
- self.nodes_to_remove.append(k_nodes[3])
- elif k_nodes == k_nodes_3:
- self.nodes_to_remove.append(k_nodes[0])
- self.nodes_to_remove.append(k_nodes[1])
- self.nodes_to_remove.append(k_nodes[3])
- self.nodes_to_remove.append(k_nodes[4])
- elif k_nodes == k_nodes_5:
- self.nodes_to_remove.append(k_nodes[0])
- self.nodes_to_remove.append(k_nodes[1])
- elif k_nodes == k_nodes_4:
- nodes_to_keep = [k_nodes[0][-1], k_nodes[0][-4]]
- for temp_path in k_nodes:
- self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep)
- if q_nodes == q_nodes_1:
- self.nodes_to_remove.extend(q_nodes[:-2])
- elif q_nodes == q_nodes_2:
- self.nodes_to_remove.append(q_nodes[1])
- self.nodes_to_remove.append(q_nodes[2])
- self.prune_graph = True
- class FusionRotaryEmbeddings(Fusion):
- def __init__(self, model: OnnxModel):
- self.base_name = "RotaryEmbedding"
- super().__init__(model, self.base_name, [self.base_name, self.base_name + ".1", "Add"])
- # The RotaryEmbedding function can have multiple extraneous constant outputs even though the function is supposed to produce only one output.
- # This is a byproduct of a potential CSE bug when using `export_modules_as_functions` in the TorchScript exporter.
- # To work around this issue, we set the extraneous constant values from the RotaryEmbedding function as initializers in the locations where they are actually used.
- def reassign_extra_outputs(self, rot_emb_node: NodeProto, function: FunctionProto):
- # Find extra outputs and Constant nodes attached to those outputs
- extra_constants, extra_outputs = [], []
- for fn_node in function.node:
- if fn_node.op_type == "Constant" and fn_node.input == [] and fn_node.output[0] in function.output:
- extra_constants.append(fn_node)
- output_index = list(function.output).index(fn_node.output[0])
- extra_outputs.append(rot_emb_node.output[output_index])
- # Set extra Constant node outputs as initializers
- extra_initializers = []
- for extra_constant in extra_constants:
- constant_tensorproto = extra_constant.attribute[0].t
- constant_tensorproto.name = self.model.create_node_name("Constant")
- self.model.add_initializer(constant_tensorproto)
- extra_initializers.append(constant_tensorproto.name)
- # Update references of Constant node outputs to initializer references
- for extra_output, extra_initializer in zip(extra_outputs, extra_initializers, strict=False):
- nodes_to_update = list(filter(lambda entry: extra_output in entry.input, self.model.model.graph.node))
- for node_to_update in nodes_to_update:
- OnnxModel.replace_node_input(node_to_update, extra_output, extra_initializer)
- return extra_outputs
- def create_rotary_embeddings_from_function(self, node: NodeProto):
- rotary_emb_node_name = self.model.create_node_name(self.base_name)
- matmul_path = self.model.match_parent_path(
- node,
- ["Reshape", "MatMul"],
- [0, 0],
- )
- if matmul_path is not None:
- reshape_node, matmul_node = matmul_path
- else:
- logger.debug("fuse_rotary_embeddings: failed to match MatMul")
- return
- rotary_emb_inputs = [
- matmul_node.output[0], # x is of shape (B,S,D) instead of (B,S,N,H)
- node.input[1], # position_ids
- ]
- # Convert cos_cache and sin_cache from node attributes to model initializers
- cos_cache_node = list(filter(lambda constant: constant.output[0] == node.input[2], self.model.model.graph.node))
- sin_cache_node = list(filter(lambda constant: constant.output[0] == node.input[3], self.model.model.graph.node))
- cos_cache_name, sin_cache_name = "cos_cache", "sin_cache"
- if (
- len(cos_cache_node) == 1
- and len(sin_cache_node) == 1
- and self.model.get_initializer(cos_cache_name) is None
- and self.model.get_initializer(sin_cache_name) is None
- ):
- cos_cache = numpy_helper.to_array(cos_cache_node[0].attribute[0].t).squeeze()
- sin_cache = numpy_helper.to_array(sin_cache_node[0].attribute[0].t).squeeze()
- cos_cache_tensor = helper.make_tensor(
- name=cos_cache_name,
- data_type=TensorProto.FLOAT,
- dims=list(cos_cache.shape),
- vals=cos_cache.flatten().tolist(),
- )
- self.model.add_initializer(cos_cache_tensor, self.this_graph_name)
- sin_cache_tensor = helper.make_tensor(
- name=sin_cache_name,
- data_type=TensorProto.FLOAT,
- dims=list(sin_cache.shape),
- vals=sin_cache.flatten().tolist(),
- )
- self.model.add_initializer(sin_cache_tensor, self.this_graph_name)
- self.nodes_to_remove.extend([cos_cache_node[0], sin_cache_node[0]])
- rotary_emb_inputs.extend([cos_cache_name, sin_cache_name])
- rotary_emb_outputs = node.output
- if len(rotary_emb_outputs) > 1:
- # Re-assign extraneous constant outputs in RotaryEmbedding functions as initializers
- func = list(filter(lambda fn: fn.name == node.op_type, self.model.model.functions))
- assert len(func) == 1
- extra_outputs = self.reassign_extra_outputs(node, func[0])
- rotary_emb_outputs = list(filter(lambda output_name: output_name not in extra_outputs, rotary_emb_outputs))
- assert len(rotary_emb_outputs) == 1
- rotary_emb_node = helper.make_node(
- self.base_name,
- inputs=rotary_emb_inputs,
- outputs=rotary_emb_outputs,
- name=rotary_emb_node_name,
- interleaved=1,
- )
- rotary_emb_node.domain = "com.microsoft"
- self.nodes_to_remove.append(reshape_node)
- return rotary_emb_node
- def create_rotary_embeddings_from_nodes(
- self,
- root_input: str,
- position_ids: str,
- cos_slice: str,
- sin_slice: str,
- output: str,
- ):
- rotary_emb_node_name = self.model.create_node_name(self.base_name)
- # Convert cos_cache and sin_cache from node attributes to model initializers
- cos_cache_node = list(filter(lambda constant: constant.output[0] == cos_slice, self.model.model.graph.node))
- sin_cache_node = list(filter(lambda constant: constant.output[0] == sin_slice, self.model.model.graph.node))
- cos_cache_name, sin_cache_name = "cos_cache", "sin_cache"
- if (
- len(cos_cache_node) == 1
- and len(sin_cache_node) == 1
- and self.model.get_initializer(cos_cache_name) is None
- and self.model.get_initializer(sin_cache_name) is None
- ):
- cos_cache = numpy_helper.to_array(cos_cache_node[0].attribute[0].t).squeeze()
- sin_cache = numpy_helper.to_array(sin_cache_node[0].attribute[0].t).squeeze()
- # Reshape cos/sin cache from (M, H) to (M, H/2)
- head_size = cos_cache.shape[1]
- cos_cache = cos_cache[:, : (head_size // 2)]
- sin_cache = sin_cache[:, : (head_size // 2)]
- cos_cache_tensor = helper.make_tensor(
- name=cos_cache_name,
- data_type=TensorProto.FLOAT,
- dims=list(cos_cache.shape),
- vals=cos_cache.flatten().tolist(),
- )
- self.model.add_initializer(cos_cache_tensor, self.this_graph_name)
- sin_cache_tensor = helper.make_tensor(
- name=sin_cache_name,
- data_type=TensorProto.FLOAT,
- dims=list(sin_cache.shape),
- vals=sin_cache.flatten().tolist(),
- )
- self.model.add_initializer(sin_cache_tensor, self.this_graph_name)
- self.nodes_to_remove.extend([cos_cache_node[0], sin_cache_node[0]])
- rotary_emb_node = helper.make_node(
- self.base_name,
- inputs=[root_input, position_ids, cos_cache_name, sin_cache_name],
- outputs=[output],
- name=rotary_emb_node_name,
- interleaved=0,
- )
- rotary_emb_node.domain = "com.microsoft"
- return rotary_emb_node
- def fuse(self, node, input_name_to_nodes, output_name_to_node):
- # Node is either RotaryEmbedding function or Add
- if self.base_name not in node.op_type and node.op_type != "Add":
- return
- # Check if node is "RotaryEmbedding nn.Module" exported as a function
- # (e.g. export_modules_as_functions={RotaryEmbedding} in torch.onnx.export)
- rotary_emb_node = None
- if node.op_type != "Add":
- # Verify that function has the correct inputs
- if len(node.input) not in {4, 5} or node.input[1] not in {
- "pos",
- "pos_id",
- "position_id",
- "pos_ids",
- "position_ids",
- }:
- logger.debug("fuse_rotary_embeddings: failed to verify inputs for RotaryEmbedding function")
- return
- rotary_emb_node = self.create_rotary_embeddings_from_function(node)
- if rotary_emb_node is None:
- logger.debug("fuse_rotary_embeddings: failed to create RotaryEmbedding node")
- return
- # Remove RotaryEmbedding function
- self.nodes_to_remove.append(node)
- # Remove RotaryEmbedding function's shape inference stored in value_info
- # The new shape will be calculated during symbolic shape inference
- old_shape_infer = list(
- filter(lambda node: node.name == rotary_emb_node.output[0], self.model.model.graph.value_info)
- )
- assert len(old_shape_infer) == 1
- self.model.model.graph.value_info.remove(old_shape_infer[0])
- else:
- # Rotary embeddings are defined using the below functions:
- #
- # def rotate_half(x):
- # """Rotates half the hidden dims of the input."""
- # x1 = x[..., : x.shape[-1] // 2]
- # x2 = x[..., x.shape[-1] // 2 :]
- # return torch.cat((-x2, x1), dim=-1)
- #
- # def apply_rope(x, cos, sin, position_ids):
- # cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
- # sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
- # cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
- # sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
- # x_embed = (x * cos) + (rotate_half(x) * sin)
- # return x_embed
- # Check paths for rotate_half(x)
- rotate_half_x2_path_1_1 = self.model.match_parent_path(
- node,
- ["Mul", "Concat", "Neg", "Slice", "Transpose"],
- [1, 0, 0, 0, 0],
- )
- rotate_half_x2_path_1_2 = self.model.match_parent_path(
- node,
- ["Mul", "Concat", "Neg", "Slice", "Slice"],
- [1, 0, 0, 0, 0],
- )
- rotate_half_x2_path_1 = rotate_half_x2_path_1_1 or rotate_half_x2_path_1_2
- rotate_half_x2_path_2_1 = self.model.match_parent_path(
- node,
- ["Mul", "Concat", "Neg", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Transpose"],
- [1, 0, 0, 0, 1, 0, 0, 0, 0],
- )
- rotate_half_x2_path_2_2 = self.model.match_parent_path(
- node,
- ["Mul", "Concat", "Neg", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Slice"],
- [1, 0, 0, 0, 1, 0, 0, 0, 0],
- )
- rotate_half_x2_path_2 = rotate_half_x2_path_2_1 or rotate_half_x2_path_2_2
- if rotate_half_x2_path_1 is None or rotate_half_x2_path_2 is None:
- logger.debug("fuse_rotary_embeddings: failed to match x2 in rotate_half")
- return
- rotate_half_x1_path_1_1 = self.model.match_parent_path(
- node,
- ["Mul", "Concat", "Slice", "Transpose"],
- [1, 0, 1, 0],
- )
- rotate_half_x1_path_1_2 = self.model.match_parent_path(
- node,
- ["Mul", "Concat", "Slice", "Slice"],
- [1, 0, 1, 0],
- )
- rotate_half_x1_path_1 = rotate_half_x1_path_1_1 or rotate_half_x1_path_1_2
- rotate_half_x1_path_2_1 = self.model.match_parent_path(
- node,
- ["Mul", "Concat", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Transpose"],
- [1, 0, 1, 2, 0, 0, 0, 0],
- )
- rotate_half_x1_path_2_2 = self.model.match_parent_path(
- node,
- ["Mul", "Concat", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Slice"],
- [1, 0, 1, 2, 0, 0, 0, 0],
- )
- rotate_half_x1_path_2 = rotate_half_x1_path_2_1 or rotate_half_x1_path_2_2
- if rotate_half_x1_path_1 is None or rotate_half_x1_path_2 is None:
- logger.debug("fuse_rotary_embeddings: failed to match x1 in rotate_half")
- return
- if (
- rotate_half_x1_path_1[-1].name != rotate_half_x1_path_2[-1].name
- or rotate_half_x2_path_1[-1].name != rotate_half_x2_path_2[-1].name
- or rotate_half_x1_path_1[-1].name != rotate_half_x2_path_1[-1].name
- or rotate_half_x1_path_2[-1].name != rotate_half_x2_path_2[-1].name
- ):
- logger.debug("fuse_rotary_embeddings: failed to match common input in rotate_half")
- return
- # Check path for x
- x_path_1 = self.model.match_parent_path(
- node,
- ["Mul", "Transpose"],
- [0, 0],
- )
- x_path_2 = self.model.match_parent_path(
- node,
- ["Mul", "Slice"],
- [0, 0],
- )
- x_path = x_path_1 or x_path_2
- if x_path is None:
- logger.debug("fuse_rotary_embeddings: failed to match x in rotate_half")
- return
- # Check path for sin
- sin_path, sin_cache, position_ids = None, "", ""
- sin_path_1 = self.model.match_parent_path(
- node,
- ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Gather", "Shape"],
- [1, 1, 0, 0, 0, 0, 2, 0, 0],
- )
- sin_path_2 = self.model.match_parent_path(
- node,
- ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Add"],
- [1, 1, 0, 0, 0, 0, 2, 0],
- )
- sin_path_3 = self.model.match_parent_path(
- node,
- ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Gather", "Shape"],
- [1, 1, 0, 0, 2, 0, 0],
- )
- sin_path_4 = self.model.match_parent_path(
- node,
- ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Add"],
- [1, 1, 0, 0, 2, 0],
- )
- if sin_path_1 is not None:
- sin_path = sin_path_1
- sin_cache = sin_path[-4].input[0]
- elif sin_path_2 is not None:
- sin_path = sin_path_2
- sin_cache = sin_path[-3].input[0]
- elif sin_path_3 is not None:
- sin_path = sin_path_3
- sin_cache = sin_path[-4].input[0]
- position_ids = sin_path[2].input[1]
- elif sin_path_4 is not None:
- sin_path = sin_path_4
- sin_cache = sin_path[-3].input[0]
- position_ids = sin_path[2].input[1]
- else:
- logger.debug("fuse_rotary_embeddings: failed to match sin path in apply_rope")
- return
- # Check path for cos
- cos_path, cos_cache = None, ""
- cos_path_1 = self.model.match_parent_path(
- node,
- ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Gather", "Shape"],
- [0, 1, 0, 0, 0, 0, 2, 0, 0],
- )
- cos_path_2 = self.model.match_parent_path(
- node,
- ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Add"],
- [0, 1, 0, 0, 0, 0, 2, 0],
- )
- cos_path_3 = self.model.match_parent_path(
- node,
- ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Gather", "Shape"],
- [0, 1, 0, 0, 2, 0, 0],
- )
- cos_path_4 = self.model.match_parent_path(
- node,
- ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Add"],
- [0, 1, 0, 0, 2, 0],
- )
- if cos_path_1 is not None:
- cos_path = cos_path_1
- cos_cache = cos_path[-4].input[0]
- elif cos_path_2 is not None:
- cos_path = cos_path_2
- cos_cache = cos_path[-3].input[0]
- elif cos_path_3 is not None:
- cos_path = cos_path_3
- cos_cache = cos_path[-4].input[0]
- position_ids = cos_path[2].input[1]
- elif cos_path_4 is not None:
- cos_path = cos_path_4
- cos_cache = cos_path[-3].input[0]
- position_ids = cos_path[2].input[1]
- else:
- logger.debug("fuse_rotary_embeddings: failed to match sin path in apply_rope")
- return
- # Check path for position ids
- if position_ids == "":
- position_ids_from_sin_path = self.model.match_parent_path(
- sin_path[2],
- ["Reshape"],
- [1],
- )
- position_ids_from_cos_path = self.model.match_parent_path(
- cos_path[2],
- ["Reshape"],
- [1],
- )
- if (
- position_ids_from_sin_path is None
- or position_ids_from_cos_path is None
- or position_ids_from_sin_path[0].name != position_ids_from_cos_path[0].name
- ):
- logger.debug("fuse_rotary_embeddings: failed to match position ids path in apply_rope")
- return
- position_ids = position_ids_from_cos_path[0].input[0]
- else:
- position_ids_from_sin_path = []
- position_ids_from_cos_path = []
- past_seq_len_path, curr_seq_len_path = None, None
- if (sin_path == sin_path_1 and cos_path == cos_path_1) or (
- sin_path == sin_path_3 and cos_path == cos_path_3
- ):
- if sin_path[-2].name != cos_path[-2].name or sin_path[-1].name != cos_path[-1].name:
- logger.debug(
- "fuse_rotary_embeddings: failed to match common Gather node and Shape node in sin cache and cos cache"
- )
- return
- elif (sin_path == sin_path_2 and cos_path == cos_path_2) or (
- sin_path == sin_path_4 and cos_path == cos_path_4
- ):
- if sin_path[-1].name != cos_path[-1].name:
- logger.debug("fuse_rotary_embeddings: failed to match common Add node in sin cache and cos cache")
- return
- # Match past sequence length path: past_key --> Shape --> Gather --> Add
- past_seq_len_path = self.model.match_parent_path(
- sin_path[-1],
- ["Gather", "Shape"],
- [1, 0],
- )
- # Match current sequence length path: transpose_k --> Shape --> Gather --> Add
- curr_seq_len_path = self.model.match_parent_path(
- sin_path[-1],
- ["Gather", "Shape", "Transpose"],
- [0, 0, 0],
- )
- if (
- past_seq_len_path is None
- or curr_seq_len_path is None
- or self.model.find_graph_input(past_seq_len_path[-1].input[0]) is None
- or curr_seq_len_path[-1].op_type != "Transpose"
- ):
- logger.debug("fuse_rotary_embeddings: failed to match past_seq_len and curr_seq_len paths")
- return
- else:
- logger.debug("fuse_rotary_embeddings: failed to match common cache paths")
- rotary_emb_node = self.create_rotary_embeddings_from_nodes(
- rotate_half_x1_path_1[-1].output[0],
- position_ids,
- cos_cache,
- sin_cache,
- node.output[0],
- )
- if rotary_emb_node is None:
- logger.debug("fuse_rotary_embeddings: failed to create RotaryEmbedding node")
- return
- # Remove rotary embedding nodes
- self.add_nodes_to_remove([node])
- self.add_nodes_to_remove(rotate_half_x1_path_1[:-1])
- self.add_nodes_to_remove(rotate_half_x1_path_2[:-1])
- self.add_nodes_to_remove(rotate_half_x2_path_1[:-1])
- self.add_nodes_to_remove(rotate_half_x2_path_2[:-1])
- self.add_nodes_to_remove(x_path[:-1])
- self.add_nodes_to_remove(sin_path)
- self.add_nodes_to_remove(cos_path)
- self.add_nodes_to_remove(position_ids_from_sin_path[:-1])
- self.add_nodes_to_remove(position_ids_from_cos_path[:-1])
- if past_seq_len_path is not None and len(self.model.get_children(past_seq_len_path[0])) == 1:
- # In merged HF model, output of Gather in past_seq_len_path is used twice
- # for past_key_values.0.key and once for other past_key_values
- self.add_nodes_to_remove(past_seq_len_path)
- if curr_seq_len_path is not None:
- self.add_nodes_to_remove(curr_seq_len_path[:-1])
- self.increase_counter(self.base_name)
- self.node_name_to_graph_name[rotary_emb_node.name] = self.this_graph_name
- self.nodes_to_add.append(rotary_emb_node)
- self.prune_graph = True
|