fusion_rotary_attention.py 65 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import logging
  6. from fusion_attention import FusionAttention
  7. from fusion_base import Fusion
  8. from onnx import FunctionProto, NodeProto, TensorProto, helper, numpy_helper
  9. from onnx_model import OnnxModel
  10. logger = logging.getLogger(__name__)
  11. class FusionRotaryAttention(FusionAttention):
  12. """
  13. Fuse Attention subgraph with rotary positional embeddings into one MultiHeadAttention node.
  14. """
  15. def __init__(
  16. self,
  17. model: OnnxModel,
  18. hidden_size: int,
  19. num_heads: int,
  20. ):
  21. super().__init__(
  22. model,
  23. hidden_size,
  24. num_heads,
  25. use_multi_head_attention=True,
  26. search_op_types=[
  27. "SimplifiedLayerNormalization",
  28. "SkipSimplifiedLayerNormalization",
  29. "LayerNormalization",
  30. "SkipLayerNormalization",
  31. "Add",
  32. ],
  33. )
  34. def create_mha_node(
  35. self,
  36. input: str,
  37. output: str,
  38. q_rotary: NodeProto,
  39. k_rotary: NodeProto,
  40. v_matmul: NodeProto,
  41. attn_mask: str = "",
  42. add_qk: str = "",
  43. past_k: str = "",
  44. past_v: str = "",
  45. present_k: str = "",
  46. present_v: str = "",
  47. scale: float | None = None,
  48. ) -> NodeProto | None:
  49. assert self.num_heads > 0
  50. if self.hidden_size > 0 and (self.hidden_size % self.num_heads) != 0:
  51. logger.debug(
  52. f"fuse_rotary_attention: input hidden size {self.hidden_size} is not a multiple of num of heads {self.num_heads}"
  53. )
  54. return None
  55. mha_node_name = self.model.create_node_name("MultiHeadAttention")
  56. mha_inputs = [
  57. q_rotary.output[0],
  58. k_rotary.output[0],
  59. v_matmul.output[0],
  60. "", # bias
  61. attn_mask, # key_padding_mask
  62. add_qk, # attention_bias
  63. past_k,
  64. past_v,
  65. ]
  66. mha_outputs = [output]
  67. if present_k and present_v:
  68. mha_outputs.extend([present_k, present_v])
  69. mha_node = helper.make_node(
  70. "MultiHeadAttention",
  71. inputs=mha_inputs,
  72. outputs=mha_outputs,
  73. name=mha_node_name,
  74. )
  75. mha_node.domain = "com.microsoft"
  76. mha_node.attribute.extend([helper.make_attribute("num_heads", self.num_heads)])
  77. if scale is not None:
  78. mha_node.attribute.extend([helper.make_attribute("scale", scale)])
  79. if self.mask_filter_value is not None:
  80. mha_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
  81. self.increase_counter("MultiHeadAttention")
  82. return mha_node
  83. def check_runtime_shape_paths_for_function(
  84. self,
  85. reshape_qkv_2, # Reshape after Transpose
  86. reshape_qkv_1, # Reshape before Transpose
  87. reshape_q_2, # Reshape after RotaryEmbedding
  88. reshape_k_2, # Reshape after RotaryEmbedding
  89. reshape_v_2, # Reshape after Transpose
  90. reshape_v_1, # Reshape before Transpose
  91. add_qk, # Add before Softmax
  92. root_input, # Root input to attention subgraph
  93. ):
  94. # Check #1: check paths for qkv nodes
  95. concat_qkv_2_path = self.model.match_parent_path(reshape_qkv_2, ["Concat"], [1])
  96. concat_qkv_1_path = self.model.match_parent_path(reshape_qkv_1, ["Concat"], [1])
  97. if concat_qkv_2_path is None or concat_qkv_1_path is None:
  98. return False
  99. concat_qkv_2, concat_qkv_1 = concat_qkv_2_path[0], concat_qkv_1_path[0]
  100. reshape_qkv_2_path_1 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
  101. reshape_qkv_2_path_2 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
  102. reshape_qkv_1_path_1 = self.model.match_parent_path(concat_qkv_1, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
  103. reshape_qkv_1_path_2 = self.model.match_parent_path(concat_qkv_1, ["Unsqueeze", "Gather", "Shape"], [2, 0, 0])
  104. if (
  105. reshape_qkv_2_path_1 is None
  106. or reshape_qkv_2_path_2 is None
  107. or reshape_qkv_1_path_1 is None
  108. or reshape_qkv_1_path_2 is None
  109. ):
  110. return False
  111. _, gather_1, shape_1 = reshape_qkv_2_path_1
  112. _, gather_2, shape_2 = reshape_qkv_2_path_2
  113. # Check root_input --> Shape --> Gather connection
  114. if shape_1.input[0] != root_input or shape_2.input[0] != root_input:
  115. return False
  116. # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_qkv_1_path_1 and reshape_qkv_1_path_2
  117. if reshape_qkv_1_path_1[1].name != gather_1.name or reshape_qkv_1_path_2[1].name != gather_2.name:
  118. return False
  119. # Check #2: check paths for v nodes
  120. concat_v_2_path = self.model.match_parent_path(reshape_v_2, ["Concat"], [1])
  121. concat_v_1_path = self.model.match_parent_path(reshape_v_1, ["Concat"], [1])
  122. if concat_v_2_path is None or concat_v_1_path is None:
  123. return False
  124. concat_v_2, concat_v_1 = concat_v_2_path[0], concat_v_1_path[0]
  125. reshape_v_2_path_1 = self.model.match_parent_path(
  126. concat_v_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0]
  127. )
  128. reshape_v_2_path_2 = self.model.match_parent_path(
  129. concat_v_2, ["Unsqueeze", "Add", "Gather", "Shape"], [1, 0, 0, 0]
  130. )
  131. reshape_v_1_path_1 = self.model.match_parent_path(concat_v_1, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
  132. reshape_v_1_path_2 = self.model.match_parent_path(concat_v_1, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
  133. if (
  134. reshape_v_2_path_1 is None
  135. or reshape_v_2_path_2 is None
  136. or reshape_v_1_path_1 is None
  137. or reshape_v_1_path_2 is None
  138. ):
  139. return False
  140. # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_v_2_path_1
  141. # Check Gather --> Add --> Unsqueeze --> Concat --> Reshape connection for reshape_v_2_path_2
  142. # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_v_1_path_1 and reshape_v_1_path_2
  143. if (
  144. reshape_v_2_path_1[2].name != gather_1.name
  145. or reshape_v_2_path_2[2].name != gather_2.name
  146. or reshape_v_1_path_1[1].name != gather_1.name
  147. or reshape_v_1_path_2[1].name != gather_2.name
  148. ):
  149. return False
  150. # Check #3: check paths for k nodes
  151. concat_k_2_path = self.model.match_parent_path(reshape_k_2, ["Concat"], [1])
  152. if concat_k_2_path is None:
  153. return False
  154. concat_k_2 = concat_k_2_path[0]
  155. reshape_k_2_path_1 = self.model.match_parent_path(
  156. concat_k_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0]
  157. )
  158. reshape_k_2_path_2 = self.model.match_parent_path(
  159. concat_k_2, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 0, 0]
  160. )
  161. if reshape_k_2_path_1 is None or reshape_k_2_path_2 is None:
  162. return False
  163. # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_k_2_path_1
  164. # Check Gather --> Add --> Unsqueeze --> Concat --> Reshape connection for reshape_k_2_path_2
  165. if reshape_k_2_path_1[2].name != gather_1.name or reshape_k_2_path_2[2].name != gather_2.name:
  166. return False
  167. # Check #4: check paths for q nodes
  168. concat_q_2_path = self.model.match_parent_path(reshape_q_2, ["Concat"], [1])
  169. if concat_q_2_path is None:
  170. return False
  171. concat_q_2 = concat_q_2_path[0]
  172. reshape_q_2_path_1 = self.model.match_parent_path(
  173. concat_q_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0]
  174. )
  175. reshape_q_2_path_2 = self.model.match_parent_path(concat_q_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
  176. if reshape_q_2_path_1 is None or reshape_q_2_path_2 is None:
  177. return False
  178. # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_q_2_path_1
  179. # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_q_2_path_2
  180. if reshape_q_2_path_1[2].name != gather_1.name or reshape_q_2_path_2[1].name != gather_2.name:
  181. return False
  182. # Check #5: check Mul nodes are the same for q, k, v
  183. mul_q = reshape_q_2_path_1[1]
  184. mul_k = reshape_k_2_path_1[1]
  185. mul_v = reshape_v_2_path_1[1]
  186. gather_1_out = gather_1.output[0]
  187. if mul_q.input[0] != gather_1_out or mul_k.input[0] != gather_1_out or mul_v.input[0] != gather_1_out:
  188. return False
  189. # Check #6: check paths for attention mask nodes
  190. attn_mask_path_1 = self.model.match_parent_path(add_qk, ["Concat", "Slice", "Slice"], [1, 0, 0])
  191. attn_mask_path_2 = self.model.match_parent_path(add_qk, ["Cast", "Concat", "Slice", "Slice"], [1, 0, 0, 0])
  192. if attn_mask_path_1 is not None:
  193. _, slice_qk_2, slice_qk_1 = attn_mask_path_1
  194. elif attn_mask_path_2 is not None:
  195. _, _, slice_qk_2, slice_qk_1 = attn_mask_path_2
  196. else:
  197. return False
  198. # Check first input to Slice #1 is 3D attention mask of shape (B,S,T)
  199. if slice_qk_1.input[0] not in {"attn_mask", "attention_mask"}:
  200. return False
  201. slice_qk_2_path = self.model.match_parent_path(
  202. slice_qk_2, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 1, 0]
  203. )
  204. slice_qk_1_path_1 = self.model.match_parent_path(
  205. slice_qk_1, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 1, 0]
  206. )
  207. slice_qk_1_path_2 = self.model.match_parent_path(slice_qk_1, ["Unsqueeze"], [1])
  208. if slice_qk_2_path is None or slice_qk_1_path_1 is None or slice_qk_1_path_2 is None:
  209. return False
  210. # Check Gather --> Add --> Unsqueeze #3 --> Slice #2 connection for slice_qk_2_path
  211. # Check Gather --> Add --> Unsqueeze #2 --> Slice #1 connection for slice_qk_1_path_1
  212. if slice_qk_2_path[1].name != slice_qk_1_path_1[1].name or slice_qk_2_path[2].name != slice_qk_1_path_1[2].name:
  213. return False
  214. # Check Unsqueeze #1 --> Slice #1 connection for slice_qk_1_path_2
  215. # Check if first input to Add and Unsqueeze #1 is position ids
  216. if slice_qk_1_path_1[1].input[0] != slice_qk_1_path_2[0].input[0]:
  217. return False
  218. return True
  219. def check_runtime_shape_paths_for_nodes(
  220. self,
  221. reshape_qkv, # Final reshape before o_proj MatMul
  222. reshape_q, # Reshape before q_proj MatMul
  223. reshape_k, # Reshape before k_proj MatMul
  224. reshape_v, # Reshape before v_proj MatMul
  225. root_input, # Root input to attention subgraph
  226. ):
  227. # Check #1: check paths for qkv nodes
  228. concat_qkv_path = self.model.match_parent_path(reshape_qkv, ["Concat"], [1])
  229. if concat_qkv_path is None:
  230. return False
  231. concat_qkv = concat_qkv_path[0]
  232. reshape_qkv_path_1 = self.model.match_parent_path(concat_qkv, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
  233. reshape_qkv_path_2 = self.model.match_parent_path(concat_qkv, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
  234. if reshape_qkv_path_1 is None or reshape_qkv_path_2 is None:
  235. return False
  236. _, gather_1, shape_1 = reshape_qkv_path_1
  237. _, gather_2, shape_2 = reshape_qkv_path_2
  238. # Check root_input --> Shape --> Gather connection
  239. if shape_1.input[0] != root_input or shape_2.input[0] != root_input:
  240. return False
  241. # Check #2: check paths for v nodes
  242. concat_v_path = self.model.match_parent_path(reshape_v, ["Concat"], [1])
  243. if concat_v_path is None:
  244. return False
  245. concat_v = concat_v_path[0]
  246. reshape_v_path_1 = self.model.match_parent_path(concat_v, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
  247. reshape_v_path_2 = self.model.match_parent_path(concat_v, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
  248. if reshape_v_path_1 is None or reshape_v_path_2 is None:
  249. return False
  250. # Check Gather --> Unsqueeze --> Concat --> Reshape connection
  251. if reshape_v_path_1[1].name != gather_1.name or reshape_v_path_2[1].name != gather_2.name:
  252. return False
  253. # Check #3: check paths for k nodes
  254. concat_k_path = self.model.match_parent_path(reshape_k, ["Concat"], [1])
  255. if concat_k_path is None:
  256. return False
  257. concat_k = concat_k_path[0]
  258. reshape_k_path_1 = self.model.match_parent_path(concat_k, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
  259. reshape_k_path_2 = self.model.match_parent_path(concat_k, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
  260. if reshape_k_path_1 is None or reshape_k_path_2 is None:
  261. return False
  262. # Check Gather --> Unsqueeze --> Concat --> Reshape connection
  263. if reshape_k_path_1[1].name != gather_1.name or reshape_k_path_2[1].name != gather_2.name:
  264. return False
  265. # Check #4: check paths for q nodes
  266. concat_q_path = self.model.match_parent_path(reshape_q, ["Concat"], [1])
  267. if concat_q_path is None:
  268. return False
  269. concat_q = concat_q_path[0]
  270. reshape_q_path_1 = self.model.match_parent_path(concat_q, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
  271. reshape_q_path_2 = self.model.match_parent_path(concat_q, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
  272. if reshape_q_path_1 is None or reshape_q_path_2 is None:
  273. return False
  274. # Check Gather --> Unsqueeze --> Concat --> Reshape connection
  275. if reshape_q_path_1[1].name != gather_1.name or reshape_q_path_2[1].name != gather_2.name:
  276. return False
  277. return True
  278. def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
  279. if normalize_node.op_type not in {"SkipSimplifiedLayerNormalization", "SkipLayerNormalization", "Add"}:
  280. return
  281. # qkv_nodes_1 is for LLaMA-2 Microsoft
  282. # qkv_nodes_2 is for LLaMA-2 Hugging Face
  283. # qkv_nodes_3 is for LLaMA-2 distribute Hugging Face model
  284. qkv_nodes = None
  285. qkv_nodes_1 = self.model.match_parent_path(
  286. normalize_node,
  287. ["MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
  288. [1, 0, 0, 0, 0],
  289. )
  290. qkv_nodes_2 = self.model.match_parent_path(
  291. normalize_node,
  292. ["MatMul", "Reshape", "Transpose", "MatMul"],
  293. [1, 0, 0, 0],
  294. )
  295. qkv_nodes_3 = self.model.match_parent_path(
  296. normalize_node,
  297. ["AllReduce", "MatMul", "Reshape", "Transpose", "MatMul"],
  298. [1, 0, 0, 0, 0],
  299. )
  300. if qkv_nodes_1 is not None:
  301. _, reshape_qkv_2, _, reshape_qkv_1, matmul_qkv = qkv_nodes_1
  302. qkv_nodes = qkv_nodes_1
  303. elif qkv_nodes_2 is not None:
  304. _, reshape_qkv, _, matmul_qkv = qkv_nodes_2
  305. qkv_nodes = qkv_nodes_2
  306. elif qkv_nodes_3 is not None:
  307. _, _, reshape_qkv, _, matmul_qkv = qkv_nodes_3
  308. qkv_nodes = qkv_nodes_3
  309. else:
  310. logger.debug("fuse_rotary_attention: failed to match qkv nodes")
  311. return
  312. # v_nodes_1 is for LLaMA-2 Microsoft
  313. # v_nodes_3 is for LLaMA-2 Hugging Face
  314. # v_nodes_4 is for LLaMA-2 70B model
  315. # v_nodes_5 is for Phi-2 DirectML
  316. past_v, present_v, past_seq_len = "", "", ""
  317. v_nodes = None
  318. add_v = None
  319. v_nodes_1 = self.model.match_parent_path(
  320. matmul_qkv,
  321. ["Reshape", "Transpose", "Concat", "Transpose", "Reshape", "MatMul"],
  322. [1, 0, 0, 1, 0, 0],
  323. )
  324. v_nodes_2 = self.model.match_parent_path(
  325. matmul_qkv,
  326. ["Concat", "Transpose", "Reshape", "MatMul"],
  327. [1, 1, 0, 0],
  328. )
  329. v_nodes_3 = self.model.match_parent_path(
  330. matmul_qkv,
  331. ["Transpose", "Reshape", "MatMul"],
  332. [1, 0, 0],
  333. )
  334. _, v_nodes_4, _ = self.model.match_parent_paths_all(
  335. matmul_qkv,
  336. [
  337. (
  338. ["Reshape", "Expand", "Unsqueeze", "Concat", "Transpose", "Reshape", "MatMul"],
  339. [1, 0, 0, 0, 1, 0, 0],
  340. ),
  341. (
  342. [
  343. "Reshape",
  344. "Expand",
  345. "Where",
  346. "Equal",
  347. "Reshape",
  348. "Concat",
  349. "Unsqueeze",
  350. "Gather",
  351. "Shape",
  352. "Concat",
  353. "Transpose",
  354. "Reshape",
  355. "MatMul",
  356. ],
  357. [1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
  358. ),
  359. (
  360. [
  361. "Reshape",
  362. "Expand",
  363. "Where",
  364. "Equal",
  365. "Mul",
  366. "ConstantOfShape",
  367. "Shape",
  368. "Reshape",
  369. "Concat",
  370. "Unsqueeze",
  371. "Gather",
  372. "Shape",
  373. "Concat",
  374. "Transpose",
  375. "Reshape",
  376. "MatMul",
  377. ],
  378. [1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0],
  379. ),
  380. (
  381. [
  382. "Reshape",
  383. "Expand",
  384. "Where",
  385. "ConstantOfShape",
  386. "Shape",
  387. "Reshape",
  388. "Concat",
  389. "Unsqueeze",
  390. "Gather",
  391. "Shape",
  392. "Concat",
  393. "Transpose",
  394. "Reshape",
  395. "MatMul",
  396. ],
  397. [1, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0],
  398. ),
  399. (
  400. [
  401. "Reshape",
  402. "Expand",
  403. "Where",
  404. "Reshape",
  405. "Concat",
  406. "Unsqueeze",
  407. "Gather",
  408. "Shape",
  409. "Concat",
  410. "Transpose",
  411. "Reshape",
  412. "MatMul",
  413. ],
  414. [1, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0],
  415. ),
  416. (
  417. ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"],
  418. [1, 1, 0, 0, 0, 0, 1, 0, 0],
  419. ),
  420. (
  421. [
  422. "Reshape",
  423. "Concat",
  424. "Unsqueeze",
  425. "Mul",
  426. "Gather",
  427. "Shape",
  428. "Concat",
  429. "Transpose",
  430. "Reshape",
  431. "MatMul",
  432. ],
  433. [1, 1, 1, 0, 0, 0, 0, 1, 0, 0],
  434. ),
  435. (
  436. ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"],
  437. [1, 1, 2, 0, 0, 0, 1, 0, 0],
  438. ),
  439. (
  440. ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"],
  441. [1, 1, 3, 0, 0, 0, 1, 0, 0],
  442. ),
  443. ],
  444. output_name_to_node=None,
  445. )
  446. v_nodes_5 = self.model.match_parent_path(
  447. matmul_qkv,
  448. ["Concat", "Transpose", "Reshape", "Add", "MatMul"],
  449. [1, 1, 0, 0, 1],
  450. )
  451. if v_nodes_1 is not None:
  452. reshape_v_2, _, concat_v, _, reshape_v_1, matmul_v = v_nodes_1
  453. v_nodes = v_nodes_1
  454. concat_v_path = self.model.match_parent_path(
  455. concat_v,
  456. ["Slice", "Unsqueeze"],
  457. [0, 2],
  458. )
  459. if concat_v_path is None:
  460. logger.debug("fuse_rotary_attention: failed to match past/present concat in v path")
  461. return
  462. past_v = concat_v_path[0].input[0]
  463. past_seq_len = concat_v_path[-1].input[0]
  464. present_v = concat_v.output[0]
  465. elif v_nodes_2 is not None:
  466. concat_v, transpose_v, reshape_v, matmul_v = v_nodes_2
  467. v_nodes = v_nodes_2
  468. past_v = concat_v.input[0]
  469. present_v = concat_v.output[0]
  470. elif v_nodes_3 is not None:
  471. transpose_v, reshape_v, matmul_v = v_nodes_3
  472. v_nodes = v_nodes_3
  473. present_v = transpose_v.output[0]
  474. elif v_nodes_4 is not None and len(v_nodes_4) == 9:
  475. concat_v, transpose_v, reshape_v, matmul_v = v_nodes_4[0][-4:]
  476. v_nodes = v_nodes_4
  477. past_v = concat_v.input[0]
  478. present_v = concat_v.output[0]
  479. elif v_nodes_5 is not None:
  480. concat_v, transpose_v, reshape_v, add_v, matmul_v = v_nodes_5
  481. matmul_v = add_v
  482. v_nodes = v_nodes_5
  483. past_v = concat_v.input[0]
  484. present_v = concat_v.output[0]
  485. else:
  486. logger.debug("fuse_rotary_attention: failed to match v path")
  487. return
  488. qk_nodes = self.model.match_parent_path(
  489. matmul_qkv,
  490. ["Softmax", "Add", "Div", "MatMul"],
  491. [0, 0, 0, 0],
  492. )
  493. add_qk, matmul_qk = None, None
  494. if qk_nodes is not None:
  495. _, add_qk, _, matmul_qk = qk_nodes
  496. else:
  497. logger.debug("fuse_rotary_attention: failed to match qk nodes")
  498. return
  499. # attn_mask_nodes_1, attn_mask_nodes_2 are for LLaMA-2 Microsoft's 3D attention mask
  500. # attn_mask_nodes_3, attn_mask_nodes_4 are for LLaMA-2 Hugging Face's 2D attention mask
  501. # attn_mask_nodes_5, attn_mask_nodes_6 are for LLaMA-2 Microsoft's model for the DML EP
  502. # attn_mask_nodes_7 is for LLaMA-2 Hugging Face's changes to the attention mask
  503. attn_mask, add_qk_str = "", ""
  504. attn_mask_nodes_1 = self.model.match_parent_path(
  505. add_qk,
  506. ["Concat", "Slice", "Slice"],
  507. [1, 0, 0],
  508. )
  509. attn_mask_nodes_2 = self.model.match_parent_path(
  510. add_qk,
  511. ["Cast", "Concat", "Slice", "Slice"],
  512. [1, 0, 0, 0],
  513. )
  514. attn_mask_nodes_3 = self.model.match_parent_path(
  515. add_qk,
  516. ["Add", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
  517. [1, 0, 2, 1, 0, 0, 0],
  518. )
  519. attn_mask_nodes_4 = self.model.match_parent_path(
  520. add_qk,
  521. ["Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
  522. [1, 2, 1, 0, 0, 0],
  523. )
  524. attn_mask_nodes_5 = self.model.match_parent_path(
  525. add_qk,
  526. ["Expand", "Add", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
  527. [1, 0, 0, 2, 1, 0, 0, 0],
  528. )
  529. attn_mask_nodes_6 = self.model.match_parent_path(
  530. add_qk,
  531. ["Expand", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
  532. [1, 0, 2, 1, 0, 0, 0],
  533. )
  534. attn_mask_nodes_7 = self.model.match_parent_path(
  535. add_qk,
  536. ["Where", "Cast", "Where", "Cast", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
  537. [1, 0, 0, 0, 0, 1, 0, 0, 0],
  538. )
  539. if attn_mask_nodes_1 is not None:
  540. _, slice_mask_1, slice_mask_2 = attn_mask_nodes_1
  541. attn_mask = slice_mask_1.output[0]
  542. elif attn_mask_nodes_2 is not None:
  543. _, _, slice_mask_1, slice_mask_2 = attn_mask_nodes_2
  544. attn_mask = slice_mask_1.output[0]
  545. elif attn_mask_nodes_3 is not None:
  546. # Reshape from (B,1,S,T) to (B,N,S,T)
  547. add_qk_str = self.reshape_add_qk(attn_mask_nodes_3[0].output[0])
  548. elif attn_mask_nodes_4 is not None:
  549. # Reshape from (B,1,S,T) to (B,N,S,T)
  550. add_qk_str = self.reshape_add_qk(attn_mask_nodes_4[0].output[0])
  551. elif attn_mask_nodes_5 is not None:
  552. # The mask has already been reshaped to (B,N,S,T)
  553. add_qk_str = attn_mask_nodes_5[0].output[0]
  554. elif attn_mask_nodes_6 is not None:
  555. # The mask has already been reshaped to (B,N,S,T)
  556. add_qk_str = attn_mask_nodes_6[0].output[0]
  557. elif attn_mask_nodes_7 is not None:
  558. # Reshape from (B,1,S,T) to (B,N,S,T)
  559. add_qk_str = self.reshape_add_qk(attn_mask_nodes_7[0].output[0])
  560. else:
  561. logger.debug("fuse_rotary_attention: failed to match attention mask nodes")
  562. return
  563. # k_nodes_1 is for LLaMA-2 Microsoft
  564. # k_nodes_2 is for LLaMA-2 Hugging Face
  565. # k_nodes_4 is for LLaMA-2 70B Hugging Face
  566. past_k, present_k = "", ""
  567. k_nodes = None
  568. slice_k = None
  569. concat_k_half = None
  570. k_nodes_1 = self.model.match_parent_path(
  571. matmul_qk,
  572. ["Reshape", "Transpose", "Concat", "Transpose", "RotaryEmbedding", "MatMul"],
  573. [1, 0, 0, 1, 0, 0],
  574. )
  575. k_nodes_2 = self.model.match_parent_path(
  576. matmul_qk,
  577. ["Transpose", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"],
  578. [1, 0, 0, 0, 0],
  579. )
  580. k_nodes_3 = self.model.match_parent_path(
  581. matmul_qk,
  582. ["Transpose", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"],
  583. [1, 0, 1, 0, 0, 0],
  584. )
  585. _, k_nodes_4, _ = self.model.match_parent_paths_all(
  586. matmul_qk,
  587. [
  588. (
  589. [
  590. "Transpose",
  591. "Reshape",
  592. "Expand",
  593. "Unsqueeze",
  594. "Concat",
  595. "RotaryEmbedding",
  596. "Transpose",
  597. "Reshape",
  598. "MatMul",
  599. ],
  600. [1, 0, 0, 0, 0, 1, 0, 0, 0],
  601. ),
  602. (
  603. [
  604. "Transpose",
  605. "Reshape",
  606. "Expand",
  607. "Where",
  608. "Equal",
  609. "Reshape",
  610. "Concat",
  611. "Unsqueeze",
  612. "Gather",
  613. "Shape",
  614. "Concat",
  615. "RotaryEmbedding",
  616. "Transpose",
  617. "Reshape",
  618. "MatMul",
  619. ],
  620. [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
  621. ),
  622. (
  623. [
  624. "Transpose",
  625. "Reshape",
  626. "Expand",
  627. "Where",
  628. "Equal",
  629. "Mul",
  630. "ConstantOfShape",
  631. "Shape",
  632. "Reshape",
  633. "Concat",
  634. "Unsqueeze",
  635. "Gather",
  636. "Shape",
  637. "Concat",
  638. "RotaryEmbedding",
  639. "Transpose",
  640. "Reshape",
  641. "MatMul",
  642. ],
  643. [1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0],
  644. ),
  645. (
  646. [
  647. "Transpose",
  648. "Reshape",
  649. "Expand",
  650. "Where",
  651. "ConstantOfShape",
  652. "Shape",
  653. "Reshape",
  654. "Concat",
  655. "Unsqueeze",
  656. "Gather",
  657. "Shape",
  658. "Concat",
  659. "RotaryEmbedding",
  660. "Transpose",
  661. "Reshape",
  662. "MatMul",
  663. ],
  664. [1, 0, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0],
  665. ),
  666. (
  667. [
  668. "Transpose",
  669. "Reshape",
  670. "Expand",
  671. "Where",
  672. "Reshape",
  673. "Concat",
  674. "Unsqueeze",
  675. "Gather",
  676. "Shape",
  677. "Concat",
  678. "RotaryEmbedding",
  679. "Transpose",
  680. "Reshape",
  681. "MatMul",
  682. ],
  683. [1, 0, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0, 0],
  684. ),
  685. (
  686. [
  687. "Transpose",
  688. "Reshape",
  689. "Concat",
  690. "Unsqueeze",
  691. "Gather",
  692. "Shape",
  693. "Concat",
  694. "RotaryEmbedding",
  695. "Transpose",
  696. "Reshape",
  697. "MatMul",
  698. ],
  699. [1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
  700. ),
  701. (
  702. [
  703. "Transpose",
  704. "Reshape",
  705. "Concat",
  706. "Unsqueeze",
  707. "Mul",
  708. "Gather",
  709. "Shape",
  710. "Concat",
  711. "RotaryEmbedding",
  712. "Transpose",
  713. "Reshape",
  714. "MatMul",
  715. ],
  716. [1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0],
  717. ),
  718. (
  719. [
  720. "Transpose",
  721. "Reshape",
  722. "Concat",
  723. "Unsqueeze",
  724. "Gather",
  725. "Shape",
  726. "Concat",
  727. "RotaryEmbedding",
  728. "Transpose",
  729. "Reshape",
  730. "MatMul",
  731. ],
  732. [1, 0, 1, 2, 0, 0, 0, 1, 0, 0, 0],
  733. ),
  734. (
  735. [
  736. "Transpose",
  737. "Reshape",
  738. "Concat",
  739. "Unsqueeze",
  740. "Gather",
  741. "Shape",
  742. "Concat",
  743. "RotaryEmbedding",
  744. "Transpose",
  745. "Reshape",
  746. "MatMul",
  747. ],
  748. [1, 0, 1, 3, 0, 0, 0, 1, 0, 0, 0],
  749. ),
  750. ],
  751. output_name_to_node=None,
  752. )
  753. k_nodes_5 = self.model.match_parent_path(
  754. matmul_qk,
  755. ["Transpose", "Concat", "Concat", "RotaryEmbedding", "Slice", "Transpose", "Reshape", "Add", "MatMul"],
  756. [1, 0, 1, 0, 0, 0, 0, 0, 1],
  757. )
  758. if k_nodes_1 is not None:
  759. reshape_k_2, _, concat_k, _, rotary_k, matmul_k = k_nodes_1
  760. k_nodes = k_nodes_1
  761. concat_k_path = self.model.match_parent_path(
  762. concat_k,
  763. ["Slice", "Unsqueeze"],
  764. [0, 2],
  765. )
  766. if concat_k_path is None:
  767. logger.debug("fuse_rotary_attention: failed to match past/present concat in k path")
  768. return
  769. past_k = concat_k_path[0].input[0]
  770. shared_past_seq_len = concat_k_path[-1].input[0]
  771. present_k = concat_k.output[0]
  772. assert past_seq_len == shared_past_seq_len
  773. elif k_nodes_2 is not None:
  774. _, rotary_k, _, reshape_k, matmul_k = k_nodes_2
  775. k_nodes = k_nodes_2
  776. present_k = rotary_k.output[0]
  777. elif k_nodes_3 is not None:
  778. _, concat_k, rotary_k, _, reshape_k, matmul_k = k_nodes_3
  779. k_nodes = k_nodes_3
  780. past_k = concat_k.input[0]
  781. present_k = concat_k.output[0]
  782. elif k_nodes_4 is not None and len(k_nodes_4) == 9:
  783. reshape_k, matmul_k = k_nodes_4[0][-2:]
  784. concat_k, rotary_k = k_nodes_4[0][-5:-3]
  785. k_nodes = k_nodes_4
  786. past_k = concat_k.input[0]
  787. present_k = concat_k.output[0]
  788. elif k_nodes_5 is not None:
  789. _, concat_k, concat_k_half, rotary_k, slice_k, _, reshape_k, _, matmul_k = k_nodes_5
  790. k_nodes = k_nodes_5
  791. past_k = concat_k.input[0]
  792. present_k = concat_k.output[0]
  793. else:
  794. logger.debug("fuse_rotary_attention: failed to match k nodes")
  795. return
  796. # q_nodes_1 is for LLaMA-2 Microsoft
  797. # q_nodes_2 is for LLaMA-2 Hugging Face
  798. # q_nodes_3 is for Phi-2 DirectML
  799. q_nodes = None
  800. slice_q = None
  801. concat_q_half = None
  802. q_nodes_1 = self.model.match_parent_path(
  803. matmul_qk,
  804. ["Reshape", "Transpose", "RotaryEmbedding", "MatMul"],
  805. [0, 0, 0, 0],
  806. )
  807. q_nodes_2 = self.model.match_parent_path(
  808. matmul_qk,
  809. ["RotaryEmbedding", "Transpose", "Reshape", "MatMul"],
  810. [0, 0, 0, 0],
  811. )
  812. q_nodes_3 = self.model.match_parent_path(
  813. matmul_qk,
  814. ["Concat", "RotaryEmbedding", "Slice", "Transpose", "Reshape", "Add", "MatMul"],
  815. [0, 0, 0, 0, 0, 0, 1],
  816. )
  817. if q_nodes_1 is not None:
  818. reshape_q_2, _, rotary_q, matmul_q = q_nodes_1
  819. q_nodes = q_nodes_1
  820. elif q_nodes_2 is not None:
  821. rotary_q, _, reshape_q, matmul_q = q_nodes_2
  822. q_nodes = q_nodes_2
  823. elif q_nodes_3 is not None:
  824. concat_q_half, rotary_q, slice_q, _, reshape_q, _, matmul_q = q_nodes_3
  825. q_nodes = q_nodes_3
  826. else:
  827. logger.debug("fuse_rotary_attention: failed to match q nodes")
  828. return
  829. if matmul_q.input[0] != matmul_k.input[0] and matmul_k.input[0] != matmul_v.input[0]:
  830. logger.debug("fuse_rotary_attention: failed to find the same root_input for q, k, v paths")
  831. return
  832. root_output = ""
  833. if qkv_nodes == qkv_nodes_1:
  834. if not self.check_runtime_shape_paths_for_function(
  835. reshape_qkv_2,
  836. reshape_qkv_1,
  837. reshape_q_2,
  838. reshape_k_2,
  839. reshape_v_2,
  840. reshape_v_1,
  841. add_qk,
  842. matmul_q.input[0],
  843. ):
  844. logger.debug("fuse_rotary_attention: failed to verify runtime shape paths")
  845. return
  846. root_output = reshape_qkv_2.output[0]
  847. elif qkv_nodes in (qkv_nodes_2, qkv_nodes_3):
  848. if not self.check_runtime_shape_paths_for_nodes(
  849. reshape_qkv,
  850. reshape_q,
  851. reshape_k,
  852. reshape_v,
  853. matmul_q.input[0],
  854. ):
  855. logger.debug("fuse_rotary_attention: failed to verify runtime shape paths")
  856. return
  857. root_output = reshape_qkv.output[0]
  858. # Rename inputs of rotary_q/k so it connects with output of matmul_q/k
  859. # Before: MatMul --> Reshape --> Transpose --> RotaryEmbedding
  860. # After: MatMul --> RotaryEmbedding
  861. rotary_q.input[0] = slice_q.output[0] if slice_q else matmul_q.output[0]
  862. rotary_k.input[0] = slice_k.output[0] if slice_k else matmul_k.output[0]
  863. # Rename current output of rotary_k (present_key) so it doesn't match output of MHA (present_key)
  864. if concat_q_half is None:
  865. rotary_k.output[0] = rotary_k.name + "_output_0"
  866. if qkv_nodes == qkv_nodes_3:
  867. qkv_nodes = qkv_nodes[1:]
  868. def create_hidden_size_concat_node(reshape_q):
  869. """Detect num_heads and hidden_size for ONNX model from phi-2
  870. Args:
  871. reshape_q (NodeProto): reshape node for q
  872. Returns:
  873. hidden_size_concat_node(NodeProto): Concat node to be used by reshape
  874. """
  875. concat = self.model.match_parent(reshape_q, "Concat", 1)
  876. if concat is None:
  877. logger.debug("fuse_rotary_attention: failed to trace the concat node from reshape_q")
  878. return None
  879. # The shape is a tensor like [?, ?, num_heads, head_size]
  880. num_head_constant_node = self.model.get_constant_value(concat.input[2])
  881. head_size_constant_node = self.model.get_constant_value(concat.input[3])
  882. if num_head_constant_node is None or head_size_constant_node is None:
  883. logger.debug("fuse_rotary_attention: failed to get constant nodes of num_heads or head_size")
  884. return None
  885. num_head_value = num_head_constant_node[0]
  886. head_size_value = head_size_constant_node[0]
  887. hidden_size = num_head_value * head_size_value
  888. hidden_size_initilizer = self.model.create_node_name("Initializer", name_prefix="hidden_size")
  889. if self.model.get_initializer(hidden_size_initilizer) is None:
  890. self.add_initializer(
  891. name=hidden_size_initilizer,
  892. data_type=TensorProto.INT64,
  893. dims=[1],
  894. vals=[hidden_size],
  895. raw=False,
  896. )
  897. hidden_size_reshape_node_name = self.model.create_node_name("Concat", name_prefix="hidden_size_concat")
  898. hidden_size_concat_node = helper.make_node(
  899. "Concat",
  900. inputs=[
  901. concat.input[0],
  902. concat.input[1],
  903. hidden_size_initilizer,
  904. ],
  905. outputs=[hidden_size_reshape_node_name + "output_0"],
  906. name=hidden_size_reshape_node_name,
  907. )
  908. hidden_size_concat_node.attribute.extend([helper.make_attribute("axis", 0)])
  909. return hidden_size_concat_node
  910. # Add Tranpose and Reshape nodes for patial rotary embedding applied in phi-2 before passing into MHA
  911. if concat_q_half and concat_k_half:
  912. # Transpose the key output of rotary Embedding
  913. k_transpose_node_name = self.model.create_node_name("Transpose")
  914. k_tranpose_output_name = k_transpose_node_name + "_output_0"
  915. k_transpose_node = helper.make_node(
  916. "Transpose",
  917. inputs=[concat_k_half.output[0]],
  918. outputs=[k_tranpose_output_name],
  919. name=k_transpose_node_name,
  920. )
  921. k_transpose_node.attribute.extend([helper.make_attribute("perm", [0, 2, 1, 3])])
  922. # Transpose the query output of rotary Embedding
  923. q_transpose_node_name = self.model.create_node_name("Transpose")
  924. q_tranpose_output_name = q_transpose_node_name + "_output_0"
  925. q_transpose_node = helper.make_node(
  926. "Transpose",
  927. inputs=[concat_q_half.output[0]],
  928. outputs=[q_tranpose_output_name],
  929. name=q_transpose_node_name,
  930. )
  931. q_transpose_node.attribute.extend([helper.make_attribute("perm", [0, 2, 1, 3])])
  932. hidden_size_concat_node = create_hidden_size_concat_node(reshape_k)
  933. if hidden_size_concat_node is None:
  934. logger.debug("fuse_rotary_attention: failed to create hidden_size_concat_node")
  935. return
  936. # Reshape the Rotary Embedding output for key for 4D to 3D
  937. concat_k_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="concat_k_half")
  938. concat_k_reshape_node = helper.make_node(
  939. "Reshape",
  940. inputs=[k_transpose_node.output[0], hidden_size_concat_node.output[0]],
  941. outputs=[concat_k_reshape_node_name + "_output_0"],
  942. name=concat_k_reshape_node_name,
  943. )
  944. # Reshape the Rotary Embedding output for query from 4D to 3D
  945. concat_q_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="concat_q_half")
  946. concat_q_reshape_node = helper.make_node(
  947. "Reshape",
  948. inputs=[q_transpose_node.output[0], hidden_size_concat_node.output[0]],
  949. outputs=[concat_q_reshape_node_name + "_output_0"],
  950. name=concat_q_reshape_node_name,
  951. )
  952. rotary_k = concat_k_reshape_node
  953. rotary_q = concat_q_reshape_node
  954. self.nodes_to_add.append(hidden_size_concat_node)
  955. self.nodes_to_add.append(k_transpose_node)
  956. self.nodes_to_add.append(q_transpose_node)
  957. self.nodes_to_add.append(concat_k_reshape_node)
  958. self.nodes_to_add.append(concat_q_reshape_node)
  959. self.node_name_to_graph_name[hidden_size_concat_node.name] = self.this_graph_name
  960. self.node_name_to_graph_name[k_transpose_node.name] = self.this_graph_name
  961. self.node_name_to_graph_name[q_transpose_node.name] = self.this_graph_name
  962. self.node_name_to_graph_name[concat_k_reshape_node.name] = self.this_graph_name
  963. self.node_name_to_graph_name[concat_q_reshape_node.name] = self.this_graph_name
  964. new_node = self.create_mha_node(
  965. matmul_q.input[0],
  966. root_output,
  967. rotary_q,
  968. rotary_k,
  969. matmul_v,
  970. attn_mask,
  971. add_qk_str,
  972. past_k,
  973. past_v,
  974. present_k,
  975. present_v,
  976. )
  977. if new_node is None:
  978. logger.debug("fuse_rotary_attention: failed to create multi-head attention with rotary embeddings")
  979. return
  980. self.nodes_to_add.append(new_node)
  981. self.node_name_to_graph_name[new_node.name] = self.this_graph_name
  982. self.nodes_to_remove.extend(qkv_nodes[1:])
  983. if v_nodes != v_nodes_4:
  984. self.nodes_to_remove.extend(v_nodes[:-1] if add_v is None else v_nodes[:-2])
  985. else:
  986. nodes_to_keep = [v_nodes[0][-1]]
  987. for temp_path in v_nodes:
  988. self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep)
  989. self.nodes_to_remove.extend(qk_nodes)
  990. if k_nodes == k_nodes_1:
  991. self.nodes_to_remove.extend(k_nodes[:-2])
  992. elif k_nodes == k_nodes_2:
  993. self.nodes_to_remove.append(k_nodes[0])
  994. self.nodes_to_remove.append(k_nodes[2])
  995. self.nodes_to_remove.append(k_nodes[3])
  996. elif k_nodes == k_nodes_3:
  997. self.nodes_to_remove.append(k_nodes[0])
  998. self.nodes_to_remove.append(k_nodes[1])
  999. self.nodes_to_remove.append(k_nodes[3])
  1000. self.nodes_to_remove.append(k_nodes[4])
  1001. elif k_nodes == k_nodes_5:
  1002. self.nodes_to_remove.append(k_nodes[0])
  1003. self.nodes_to_remove.append(k_nodes[1])
  1004. elif k_nodes == k_nodes_4:
  1005. nodes_to_keep = [k_nodes[0][-1], k_nodes[0][-4]]
  1006. for temp_path in k_nodes:
  1007. self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep)
  1008. if q_nodes == q_nodes_1:
  1009. self.nodes_to_remove.extend(q_nodes[:-2])
  1010. elif q_nodes == q_nodes_2:
  1011. self.nodes_to_remove.append(q_nodes[1])
  1012. self.nodes_to_remove.append(q_nodes[2])
  1013. self.prune_graph = True
  1014. class FusionRotaryEmbeddings(Fusion):
  1015. def __init__(self, model: OnnxModel):
  1016. self.base_name = "RotaryEmbedding"
  1017. super().__init__(model, self.base_name, [self.base_name, self.base_name + ".1", "Add"])
  1018. # The RotaryEmbedding function can have multiple extraneous constant outputs even though the function is supposed to produce only one output.
  1019. # This is a byproduct of a potential CSE bug when using `export_modules_as_functions` in the TorchScript exporter.
  1020. # To work around this issue, we set the extraneous constant values from the RotaryEmbedding function as initializers in the locations where they are actually used.
  1021. def reassign_extra_outputs(self, rot_emb_node: NodeProto, function: FunctionProto):
  1022. # Find extra outputs and Constant nodes attached to those outputs
  1023. extra_constants, extra_outputs = [], []
  1024. for fn_node in function.node:
  1025. if fn_node.op_type == "Constant" and fn_node.input == [] and fn_node.output[0] in function.output:
  1026. extra_constants.append(fn_node)
  1027. output_index = list(function.output).index(fn_node.output[0])
  1028. extra_outputs.append(rot_emb_node.output[output_index])
  1029. # Set extra Constant node outputs as initializers
  1030. extra_initializers = []
  1031. for extra_constant in extra_constants:
  1032. constant_tensorproto = extra_constant.attribute[0].t
  1033. constant_tensorproto.name = self.model.create_node_name("Constant")
  1034. self.model.add_initializer(constant_tensorproto)
  1035. extra_initializers.append(constant_tensorproto.name)
  1036. # Update references of Constant node outputs to initializer references
  1037. for extra_output, extra_initializer in zip(extra_outputs, extra_initializers, strict=False):
  1038. nodes_to_update = list(filter(lambda entry: extra_output in entry.input, self.model.model.graph.node))
  1039. for node_to_update in nodes_to_update:
  1040. OnnxModel.replace_node_input(node_to_update, extra_output, extra_initializer)
  1041. return extra_outputs
  1042. def create_rotary_embeddings_from_function(self, node: NodeProto):
  1043. rotary_emb_node_name = self.model.create_node_name(self.base_name)
  1044. matmul_path = self.model.match_parent_path(
  1045. node,
  1046. ["Reshape", "MatMul"],
  1047. [0, 0],
  1048. )
  1049. if matmul_path is not None:
  1050. reshape_node, matmul_node = matmul_path
  1051. else:
  1052. logger.debug("fuse_rotary_embeddings: failed to match MatMul")
  1053. return
  1054. rotary_emb_inputs = [
  1055. matmul_node.output[0], # x is of shape (B,S,D) instead of (B,S,N,H)
  1056. node.input[1], # position_ids
  1057. ]
  1058. # Convert cos_cache and sin_cache from node attributes to model initializers
  1059. cos_cache_node = list(filter(lambda constant: constant.output[0] == node.input[2], self.model.model.graph.node))
  1060. sin_cache_node = list(filter(lambda constant: constant.output[0] == node.input[3], self.model.model.graph.node))
  1061. cos_cache_name, sin_cache_name = "cos_cache", "sin_cache"
  1062. if (
  1063. len(cos_cache_node) == 1
  1064. and len(sin_cache_node) == 1
  1065. and self.model.get_initializer(cos_cache_name) is None
  1066. and self.model.get_initializer(sin_cache_name) is None
  1067. ):
  1068. cos_cache = numpy_helper.to_array(cos_cache_node[0].attribute[0].t).squeeze()
  1069. sin_cache = numpy_helper.to_array(sin_cache_node[0].attribute[0].t).squeeze()
  1070. cos_cache_tensor = helper.make_tensor(
  1071. name=cos_cache_name,
  1072. data_type=TensorProto.FLOAT,
  1073. dims=list(cos_cache.shape),
  1074. vals=cos_cache.flatten().tolist(),
  1075. )
  1076. self.model.add_initializer(cos_cache_tensor, self.this_graph_name)
  1077. sin_cache_tensor = helper.make_tensor(
  1078. name=sin_cache_name,
  1079. data_type=TensorProto.FLOAT,
  1080. dims=list(sin_cache.shape),
  1081. vals=sin_cache.flatten().tolist(),
  1082. )
  1083. self.model.add_initializer(sin_cache_tensor, self.this_graph_name)
  1084. self.nodes_to_remove.extend([cos_cache_node[0], sin_cache_node[0]])
  1085. rotary_emb_inputs.extend([cos_cache_name, sin_cache_name])
  1086. rotary_emb_outputs = node.output
  1087. if len(rotary_emb_outputs) > 1:
  1088. # Re-assign extraneous constant outputs in RotaryEmbedding functions as initializers
  1089. func = list(filter(lambda fn: fn.name == node.op_type, self.model.model.functions))
  1090. assert len(func) == 1
  1091. extra_outputs = self.reassign_extra_outputs(node, func[0])
  1092. rotary_emb_outputs = list(filter(lambda output_name: output_name not in extra_outputs, rotary_emb_outputs))
  1093. assert len(rotary_emb_outputs) == 1
  1094. rotary_emb_node = helper.make_node(
  1095. self.base_name,
  1096. inputs=rotary_emb_inputs,
  1097. outputs=rotary_emb_outputs,
  1098. name=rotary_emb_node_name,
  1099. interleaved=1,
  1100. )
  1101. rotary_emb_node.domain = "com.microsoft"
  1102. self.nodes_to_remove.append(reshape_node)
  1103. return rotary_emb_node
  1104. def create_rotary_embeddings_from_nodes(
  1105. self,
  1106. root_input: str,
  1107. position_ids: str,
  1108. cos_slice: str,
  1109. sin_slice: str,
  1110. output: str,
  1111. ):
  1112. rotary_emb_node_name = self.model.create_node_name(self.base_name)
  1113. # Convert cos_cache and sin_cache from node attributes to model initializers
  1114. cos_cache_node = list(filter(lambda constant: constant.output[0] == cos_slice, self.model.model.graph.node))
  1115. sin_cache_node = list(filter(lambda constant: constant.output[0] == sin_slice, self.model.model.graph.node))
  1116. cos_cache_name, sin_cache_name = "cos_cache", "sin_cache"
  1117. if (
  1118. len(cos_cache_node) == 1
  1119. and len(sin_cache_node) == 1
  1120. and self.model.get_initializer(cos_cache_name) is None
  1121. and self.model.get_initializer(sin_cache_name) is None
  1122. ):
  1123. cos_cache = numpy_helper.to_array(cos_cache_node[0].attribute[0].t).squeeze()
  1124. sin_cache = numpy_helper.to_array(sin_cache_node[0].attribute[0].t).squeeze()
  1125. # Reshape cos/sin cache from (M, H) to (M, H/2)
  1126. head_size = cos_cache.shape[1]
  1127. cos_cache = cos_cache[:, : (head_size // 2)]
  1128. sin_cache = sin_cache[:, : (head_size // 2)]
  1129. cos_cache_tensor = helper.make_tensor(
  1130. name=cos_cache_name,
  1131. data_type=TensorProto.FLOAT,
  1132. dims=list(cos_cache.shape),
  1133. vals=cos_cache.flatten().tolist(),
  1134. )
  1135. self.model.add_initializer(cos_cache_tensor, self.this_graph_name)
  1136. sin_cache_tensor = helper.make_tensor(
  1137. name=sin_cache_name,
  1138. data_type=TensorProto.FLOAT,
  1139. dims=list(sin_cache.shape),
  1140. vals=sin_cache.flatten().tolist(),
  1141. )
  1142. self.model.add_initializer(sin_cache_tensor, self.this_graph_name)
  1143. self.nodes_to_remove.extend([cos_cache_node[0], sin_cache_node[0]])
  1144. rotary_emb_node = helper.make_node(
  1145. self.base_name,
  1146. inputs=[root_input, position_ids, cos_cache_name, sin_cache_name],
  1147. outputs=[output],
  1148. name=rotary_emb_node_name,
  1149. interleaved=0,
  1150. )
  1151. rotary_emb_node.domain = "com.microsoft"
  1152. return rotary_emb_node
  1153. def fuse(self, node, input_name_to_nodes, output_name_to_node):
  1154. # Node is either RotaryEmbedding function or Add
  1155. if self.base_name not in node.op_type and node.op_type != "Add":
  1156. return
  1157. # Check if node is "RotaryEmbedding nn.Module" exported as a function
  1158. # (e.g. export_modules_as_functions={RotaryEmbedding} in torch.onnx.export)
  1159. rotary_emb_node = None
  1160. if node.op_type != "Add":
  1161. # Verify that function has the correct inputs
  1162. if len(node.input) not in {4, 5} or node.input[1] not in {
  1163. "pos",
  1164. "pos_id",
  1165. "position_id",
  1166. "pos_ids",
  1167. "position_ids",
  1168. }:
  1169. logger.debug("fuse_rotary_embeddings: failed to verify inputs for RotaryEmbedding function")
  1170. return
  1171. rotary_emb_node = self.create_rotary_embeddings_from_function(node)
  1172. if rotary_emb_node is None:
  1173. logger.debug("fuse_rotary_embeddings: failed to create RotaryEmbedding node")
  1174. return
  1175. # Remove RotaryEmbedding function
  1176. self.nodes_to_remove.append(node)
  1177. # Remove RotaryEmbedding function's shape inference stored in value_info
  1178. # The new shape will be calculated during symbolic shape inference
  1179. old_shape_infer = list(
  1180. filter(lambda node: node.name == rotary_emb_node.output[0], self.model.model.graph.value_info)
  1181. )
  1182. assert len(old_shape_infer) == 1
  1183. self.model.model.graph.value_info.remove(old_shape_infer[0])
  1184. else:
  1185. # Rotary embeddings are defined using the below functions:
  1186. #
  1187. # def rotate_half(x):
  1188. # """Rotates half the hidden dims of the input."""
  1189. # x1 = x[..., : x.shape[-1] // 2]
  1190. # x2 = x[..., x.shape[-1] // 2 :]
  1191. # return torch.cat((-x2, x1), dim=-1)
  1192. #
  1193. # def apply_rope(x, cos, sin, position_ids):
  1194. # cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
  1195. # sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
  1196. # cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
  1197. # sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
  1198. # x_embed = (x * cos) + (rotate_half(x) * sin)
  1199. # return x_embed
  1200. # Check paths for rotate_half(x)
  1201. rotate_half_x2_path_1_1 = self.model.match_parent_path(
  1202. node,
  1203. ["Mul", "Concat", "Neg", "Slice", "Transpose"],
  1204. [1, 0, 0, 0, 0],
  1205. )
  1206. rotate_half_x2_path_1_2 = self.model.match_parent_path(
  1207. node,
  1208. ["Mul", "Concat", "Neg", "Slice", "Slice"],
  1209. [1, 0, 0, 0, 0],
  1210. )
  1211. rotate_half_x2_path_1 = rotate_half_x2_path_1_1 or rotate_half_x2_path_1_2
  1212. rotate_half_x2_path_2_1 = self.model.match_parent_path(
  1213. node,
  1214. ["Mul", "Concat", "Neg", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Transpose"],
  1215. [1, 0, 0, 0, 1, 0, 0, 0, 0],
  1216. )
  1217. rotate_half_x2_path_2_2 = self.model.match_parent_path(
  1218. node,
  1219. ["Mul", "Concat", "Neg", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Slice"],
  1220. [1, 0, 0, 0, 1, 0, 0, 0, 0],
  1221. )
  1222. rotate_half_x2_path_2 = rotate_half_x2_path_2_1 or rotate_half_x2_path_2_2
  1223. if rotate_half_x2_path_1 is None or rotate_half_x2_path_2 is None:
  1224. logger.debug("fuse_rotary_embeddings: failed to match x2 in rotate_half")
  1225. return
  1226. rotate_half_x1_path_1_1 = self.model.match_parent_path(
  1227. node,
  1228. ["Mul", "Concat", "Slice", "Transpose"],
  1229. [1, 0, 1, 0],
  1230. )
  1231. rotate_half_x1_path_1_2 = self.model.match_parent_path(
  1232. node,
  1233. ["Mul", "Concat", "Slice", "Slice"],
  1234. [1, 0, 1, 0],
  1235. )
  1236. rotate_half_x1_path_1 = rotate_half_x1_path_1_1 or rotate_half_x1_path_1_2
  1237. rotate_half_x1_path_2_1 = self.model.match_parent_path(
  1238. node,
  1239. ["Mul", "Concat", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Transpose"],
  1240. [1, 0, 1, 2, 0, 0, 0, 0],
  1241. )
  1242. rotate_half_x1_path_2_2 = self.model.match_parent_path(
  1243. node,
  1244. ["Mul", "Concat", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Slice"],
  1245. [1, 0, 1, 2, 0, 0, 0, 0],
  1246. )
  1247. rotate_half_x1_path_2 = rotate_half_x1_path_2_1 or rotate_half_x1_path_2_2
  1248. if rotate_half_x1_path_1 is None or rotate_half_x1_path_2 is None:
  1249. logger.debug("fuse_rotary_embeddings: failed to match x1 in rotate_half")
  1250. return
  1251. if (
  1252. rotate_half_x1_path_1[-1].name != rotate_half_x1_path_2[-1].name
  1253. or rotate_half_x2_path_1[-1].name != rotate_half_x2_path_2[-1].name
  1254. or rotate_half_x1_path_1[-1].name != rotate_half_x2_path_1[-1].name
  1255. or rotate_half_x1_path_2[-1].name != rotate_half_x2_path_2[-1].name
  1256. ):
  1257. logger.debug("fuse_rotary_embeddings: failed to match common input in rotate_half")
  1258. return
  1259. # Check path for x
  1260. x_path_1 = self.model.match_parent_path(
  1261. node,
  1262. ["Mul", "Transpose"],
  1263. [0, 0],
  1264. )
  1265. x_path_2 = self.model.match_parent_path(
  1266. node,
  1267. ["Mul", "Slice"],
  1268. [0, 0],
  1269. )
  1270. x_path = x_path_1 or x_path_2
  1271. if x_path is None:
  1272. logger.debug("fuse_rotary_embeddings: failed to match x in rotate_half")
  1273. return
  1274. # Check path for sin
  1275. sin_path, sin_cache, position_ids = None, "", ""
  1276. sin_path_1 = self.model.match_parent_path(
  1277. node,
  1278. ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Gather", "Shape"],
  1279. [1, 1, 0, 0, 0, 0, 2, 0, 0],
  1280. )
  1281. sin_path_2 = self.model.match_parent_path(
  1282. node,
  1283. ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Add"],
  1284. [1, 1, 0, 0, 0, 0, 2, 0],
  1285. )
  1286. sin_path_3 = self.model.match_parent_path(
  1287. node,
  1288. ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Gather", "Shape"],
  1289. [1, 1, 0, 0, 2, 0, 0],
  1290. )
  1291. sin_path_4 = self.model.match_parent_path(
  1292. node,
  1293. ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Add"],
  1294. [1, 1, 0, 0, 2, 0],
  1295. )
  1296. if sin_path_1 is not None:
  1297. sin_path = sin_path_1
  1298. sin_cache = sin_path[-4].input[0]
  1299. elif sin_path_2 is not None:
  1300. sin_path = sin_path_2
  1301. sin_cache = sin_path[-3].input[0]
  1302. elif sin_path_3 is not None:
  1303. sin_path = sin_path_3
  1304. sin_cache = sin_path[-4].input[0]
  1305. position_ids = sin_path[2].input[1]
  1306. elif sin_path_4 is not None:
  1307. sin_path = sin_path_4
  1308. sin_cache = sin_path[-3].input[0]
  1309. position_ids = sin_path[2].input[1]
  1310. else:
  1311. logger.debug("fuse_rotary_embeddings: failed to match sin path in apply_rope")
  1312. return
  1313. # Check path for cos
  1314. cos_path, cos_cache = None, ""
  1315. cos_path_1 = self.model.match_parent_path(
  1316. node,
  1317. ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Gather", "Shape"],
  1318. [0, 1, 0, 0, 0, 0, 2, 0, 0],
  1319. )
  1320. cos_path_2 = self.model.match_parent_path(
  1321. node,
  1322. ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Add"],
  1323. [0, 1, 0, 0, 0, 0, 2, 0],
  1324. )
  1325. cos_path_3 = self.model.match_parent_path(
  1326. node,
  1327. ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Gather", "Shape"],
  1328. [0, 1, 0, 0, 2, 0, 0],
  1329. )
  1330. cos_path_4 = self.model.match_parent_path(
  1331. node,
  1332. ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Add"],
  1333. [0, 1, 0, 0, 2, 0],
  1334. )
  1335. if cos_path_1 is not None:
  1336. cos_path = cos_path_1
  1337. cos_cache = cos_path[-4].input[0]
  1338. elif cos_path_2 is not None:
  1339. cos_path = cos_path_2
  1340. cos_cache = cos_path[-3].input[0]
  1341. elif cos_path_3 is not None:
  1342. cos_path = cos_path_3
  1343. cos_cache = cos_path[-4].input[0]
  1344. position_ids = cos_path[2].input[1]
  1345. elif cos_path_4 is not None:
  1346. cos_path = cos_path_4
  1347. cos_cache = cos_path[-3].input[0]
  1348. position_ids = cos_path[2].input[1]
  1349. else:
  1350. logger.debug("fuse_rotary_embeddings: failed to match sin path in apply_rope")
  1351. return
  1352. # Check path for position ids
  1353. if position_ids == "":
  1354. position_ids_from_sin_path = self.model.match_parent_path(
  1355. sin_path[2],
  1356. ["Reshape"],
  1357. [1],
  1358. )
  1359. position_ids_from_cos_path = self.model.match_parent_path(
  1360. cos_path[2],
  1361. ["Reshape"],
  1362. [1],
  1363. )
  1364. if (
  1365. position_ids_from_sin_path is None
  1366. or position_ids_from_cos_path is None
  1367. or position_ids_from_sin_path[0].name != position_ids_from_cos_path[0].name
  1368. ):
  1369. logger.debug("fuse_rotary_embeddings: failed to match position ids path in apply_rope")
  1370. return
  1371. position_ids = position_ids_from_cos_path[0].input[0]
  1372. else:
  1373. position_ids_from_sin_path = []
  1374. position_ids_from_cos_path = []
  1375. past_seq_len_path, curr_seq_len_path = None, None
  1376. if (sin_path == sin_path_1 and cos_path == cos_path_1) or (
  1377. sin_path == sin_path_3 and cos_path == cos_path_3
  1378. ):
  1379. if sin_path[-2].name != cos_path[-2].name or sin_path[-1].name != cos_path[-1].name:
  1380. logger.debug(
  1381. "fuse_rotary_embeddings: failed to match common Gather node and Shape node in sin cache and cos cache"
  1382. )
  1383. return
  1384. elif (sin_path == sin_path_2 and cos_path == cos_path_2) or (
  1385. sin_path == sin_path_4 and cos_path == cos_path_4
  1386. ):
  1387. if sin_path[-1].name != cos_path[-1].name:
  1388. logger.debug("fuse_rotary_embeddings: failed to match common Add node in sin cache and cos cache")
  1389. return
  1390. # Match past sequence length path: past_key --> Shape --> Gather --> Add
  1391. past_seq_len_path = self.model.match_parent_path(
  1392. sin_path[-1],
  1393. ["Gather", "Shape"],
  1394. [1, 0],
  1395. )
  1396. # Match current sequence length path: transpose_k --> Shape --> Gather --> Add
  1397. curr_seq_len_path = self.model.match_parent_path(
  1398. sin_path[-1],
  1399. ["Gather", "Shape", "Transpose"],
  1400. [0, 0, 0],
  1401. )
  1402. if (
  1403. past_seq_len_path is None
  1404. or curr_seq_len_path is None
  1405. or self.model.find_graph_input(past_seq_len_path[-1].input[0]) is None
  1406. or curr_seq_len_path[-1].op_type != "Transpose"
  1407. ):
  1408. logger.debug("fuse_rotary_embeddings: failed to match past_seq_len and curr_seq_len paths")
  1409. return
  1410. else:
  1411. logger.debug("fuse_rotary_embeddings: failed to match common cache paths")
  1412. rotary_emb_node = self.create_rotary_embeddings_from_nodes(
  1413. rotate_half_x1_path_1[-1].output[0],
  1414. position_ids,
  1415. cos_cache,
  1416. sin_cache,
  1417. node.output[0],
  1418. )
  1419. if rotary_emb_node is None:
  1420. logger.debug("fuse_rotary_embeddings: failed to create RotaryEmbedding node")
  1421. return
  1422. # Remove rotary embedding nodes
  1423. self.add_nodes_to_remove([node])
  1424. self.add_nodes_to_remove(rotate_half_x1_path_1[:-1])
  1425. self.add_nodes_to_remove(rotate_half_x1_path_2[:-1])
  1426. self.add_nodes_to_remove(rotate_half_x2_path_1[:-1])
  1427. self.add_nodes_to_remove(rotate_half_x2_path_2[:-1])
  1428. self.add_nodes_to_remove(x_path[:-1])
  1429. self.add_nodes_to_remove(sin_path)
  1430. self.add_nodes_to_remove(cos_path)
  1431. self.add_nodes_to_remove(position_ids_from_sin_path[:-1])
  1432. self.add_nodes_to_remove(position_ids_from_cos_path[:-1])
  1433. if past_seq_len_path is not None and len(self.model.get_children(past_seq_len_path[0])) == 1:
  1434. # In merged HF model, output of Gather in past_seq_len_path is used twice
  1435. # for past_key_values.0.key and once for other past_key_values
  1436. self.add_nodes_to_remove(past_seq_len_path)
  1437. if curr_seq_len_path is not None:
  1438. self.add_nodes_to_remove(curr_seq_len_path[:-1])
  1439. self.increase_counter(self.base_name)
  1440. self.node_name_to_graph_name[rotary_emb_node.name] = self.this_graph_name
  1441. self.nodes_to_add.append(rotary_emb_node)
  1442. self.prune_graph = True