fusion_mha_mmdit.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. import numpy as np
  7. from fusion_base import Fusion
  8. from fusion_utils import FusionUtils
  9. from onnx import NodeProto, TensorProto, helper, numpy_helper
  10. from onnx_model import OnnxModel
  11. logger = getLogger(__name__)
  12. class FusionMultiHeadAttentionMMDit(Fusion):
  13. """
  14. Fuse MultiHeadAttention for Multimodal Diffusion Transformer (MMDiT).
  15. """
  16. def __init__(self, model: OnnxModel):
  17. super().__init__(model, fused_op_type="MultiHeadAttention", search_op_types=["Softmax"])
  18. self.unsqueeze_update_map = {}
  19. def get_num_heads(self, start_node: NodeProto, output_name_to_node, input_index=0) -> int:
  20. """
  21. Detect num_heads from Reshape & Transpose of q/k/v for both Stable Diffusion 3.x and Flux 1.x:
  22. MatMul .. [-1] [24] ..
  23. | | | / /
  24. Add Concat(axis=0)
  25. | /
  26. Reshape
  27. |
  28. Transpose(perm=0,1,3,2)
  29. |
  30. (start_node)
  31. """
  32. nodes = self.model.match_parent_path(
  33. start_node, ["Transpose", "Reshape", "Concat"], [input_index, 0, 1], output_name_to_node=output_name_to_node
  34. )
  35. if nodes is None:
  36. return 0
  37. concat_shape = nodes[-1]
  38. if len(concat_shape.input) != 4:
  39. return 0
  40. value = self.model.get_constant_value(concat_shape.input[2])
  41. if value is None:
  42. return 0
  43. if len(value.shape) != 1:
  44. return 0
  45. return int(value[0])
  46. def get_num_heads_from_k(self, transpose_k: NodeProto, output_name_to_node, concat_before_transpose: bool) -> int:
  47. """
  48. Detect num_heads from subgraph like the following (num_heads=24 in this example):
  49. MatMu .. [-1] [24] ..
  50. | | | / /
  51. Add Concat
  52. | /
  53. Reshape
  54. |
  55. Transpose(perm=0,2,1,3)
  56. |
  57. SimplifiedLayerNormalization
  58. |
  59. Transpose(perm=0,1,3,2)
  60. Another variant is to an extra Concat node to join two symmetrical subgraphs:
  61. | |
  62. MatMul MatMul .. [-1] [24] ..
  63. | | | | / /
  64. Add Concat Add Concat
  65. | / | /
  66. Reshape Reshape
  67. | |
  68. Transpose Transpose(perm=0,2,1,3)
  69. | |
  70. SimplifiedLayerNormalization SimplifiedLayerNormalization
  71. | /
  72. Concat
  73. |
  74. Transpose(perm=0,1,3,2)
  75. Both patterns are used in stable diffusion 3.5 model.
  76. """
  77. if concat_before_transpose:
  78. nodes = self.model.match_parent_path(
  79. transpose_k, ["Concat", "SimplifiedLayerNormalization"], [0, 1], output_name_to_node=output_name_to_node
  80. )
  81. if nodes:
  82. return self.get_num_heads(nodes[1], output_name_to_node)
  83. else:
  84. nodes = self.model.match_parent_path(
  85. transpose_k, ["SimplifiedLayerNormalization"], [0], output_name_to_node=output_name_to_node
  86. )
  87. if nodes:
  88. return self.get_num_heads(nodes[0], output_name_to_node)
  89. return 0
  90. def reshape_to_3d(self, input_name: str, output_name: str) -> str:
  91. """Add a Reshape node to convert 4D BxSxNxH to 3D BxSxD.
  92. Args:
  93. input_name (str): input name for the 4D tensor of shape BxSxNxH.
  94. output_name (str): output name for the 3D tensor of shape BxSxD, where D = N * H.
  95. Returns:
  96. str: the output name
  97. """
  98. new_dims_name = "bsnh_to_bsd_reshape_dims"
  99. new_dims = self.model.get_initializer(new_dims_name)
  100. if new_dims is None:
  101. new_dims = numpy_helper.from_array(np.array([0, 0, -1], dtype="int64"), name=new_dims_name)
  102. self.model.add_initializer(new_dims, self.this_graph_name)
  103. reshape_q = helper.make_node(
  104. "Reshape",
  105. inputs=[input_name, new_dims_name],
  106. outputs=[output_name],
  107. name=self.model.create_node_name("Reshape"),
  108. )
  109. self.nodes_to_add.append(reshape_q)
  110. self.node_name_to_graph_name[reshape_q.name] = self.this_graph_name
  111. return reshape_q.output[0]
  112. def adjust_query_from_bnsh_to_bsd_no_concat(self, mul_q: NodeProto, output_name_to_node) -> str | None:
  113. """
  114. MultiHeadAttenion requires query in BSD format. This function adjusts query from BNSH to BSD format.
  115. Before:
  116. MatMul
  117. |
  118. Add Concat
  119. | /
  120. Reshape
  121. |
  122. Transpose(perm=0,2,1,3)
  123. |
  124. SimplifiedLayerNorm
  125. |
  126. Mul
  127. After:
  128. MatMul
  129. |
  130. Add Concat
  131. | /
  132. Reshape
  133. |
  134. SimplifiedLayerNorm
  135. |
  136. Reshape (shape=[0, 0, -1])
  137. """
  138. path = self.model.match_parent_path(
  139. mul_q,
  140. ["SimplifiedLayerNormalization", "Transpose"],
  141. [0, 0],
  142. )
  143. if path is None:
  144. return None
  145. sln_a, transpose_a = path
  146. if not FusionUtils.check_node_attribute(transpose_a, "perm", [0, 2, 1, 3]):
  147. return None
  148. # Update the graph
  149. sln_a.input[0] = transpose_a.input[0]
  150. sln_output = sln_a.output[0]
  151. sln_a.output[0] = sln_output + "_BSNH"
  152. return self.reshape_to_3d(sln_a.output[0], sln_output + "_BSD")
  153. def adjust_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) -> str | None:
  154. """
  155. MultiHeadAttenion requires query in BSD format. This function adjusts query from BNSH to BSD format.
  156. Before:
  157. MatMul MatMul
  158. | |
  159. Add Concat Add Concat
  160. | / | /
  161. Reshape Reshape
  162. | |
  163. Transpose(perm=0,2,1,3) Transpose(perm=0,2,1,3)
  164. | |
  165. SimplifiedLayerNorm SimplifiedLayerNorm
  166. | /
  167. Concat(axis=2)
  168. |
  169. Mul
  170. After:
  171. MatMul MatMul
  172. | |
  173. Add Concat Add Concat
  174. | / | /
  175. Reshape Reshape
  176. | |
  177. SimplifiedLayerNorm SimplifiedLayerNorm
  178. | /
  179. Concat(axis=1)
  180. |
  181. Reshape (shape=[0, 0, -1])
  182. """
  183. path = self.model.match_parent_path(
  184. mul_q,
  185. ["Concat", "SimplifiedLayerNormalization", "Transpose"],
  186. [0, 0, 0],
  187. )
  188. if path is None:
  189. return None
  190. concat, sln_a, transpose_a = path
  191. if len(concat.input) != 2:
  192. return None
  193. path = self.model.match_parent_path(
  194. concat,
  195. ["SimplifiedLayerNormalization", "Transpose"],
  196. [1, 0],
  197. )
  198. if path is None:
  199. return None
  200. sln_b, transpose_b = path
  201. if not FusionUtils.check_node_attribute(transpose_a, "perm", [0, 2, 1, 3]):
  202. return None
  203. if not FusionUtils.check_node_attribute(transpose_b, "perm", [0, 2, 1, 3]):
  204. return None
  205. if not FusionUtils.check_node_attribute(concat, "axis", 2):
  206. return None
  207. # Update the graph
  208. sln_a.input[0] = transpose_a.input[0]
  209. sln_b.input[0] = transpose_b.input[0]
  210. new_concat_node = helper.make_node(
  211. "Concat",
  212. inputs=[sln_a.output[0], sln_b.output[0]],
  213. outputs=[concat.output[0] + "_BSNH"],
  214. name=self.model.create_node_name("Concat"),
  215. axis=1,
  216. )
  217. self.nodes_to_add.append(new_concat_node)
  218. self.node_name_to_graph_name[new_concat_node.name] = self.this_graph_name
  219. return self.reshape_to_3d(new_concat_node.output[0], concat.output[0] + "_BSD")
  220. def update_unsqueeze_axes_1_to_2(self, unsqueeze: NodeProto) -> str:
  221. updated_unsqueeze_output = self.unsqueeze_update_map.get(unsqueeze.name)
  222. if updated_unsqueeze_output is None:
  223. if len(unsqueeze.input) == 1:
  224. new_node = helper.make_node(
  225. "Unsqueeze",
  226. inputs=unsqueeze.input,
  227. outputs=[unsqueeze.output[0] + "_BSNH"],
  228. name=self.model.create_node_name("Unsqueeze"),
  229. axes=[2],
  230. )
  231. else:
  232. initializer_name = "unsqueeze_axes_2"
  233. if self.model.get_initializer(initializer_name) is None:
  234. unsqueeze_axes_2 = helper.make_tensor(
  235. name=initializer_name,
  236. data_type=TensorProto.INT64,
  237. dims=[1], # Shape of the tensor
  238. vals=[2], # Tensor values
  239. )
  240. self.model.add_initializer(unsqueeze_axes_2, self.this_graph_name)
  241. new_node = helper.make_node(
  242. "Unsqueeze",
  243. inputs=[unsqueeze.input[0], initializer_name],
  244. outputs=[unsqueeze.output[0] + "_BSNH"],
  245. name=self.model.create_node_name("Unsqueeze"),
  246. )
  247. self.nodes_to_add.append(new_node)
  248. self.node_name_to_graph_name[new_node.name] = self.this_graph_name
  249. updated_unsqueeze_output = new_node.output[0]
  250. self.unsqueeze_update_map[unsqueeze.name] = updated_unsqueeze_output
  251. return updated_unsqueeze_output
  252. def update_unsqueeze_axes(self, add: NodeProto, output_name_to_node: dict[str, NodeProto]) -> bool:
  253. """
  254. Update axes of Unsqueeze from [1] to [2] in the following pattern:
  255. Unsqueeze Unsqueeze
  256. (axes=[0]) (axes=[0])
  257. | |
  258. Unsqueeze Unsqueeze
  259. ... (axes=[1]) ... (axes=[1])
  260. | / | /
  261. Mul Mul
  262. | /
  263. Add
  264. Args:
  265. add (NodeProto): the Add node
  266. output_name_to_node (Dict[str, NodeProto]): mapping from output name to node
  267. Returns:
  268. bool: True if the pattern is matched and updated successfully, False otherwise.
  269. """
  270. if len(add.input) != 2:
  271. return False
  272. # Check axes of Unsqueeze nodes are [0] and [1], and change to [0] and [2] respectively.
  273. nodes_b = self.model.match_parent_path(add, ["Mul", "Unsqueeze", "Unsqueeze"], [1, 1, 0], output_name_to_node)
  274. if nodes_b is None:
  275. return False
  276. fusion_utils = FusionUtils(self.model)
  277. axes_1 = fusion_utils.get_squeeze_or_unsqueeze_axes(nodes_b[1])
  278. if axes_1 is None or axes_1 != [1]:
  279. return False
  280. axes_0 = fusion_utils.get_squeeze_or_unsqueeze_axes(nodes_b[2])
  281. if axes_0 is None or axes_0 != [0]:
  282. return False
  283. # Check axes of Unsqueeze nodes are [0] and [1], and change to [0] and [2] respectively.
  284. nodes_a = self.model.match_parent_path(add, ["Mul", "Unsqueeze", "Unsqueeze"], [0, 1, 0], output_name_to_node)
  285. if nodes_a is None:
  286. return False
  287. axes_1 = fusion_utils.get_squeeze_or_unsqueeze_axes(nodes_a[1])
  288. if axes_1 is None or axes_1 != [1]:
  289. return False
  290. axes_0 = fusion_utils.get_squeeze_or_unsqueeze_axes(nodes_a[2])
  291. if axes_0 is None or axes_0 != [0]:
  292. return False
  293. nodes_a[0].input[1] = self.update_unsqueeze_axes_1_to_2(nodes_a[1])
  294. nodes_b[0].input[1] = self.update_unsqueeze_axes_1_to_2(nodes_b[1])
  295. return True
  296. def adjust_flux_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) -> str | None:
  297. """
  298. Adjust graph to change query format from BNSH to BSD for Flux model.
  299. Note that the graph pattern is complex, and we only do a shallow match here.
  300. Before:
  301. | |
  302. Transpose(perm=0,2,1,3) Transpose(perm=0,2,1,3)
  303. | |
  304. SimplifiedLayerNorm SimplifiedLayerNorm
  305. | /
  306. Concat(axis=2)
  307. |
  308. Mul Mul
  309. | /
  310. Add
  311. |
  312. Mul
  313. After (Transpose nods are removed, and a Reshape is added):
  314. | |
  315. SimplifiedLayerNorm SimplifiedLayerNorm
  316. | /
  317. Concat(axis=1)
  318. |
  319. Mul Mul
  320. | /
  321. Add
  322. |
  323. Reshape (shape=[0, 0, -1])
  324. """
  325. path = self.model.match_parent_path(
  326. mul_q,
  327. ["Add", "Mul", "Concat", "SimplifiedLayerNormalization", "Transpose"],
  328. [0, 0, 0, 0, 0],
  329. )
  330. if path is None:
  331. return None
  332. add, _mul_a, concat, sln_a, transpose_a = path
  333. if len(concat.input) != 2:
  334. return None
  335. path = self.model.match_parent_path(
  336. concat,
  337. ["SimplifiedLayerNormalization", "Transpose"],
  338. [1, 0],
  339. )
  340. if path is None:
  341. return None
  342. sln_b, transpose_b = path
  343. if not FusionUtils.check_node_attribute(transpose_a, "perm", [0, 2, 1, 3]):
  344. return None
  345. if not FusionUtils.check_node_attribute(transpose_b, "perm", [0, 2, 1, 3]):
  346. return None
  347. if not FusionUtils.check_node_attribute(concat, "axis", 2):
  348. return None
  349. # Need adjust axes of Unsqueeze nodes from [1] to [2] so that the tensors to Mul nodes are BSNH instead of BNSH.
  350. if not self.update_unsqueeze_axes(add, output_name_to_node):
  351. return None
  352. # Update the graph
  353. sln_a.input[0] = transpose_a.input[0]
  354. sln_b.input[0] = transpose_b.input[0]
  355. new_concat_node = helper.make_node(
  356. "Concat",
  357. inputs=[sln_a.output[0], sln_b.output[0]],
  358. outputs=[concat.output[0] + "_BSNH"],
  359. name=self.model.create_node_name("Concat"),
  360. axis=1,
  361. )
  362. self.nodes_to_add.append(new_concat_node)
  363. self.node_name_to_graph_name[new_concat_node.name] = self.this_graph_name
  364. self.model.replace_input_of_all_nodes(concat.output[0], new_concat_node.output[0])
  365. return self.reshape_to_3d(add.output[0], add.output[0] + "_BSD")
  366. def adjust_flux_single_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) -> str | None:
  367. """
  368. Adjust graph to change query format from BNSH to BSD for Flux model.
  369. Note that the graph pattern is complex, and we only do a shallow match here.
  370. Before:
  371. |
  372. Transpose(perm=0,2,1,3)
  373. |
  374. SimplifiedLayerNorm
  375. |
  376. Mul Mul
  377. | /
  378. Add
  379. |
  380. Mul
  381. After (Transpose is removed, and a Reshape is added):
  382. |
  383. SimplifiedLayerNorm
  384. |
  385. Mul Mul
  386. | /
  387. Add
  388. |
  389. Reshape (shape=[0, 0, -1])
  390. """
  391. path = self.model.match_parent_path(
  392. mul_q,
  393. ["Add", "Mul", "SimplifiedLayerNormalization", "Transpose"],
  394. [0, 0, 0, 0],
  395. )
  396. if path is None:
  397. return None
  398. add, _mul_a, sln_a, transpose_a = path
  399. if not FusionUtils.check_node_attribute(transpose_a, "perm", [0, 2, 1, 3]):
  400. return None
  401. # Need adjust axes of Unsqueeze nodes from [1] to [2] so that the tensors to Mul nodes are BSNH instead of BNSH.
  402. if not self.update_unsqueeze_axes(add, output_name_to_node):
  403. return None
  404. # Update the graph
  405. sln_a.input[0] = transpose_a.input[0]
  406. add.output[0] = add.output[0] + "_BSNH"
  407. return self.reshape_to_3d(add.output[0], add.output[0] + "_BSD")
  408. def transpose_reshape_bnsh_to_bsd(self, q: str, output_name_to_node) -> str | None:
  409. transpose_q = helper.make_node(
  410. "Transpose",
  411. [q],
  412. [q + "_BSNH"],
  413. name=self.model.create_node_name("Transpose", name_prefix="Transpose_BNSH_to_BSNH"),
  414. perm=[0, 2, 1, 3],
  415. )
  416. self.nodes_to_add.append(transpose_q)
  417. self.node_name_to_graph_name[transpose_q.name] = self.this_graph_name
  418. return self.reshape_to_3d(q + "_BSNH", q + "_BSD")
  419. def create_multihead_attention_node(
  420. self,
  421. q: str,
  422. k: str,
  423. v: str,
  424. output: str,
  425. num_heads: int,
  426. ) -> NodeProto:
  427. """
  428. Create a MultiHeadAttention node.
  429. Args:
  430. q (str): name of q
  431. k (str): name of k
  432. v (str): name of v
  433. output (str): output name of MHA
  434. num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
  435. Returns:
  436. NodeProto: the node created.
  437. """
  438. assert num_heads > 0
  439. # Add inputs for MHA: Query, Key, Value (Proj_Bias, Mask, Attention_Bias, Past_K, Past_V are optional)
  440. mha_inputs = [q, k, v]
  441. # Add outputs for MHA (Present_K, Present_V are optional)
  442. mha_outputs = [output]
  443. mha_node = helper.make_node(
  444. "MultiHeadAttention",
  445. inputs=mha_inputs,
  446. outputs=mha_outputs,
  447. name=self.model.create_node_name("MultiHeadAttention"),
  448. )
  449. mha_node.domain = "com.microsoft"
  450. mha_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
  451. # No mask is used in MMDit model, so we need not set the optional mask_filter_value attribute.
  452. return mha_node
  453. def fuse(self, node, input_name_to_nodes, output_name_to_node):
  454. assert node.op_type == "Softmax"
  455. softmax = node
  456. # Softmax output shall not be graph output.
  457. if self.model.find_graph_output(softmax.output[0]):
  458. return
  459. nodes = self.model.match_child_path(
  460. softmax, ["MatMul", "Transpose", "Reshape"], [(0, 0), (0, 0), (0, 0)], input_name_to_nodes
  461. )
  462. if nodes is None:
  463. return
  464. matmul_s_v, transpose_out, reshape_out = nodes
  465. if not FusionUtils.check_node_attribute(transpose_out, "perm", [0, 2, 1, 3]):
  466. return
  467. q_nodes = self.model.match_parent_path(
  468. softmax,
  469. ["MatMul", "Mul", "Sqrt", "Div", "Sqrt", "Cast", "Slice", "Shape"],
  470. [0, 0, 1, 0, 1, 0, 0, 0],
  471. )
  472. if q_nodes is None:
  473. return
  474. matmul_qk, mul_q, sqrt_q_2, div_q, sqrt_q, _, _, shape_q = q_nodes
  475. q_bnsh = mul_q.input[0]
  476. if q_bnsh != shape_q.input[0]:
  477. return
  478. k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose"], [1, 0])
  479. if k_nodes is None:
  480. return
  481. mul_k, transpose_k = k_nodes
  482. k = transpose_k.input[0]
  483. if not FusionUtils.check_node_attribute(transpose_k, "perm", [0, 1, 3, 2]):
  484. return
  485. k_scale_nodes = self.model.match_parent_path(mul_k, ["Sqrt", "Div"], [1, 0])
  486. if k_scale_nodes is None:
  487. return
  488. if k_scale_nodes[0].input[0] != sqrt_q_2.input[0]:
  489. return
  490. v = matmul_s_v.input[1]
  491. # Here we sanity check the v path to make sure it is in the expected BNSH format.
  492. concat_v = self.model.match_parent(matmul_s_v, "Concat", input_index=1, output_name_to_node=output_name_to_node)
  493. if concat_v is not None:
  494. # Match v path like:
  495. # -- Transpose (perm=[0,2,1,3]) ----+
  496. # |
  497. # v
  498. # -- Transpose (perm=[0,2,1,3]) -> Concat -> (v)
  499. transpose_1 = self.model.match_parent(
  500. concat_v, "Transpose", input_index=0, output_name_to_node=output_name_to_node
  501. )
  502. if transpose_1 is None:
  503. return
  504. if not FusionUtils.check_node_attribute(transpose_1, "perm", [0, 2, 1, 3]):
  505. return
  506. transpose_2 = self.model.match_parent(
  507. concat_v, "Transpose", input_index=1, output_name_to_node=output_name_to_node
  508. )
  509. if transpose_2 is None:
  510. return
  511. if not FusionUtils.check_node_attribute(transpose_2, "perm", [0, 2, 1, 3]):
  512. return
  513. else:
  514. # Match v path like:
  515. # -- Transpose (perm=[0,2,1,3]) -> (v)
  516. transpose_1 = self.model.match_parent(
  517. matmul_s_v, "Transpose", input_index=1, output_name_to_node=output_name_to_node
  518. )
  519. if transpose_1 is None:
  520. return
  521. if not FusionUtils.check_node_attribute(transpose_1, "perm", [0, 2, 1, 3]):
  522. return
  523. # Match patterns for Flux.
  524. num_heads = (
  525. self.get_num_heads(concat_v, output_name_to_node)
  526. if concat_v
  527. else self.get_num_heads(matmul_s_v, output_name_to_node, input_index=1)
  528. )
  529. if num_heads == 0:
  530. # Match patterns for Stable Diffusion 3.5.
  531. num_heads = self.get_num_heads_from_k(transpose_k, output_name_to_node, concat_v is not None)
  532. if num_heads <= 0:
  533. return
  534. # Q is in BNSH format, we need to adjust it to BSD format due to limitation of MHA op.
  535. # TODO: MHA op support BNSH format to reduce the effort in fusion.
  536. if concat_v is not None:
  537. query = self.adjust_query_from_bnsh_to_bsd(mul_q, output_name_to_node)
  538. else:
  539. query = self.adjust_query_from_bnsh_to_bsd_no_concat(mul_q, output_name_to_node)
  540. if query is None:
  541. query = self.adjust_flux_query_from_bnsh_to_bsd(mul_q, output_name_to_node)
  542. if query is None:
  543. query = self.adjust_flux_single_query_from_bnsh_to_bsd(mul_q, output_name_to_node)
  544. if query is None:
  545. # fallback to use Transpose and Add to adjust query from BNSH to BSD
  546. # This is more general approach.
  547. # However, it might be slower if the extra Transpose node cannot be removed by ORT optimizer.
  548. query = self.transpose_reshape_bnsh_to_bsd(q_bnsh, output_name_to_node)
  549. new_node = self.create_multihead_attention_node(
  550. q=query,
  551. k=k,
  552. v=v,
  553. output=reshape_out.output[0],
  554. num_heads=num_heads,
  555. )
  556. self.nodes_to_add.append(new_node)
  557. self.node_name_to_graph_name[new_node.name] = self.this_graph_name
  558. self.nodes_to_remove.extend([matmul_s_v, transpose_out, reshape_out])
  559. # Use prune graph to remove nodes
  560. self.prune_graph = True