fusion_attention.py 49 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. import numpy as np
  7. from fusion_base import Fusion
  8. from fusion_options import AttentionMaskFormat
  9. from fusion_utils import FusionUtils, NumpyHelper
  10. from onnx import NodeProto, TensorProto, helper, numpy_helper
  11. from onnx_model import OnnxModel
  12. logger = getLogger(__name__)
  13. class AttentionMask:
  14. """
  15. Fuse Attention subgraph into one Attention node.
  16. """
  17. def __init__(self, model: OnnxModel):
  18. self.model = model
  19. # A lookup table with mask input as key, and mask index output as value
  20. self.mask_indice = {}
  21. # A lookup table with mask input as key, and cast (to int32) output as value
  22. self.mask_casted = {}
  23. self.utils = FusionUtils(model)
  24. self.mask_format = AttentionMaskFormat.MaskIndexEnd
  25. self.opset_version = model.get_opset_version()
  26. def set_mask_format(self, mask_format: AttentionMaskFormat):
  27. self.mask_format = mask_format
  28. def set_mask_indice(self, mask, mask_index):
  29. if mask in self.mask_indice:
  30. assert mask_index == self.mask_indice[mask]
  31. self.mask_indice[mask] = mask_index
  32. def get_first_mask(self):
  33. assert len(self.mask_indice) > 0
  34. return next(iter(self.mask_indice))
  35. def process_mask(self, mask_2d: str) -> str | None:
  36. if self.mask_format == AttentionMaskFormat.NoMask:
  37. return None
  38. if mask_2d in self.mask_indice:
  39. return self.mask_indice[mask_2d]
  40. # Add cast to convert int64 to int32
  41. if self.model.find_graph_input(mask_2d):
  42. casted, input_name = self.utils.cast_graph_input_to_int32(mask_2d)
  43. else:
  44. input_name, _cast_node = self.utils.cast_input_to_int32(mask_2d)
  45. casted = True
  46. if casted:
  47. self.mask_casted[mask_2d] = input_name
  48. # Attention supports int32 attention mask (2D) since 1.4.0
  49. if self.mask_format == AttentionMaskFormat.AttentionMask:
  50. self.mask_indice[mask_2d] = input_name
  51. return input_name
  52. # Add a mask processing node to convert attention mask to mask index (1D)
  53. output_name = self.model.create_node_name("mask_index")
  54. if self.opset_version < 13:
  55. mask_index_node = helper.make_node(
  56. "ReduceSum",
  57. inputs=[input_name],
  58. outputs=[output_name],
  59. name=self.model.create_node_name("ReduceSum", "MaskReduceSum"),
  60. )
  61. mask_index_node.attribute.extend([helper.make_attribute("axes", [1]), helper.make_attribute("keepdims", 0)])
  62. else:
  63. # ReduceSum-13: axes is moved from attribute to input
  64. axes_name = "ort_const_1_reduce_sum_axes"
  65. if self.model.get_initializer(axes_name) is None:
  66. self.model.add_initializer(
  67. helper.make_tensor(
  68. name=axes_name,
  69. data_type=TensorProto.INT64,
  70. dims=[1],
  71. vals=[1],
  72. raw=False,
  73. )
  74. )
  75. mask_index_node = helper.make_node(
  76. "ReduceSum",
  77. inputs=[input_name, axes_name],
  78. outputs=[output_name],
  79. name=self.model.create_node_name("ReduceSum", "MaskReduceSum"),
  80. )
  81. mask_index_node.attribute.extend([helper.make_attribute("keepdims", 0)])
  82. self.model.add_node(mask_index_node)
  83. self.mask_indice[mask_2d] = output_name
  84. return output_name
  85. class FusionAttention(Fusion):
  86. """
  87. Fuse Attention subgraph into one Attention node.
  88. """
  89. def __init__(
  90. self,
  91. model: OnnxModel,
  92. hidden_size: int,
  93. num_heads: int,
  94. attention_mask: AttentionMask | None = None,
  95. use_multi_head_attention: bool = False,
  96. disable_multi_head_attention_bias: bool = False,
  97. search_op_types: list[str] = ["SkipLayerNormalization", "LayerNormalization"], # noqa: B006
  98. ):
  99. attention_op_name = "MultiHeadAttention" if use_multi_head_attention else "Attention"
  100. super().__init__(model, attention_op_name, search_op_types)
  101. self.hidden_size = hidden_size
  102. self.num_heads = num_heads
  103. self.attention_mask = attention_mask if attention_mask else AttentionMask(model)
  104. self.use_multi_head_attention = use_multi_head_attention
  105. self.disable_multi_head_attention_bias = disable_multi_head_attention_bias
  106. self.mask_filter_value = None
  107. # Flags to show warning only once
  108. self.num_heads_warning = True
  109. self.hidden_size_warning = True
  110. self.shape_infer = None
  111. self.shape_infer_done = True
  112. def get_num_heads_and_hidden_size_from_concat(self, concat: NodeProto) -> tuple[int, int]:
  113. """
  114. Detect num_heads and hidden_size from Concat node in the following subgraph:
  115. SkipLayerNormalization or EmbedLayerNormalization
  116. / |
  117. MatMul Shape
  118. | |
  119. Add Gather(indices=0)
  120. | |
  121. | Unsqueeze
  122. | |
  123. | Concat (*, -1, 12, 64)
  124. | /
  125. Reshape
  126. |
  127. Transpose
  128. """
  129. if len(concat.input) == 4:
  130. num_heads = self.model.get_constant_value(concat.input[2])
  131. head_size = self.model.get_constant_value(concat.input[3])
  132. if (
  133. isinstance(num_heads, np.ndarray)
  134. and num_heads.size == 1
  135. and isinstance(head_size, np.ndarray)
  136. and head_size.size == 1
  137. ):
  138. return num_heads[0], num_heads[0] * head_size[0]
  139. return self.num_heads, self.hidden_size
  140. def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> tuple[int, int]:
  141. """Detect num_heads and hidden_size from a reshape node.
  142. Args:
  143. reshape_q (NodeProto): reshape node for Q
  144. Returns:
  145. Tuple[int, int]: num_heads and hidden_size
  146. """
  147. # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
  148. q_shape_value = self.model.get_constant_value(reshape_q.input[1])
  149. if q_shape_value is None:
  150. concat = self.model.get_parent(reshape_q, 1)
  151. if concat is not None and concat.op_type == "Concat":
  152. return self.get_num_heads_and_hidden_size_from_concat(concat)
  153. logger.debug("%s is not initializer.", reshape_q.input[1])
  154. return self.num_heads, self.hidden_size # Fall back to user specified value
  155. if (
  156. (not isinstance(q_shape_value, np.ndarray))
  157. or len(q_shape_value) != 4
  158. or (q_shape_value[2] <= 0 or q_shape_value[3] <= 0)
  159. ):
  160. logger.debug("q_shape_value=%s. Expected value are like [0, 0, num_heads, head_size].", q_shape_value)
  161. return self.num_heads, self.hidden_size # Fall back to user specified value
  162. num_heads = q_shape_value[2]
  163. head_size = q_shape_value[3]
  164. hidden_size = num_heads * head_size
  165. if self.num_heads > 0 and num_heads != self.num_heads:
  166. if self.num_heads_warning:
  167. logger.warning(
  168. "--num_heads is %d. Detected value is %d. Using detected value.", self.num_heads, num_heads
  169. )
  170. self.num_heads_warning = False # Do not show the warning more than once
  171. if self.hidden_size > 0 and hidden_size != self.hidden_size:
  172. if self.hidden_size_warning:
  173. logger.warning(
  174. "--hidden_size is %d. Detected value is %d. Using detected value.", self.hidden_size, hidden_size
  175. )
  176. self.hidden_size_warning = False # Do not show the warning more than once
  177. return num_heads, hidden_size
  178. def get_add_qk_str(self, add_qk: NodeProto):
  179. if not self.shape_infer_done:
  180. self.shape_infer = self.model.infer_runtime_shape(update=True)
  181. self.shape_infer_done = True
  182. if self.shape_infer is None:
  183. return None
  184. input_0_shape = self.shape_infer.get_edge_shape(add_qk.input[0])
  185. input_1_shape = self.shape_infer.get_edge_shape(add_qk.input[1])
  186. if input_0_shape is None or input_1_shape is None:
  187. logger.debug("one of the inputs of %s is None", add_qk)
  188. return None
  189. if input_0_shape != input_1_shape:
  190. logger.debug("the shape of two inputs of %s is not same", add_qk)
  191. return None
  192. return add_qk.input[1]
  193. def reshape_add_qk(self, add_qk: str):
  194. # Convert 4D mask from (B,1,S,T) to (B,N,S,T)
  195. # B = batch size, N = num heads, S = source sequence length, T = target sequence length
  196. mask_output_name = add_qk + "_mask"
  197. # Check if concat node for (B,1,S,T) --> (B,N,S,T) already exists
  198. concat_node = list(filter(lambda node: node.output[0] == mask_output_name, self.nodes_to_add))
  199. if len(concat_node) == 1:
  200. return mask_output_name
  201. assert len(concat_node) == 0
  202. concat_node_name = self.model.create_node_name("Concat")
  203. concat_add_qk_fp32 = helper.make_node(
  204. "Concat",
  205. inputs=[add_qk for _ in range(self.num_heads)],
  206. outputs=[mask_output_name],
  207. name=concat_node_name,
  208. axis=1,
  209. )
  210. # Add new node to graph
  211. self.nodes_to_add.append(concat_add_qk_fp32)
  212. self.node_name_to_graph_name[concat_node_name] = self.this_graph_name
  213. return mask_output_name
  214. def concat_kv(self, past_k: str, past_v: str) -> str:
  215. """Concatenate past_k and past_v inputs to create past_kv input.
  216. Args:
  217. past_k (str): name of past K value
  218. past_v (str): name of past V value
  219. Returns:
  220. kv_output_name (str): name of past KV value
  221. """
  222. # Unsqueeze K and V nodes from (B,N,P,H) to (1,B,N,P,H)
  223. # B = batch size, N = num heads, P = past sequence length, H = head size
  224. unsqueeze_k_name = self.model.create_node_name("Unsqueeze")
  225. unsqueeze_v_name = self.model.create_node_name("Unsqueeze")
  226. k_5d_name = (past_k + "_5d").replace(".", "_")
  227. v_5d_name = (past_v + "_5d").replace(".", "_")
  228. k_5d = helper.make_node(
  229. "Unsqueeze",
  230. inputs=[past_k],
  231. outputs=[k_5d_name],
  232. name=unsqueeze_k_name,
  233. axes=[0],
  234. )
  235. v_5d = helper.make_node(
  236. "Unsqueeze",
  237. inputs=[past_v],
  238. outputs=[v_5d_name],
  239. name=unsqueeze_v_name,
  240. axes=[0],
  241. )
  242. # Add unsqueeze nodes to graph
  243. self.nodes_to_add.append(k_5d)
  244. self.nodes_to_add.append(v_5d)
  245. self.node_name_to_graph_name[unsqueeze_k_name] = self.this_graph_name
  246. self.node_name_to_graph_name[unsqueeze_v_name] = self.this_graph_name
  247. # Concat K and V to get one node of size (2,B,N,P,H)
  248. concat_node_name = self.model.create_node_name("Concat")
  249. kv_output_name = past_v.replace(".value", ".kv").replace(".", "_").replace("_value", "_kv")
  250. concat_kv = helper.make_node(
  251. "Concat",
  252. inputs=[k_5d_name, v_5d_name],
  253. outputs=[kv_output_name],
  254. name=concat_node_name,
  255. axis=0,
  256. )
  257. # Add concat node to graph
  258. self.nodes_to_add.append(concat_kv)
  259. self.node_name_to_graph_name[concat_node_name] = self.this_graph_name
  260. return kv_output_name
  261. def split_kv(self, present_k_name: str, present_v_name: str, kv_node: str):
  262. """Split kv_node containing present KV values into separate present K and present V values.
  263. Args:
  264. present_k_name (str): name of output to store present K value in
  265. present_v_name (str): name of output to store present V value in
  266. kv_node (str): name of present KV values
  267. """
  268. # Split kv_node into present_k and present_v nodes
  269. # Create initializers for indexing kv_node, whose shape is (2,B,N,P,H)
  270. k_index, v_index = "index_0", "index_1"
  271. k_dim = self.model.get_initializer(k_index)
  272. v_dim = self.model.get_initializer(v_index)
  273. if k_dim is None:
  274. k_dim = numpy_helper.from_array(np.array(0, dtype="int64"), name=k_index)
  275. self.model.add_initializer(k_dim, self.this_graph_name)
  276. if v_dim is None:
  277. v_dim = numpy_helper.from_array(np.array(1, dtype="int64"), name=v_index)
  278. self.model.add_initializer(v_dim, self.this_graph_name)
  279. # Create nodes to index kv_node
  280. gather_k_name = self.model.create_node_name("Gather")
  281. gather_v_name = self.model.create_node_name("Gather")
  282. present_k = helper.make_node(
  283. "Gather",
  284. inputs=[kv_node, k_index],
  285. outputs=[present_k_name],
  286. name=gather_k_name,
  287. axis=0,
  288. )
  289. present_v = helper.make_node(
  290. "Gather",
  291. inputs=[kv_node, v_index],
  292. outputs=[present_v_name],
  293. name=gather_v_name,
  294. axis=0,
  295. )
  296. # Add gather nodes to graph
  297. self.nodes_to_add.append(present_k)
  298. self.nodes_to_add.append(present_v)
  299. self.node_name_to_graph_name[gather_k_name] = self.this_graph_name
  300. self.node_name_to_graph_name[gather_v_name] = self.this_graph_name
  301. def create_combined_qkv_bias(
  302. self,
  303. q_add: NodeProto,
  304. k_add: NodeProto | None,
  305. v_add: NodeProto | None,
  306. name_prefix: str,
  307. ) -> NodeProto | None:
  308. q_bias = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0])
  309. qb = NumpyHelper.to_array(q_bias)
  310. kb = np.zeros_like(qb)
  311. vb = np.zeros_like(qb)
  312. if k_add is not None:
  313. k_bias = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0])
  314. kb = NumpyHelper.to_array(k_bias)
  315. if v_add is not None:
  316. v_bias = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0])
  317. vb = NumpyHelper.to_array(v_bias)
  318. qkv_bias = np.stack((qb, kb, vb), axis=0)
  319. qkv_bias_dim = 3 * np.prod(qb.shape)
  320. bias_name = name_prefix + "_qkv_bias"
  321. self.add_initializer(
  322. name=bias_name,
  323. data_type=q_bias.data_type,
  324. dims=[qkv_bias_dim],
  325. vals=qkv_bias,
  326. )
  327. return bias_name
  328. def create_packed_qkv_matmul_node(
  329. self,
  330. q_matmul: NodeProto,
  331. k_matmul: NodeProto,
  332. v_matmul: NodeProto,
  333. q_add: NodeProto,
  334. k_add: NodeProto | None,
  335. v_add: NodeProto | None,
  336. ) -> tuple[NodeProto, NodeProto, NodeProto]:
  337. """Create packed QKV MatMul node before MultiHeadAttention node.
  338. This is for the scenario where an Attention node should be created but cannot be created
  339. because past_key and past_value are separate inputs and not one concatenated input.
  340. Args:
  341. q_matmul (NodeProto): name of MatMul from Q path - (batch_size, sequence_length, hidden_size)
  342. k_matmul (NodeProto): name of MatMul from K path - (batch_size, sequence_length, hidden_size)
  343. v_matmul (NodeProto): name of MatMul from V path - (batch_size, sequence_length, hidden_size)
  344. q_add (NodeProto): name of Add from Q path
  345. k_add (NodeProto): name of Add from K path
  346. v_add (NodeProto): name of Add from V path
  347. Returns:
  348. q_output (NodeProto): Slice node for Q
  349. k_output (NodeProto): Slice node for K
  350. v_output (NodeProto): Slice node for V
  351. """
  352. matmul_node_name = self.model.create_node_name("MatMul")
  353. # Check that input for Q, K, V is the same
  354. assert q_matmul.input[0] == k_matmul.input[0] and k_matmul.input[0] == v_matmul.input[0]
  355. # Created packed QKV weight
  356. q_weight = self.model.get_initializer(q_matmul.input[1])
  357. k_weight = self.model.get_initializer(k_matmul.input[1])
  358. v_weight = self.model.get_initializer(v_matmul.input[1])
  359. qw = NumpyHelper.to_array(q_weight)
  360. kw = NumpyHelper.to_array(k_weight)
  361. vw = NumpyHelper.to_array(v_weight)
  362. assert qw.shape == kw.shape and kw.shape == vw.shape
  363. d = qw.shape[0]
  364. qkv_weight = np.stack((qw, kw, vw), axis=1).reshape((d, 3 * d))
  365. qkv_weight_name = matmul_node_name + "_qkv_weight"
  366. self.add_initializer(
  367. name=qkv_weight_name,
  368. data_type=q_weight.data_type,
  369. dims=[qkv_weight.shape[0], qkv_weight.shape[1]],
  370. vals=qkv_weight,
  371. )
  372. # Created packed QKV MatMul with output (B, S, 3*D)
  373. # Output is of the form:
  374. #
  375. # [[[Q Q ... Q Q K K ... K K V V ... V V]]]
  376. # [Q Q ... Q Q K K ... K K V V ... V V]
  377. # .
  378. # .
  379. # .
  380. # [[Q Q ... Q Q K K ... K K V V ... V V]
  381. # [Q Q ... Q Q K K ... K K V V ... V V]]]
  382. qkv_matmul_output = matmul_node_name + "_qkv_out"
  383. qkv_matmul = helper.make_node(
  384. "MatMul",
  385. inputs=[q_matmul.input[0], qkv_weight_name],
  386. outputs=[qkv_matmul_output],
  387. name=matmul_node_name,
  388. )
  389. self.node_name_to_graph_name[matmul_node_name] = self.this_graph_name
  390. qkv_nodes = [qkv_matmul]
  391. # Create Slice nodes to access Q, K, V
  392. q_slice_name = matmul_node_name + "_q_start_index"
  393. self.add_initializer(name=q_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[0], raw=False)
  394. k_slice_name = matmul_node_name + "_k_start_index"
  395. self.add_initializer(name=k_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[d], raw=False)
  396. v_slice_name = matmul_node_name + "_v_start_index"
  397. self.add_initializer(name=v_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[2 * d], raw=False)
  398. end_of_qkv_name = matmul_node_name + "_end_of_qkv_index"
  399. self.add_initializer(name=end_of_qkv_name, data_type=TensorProto.INT64, dims=[1], vals=[3 * d], raw=False)
  400. qkv_last_axis_name = matmul_node_name + "_qkv_last_axis"
  401. self.add_initializer(name=qkv_last_axis_name, data_type=TensorProto.INT64, dims=[1], vals=[-1], raw=False)
  402. q_slice_output = matmul_node_name + "_q_out"
  403. q_slice = helper.make_node(
  404. "Slice",
  405. inputs=[qkv_matmul_output, q_slice_name, k_slice_name, qkv_last_axis_name],
  406. outputs=[q_slice_output],
  407. name=self.model.create_node_name("Slice"),
  408. )
  409. self.node_name_to_graph_name[q_slice.name] = self.this_graph_name
  410. k_slice_output = matmul_node_name + "_k_out"
  411. k_slice = helper.make_node(
  412. "Slice",
  413. inputs=[qkv_matmul_output, k_slice_name, v_slice_name, qkv_last_axis_name],
  414. outputs=[k_slice_output],
  415. name=self.model.create_node_name("Slice"),
  416. )
  417. self.node_name_to_graph_name[k_slice.name] = self.this_graph_name
  418. v_slice_output = matmul_node_name + "_v_out"
  419. v_slice = helper.make_node(
  420. "Slice",
  421. inputs=[qkv_matmul_output, v_slice_name, end_of_qkv_name, qkv_last_axis_name],
  422. outputs=[v_slice_output],
  423. name=self.model.create_node_name("Slice"),
  424. )
  425. self.node_name_to_graph_name[v_slice.name] = self.this_graph_name
  426. q_output = q_slice
  427. k_output = k_slice
  428. v_output = v_slice
  429. qkv_nodes.extend([q_slice, k_slice, v_slice])
  430. if self.disable_multi_head_attention_bias:
  431. if q_add is not None:
  432. initializer_input = 1 if self.model.get_initializer(q_add.input[1]) else 0
  433. if np.any(NumpyHelper.to_array(self.model.get_initializer(q_add.input[initializer_input]))):
  434. q_add.input[1 - initializer_input] = q_slice_output
  435. q_output = q_add
  436. qkv_nodes.append(q_add)
  437. self.node_name_to_graph_name[q_add.name] = self.this_graph_name
  438. if k_add is not None:
  439. initializer_input = 1 if self.model.get_initializer(k_add.input[1]) else 0
  440. if np.any(NumpyHelper.to_array(self.model.get_initializer(k_add.input[initializer_input]))):
  441. k_add.input[1 - initializer_input] = k_slice_output
  442. k_output = k_add
  443. qkv_nodes.append(k_add)
  444. self.node_name_to_graph_name[k_add.name] = self.this_graph_name
  445. if v_add is not None:
  446. initializer_input = 1 if self.model.get_initializer(v_add.input[1]) else 0
  447. if np.any(NumpyHelper.to_array(self.model.get_initializer(v_add.input[initializer_input]))):
  448. v_add.input[1 - initializer_input] = v_slice_output
  449. v_output = v_add
  450. qkv_nodes.append(v_add)
  451. self.node_name_to_graph_name[v_add.name] = self.this_graph_name
  452. # Add nodes to graph
  453. self.nodes_to_add.extend(qkv_nodes)
  454. return q_output, k_output, v_output
  455. # This function is used in child classes for bart or conformer model.
  456. def create_multihead_attention_node(
  457. self,
  458. q_matmul: NodeProto,
  459. k_matmul: NodeProto | str | None,
  460. v_matmul: NodeProto | str | None,
  461. q_add: NodeProto,
  462. k_add: NodeProto | None,
  463. v_add: NodeProto | None,
  464. num_heads: int,
  465. hidden_size: int,
  466. output: str,
  467. key_padding_mask: str = "",
  468. add_qk: str = "",
  469. unidirectional: bool = False,
  470. past_k: str = "",
  471. past_v: str = "",
  472. present_k: str = "",
  473. present_v: str = "",
  474. packed_qkv: bool = False,
  475. ) -> NodeProto | None:
  476. """Create a MultiHeadAttention node.
  477. Args:
  478. q_matmul (NodeProto): name of MatMul from Q path - (batch_size, sequence_length, hidden_size)
  479. k_matmul (NodeProto): name of MatMul from K path - (batch_size, sequence_length, hidden_size) or (batch_size, num_heads, past_sequence_length, head_size)
  480. v_matmul (NodeProto): name of MatMul from V path - (batch_size, sequence_length, hidden_size) or (batch_size, num_heads, past_sequence_length, head_size)
  481. q_add (NodeProto): name of Add from Q path
  482. k_add (NodeProto): name of Add from K path
  483. v_add (NodeProto): name of Add from V path
  484. num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
  485. hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
  486. output (str): output name of MHA
  487. key_padding_mask (str): name of key padding mask
  488. add_qk (str): name of add after Q x K'
  489. unidirectional (bool): whether to apply causal attention mask automatically or not
  490. past_k (str): name of past K value - (batch_size, num_heads, past_sequence_length, head_size)
  491. past_v (str): name of past V value - (batch_size, num_heads, past_sequence_length, head_size)
  492. present_k (str): name of present K value - (batch_size, num_heads, sequence_length, head_size)
  493. present_v (str): name of present V value - (batch_size, num_heads, sequence_length, head_size)
  494. packed_qkv (bool): whether to combine MatMuls from Q, K, V paths
  495. Note: This is for the scenario where an Attention node should be created but cannot be created
  496. because past_key and past_value are separate inputs and not one concatenated input.
  497. Returns:
  498. Union[NodeProto, None]: the node created or None if failed.
  499. """
  500. # B = batch size, N = num heads, P = past seq len, H = head size
  501. assert num_heads > 0
  502. if hidden_size > 0 and (hidden_size % num_heads) != 0:
  503. logger.debug("input hidden size %d is not a multiple of num of heads %d", hidden_size, num_heads)
  504. return None
  505. graph_input_names = {node.name for node in self.model.graph().input}
  506. mha_node_name = self.model.create_node_name("Attention")
  507. # Add initial Q/K/V inputs for MHA
  508. mha_inputs = []
  509. if packed_qkv:
  510. q_slice, k_slice, v_slice = self.create_packed_qkv_matmul_node(
  511. q_matmul,
  512. k_matmul,
  513. v_matmul,
  514. q_add,
  515. k_add,
  516. v_add,
  517. )
  518. mha_inputs.extend([q_slice.output[0], k_slice.output[0], v_slice.output[0]])
  519. elif isinstance(k_matmul, NodeProto) and isinstance(v_matmul, NodeProto):
  520. if self.disable_multi_head_attention_bias:
  521. mha_inputs.extend([q_add.output[0], k_matmul.output[0], v_add.output[0]])
  522. else:
  523. mha_inputs.extend([q_matmul.output[0], k_matmul.output[0], v_matmul.output[0]])
  524. elif (
  525. isinstance(k_matmul, str)
  526. and isinstance(v_matmul, str)
  527. and k_matmul in graph_input_names
  528. and v_matmul in graph_input_names
  529. ):
  530. if self.disable_multi_head_attention_bias:
  531. mha_inputs.extend([q_add.output[0], k_matmul, v_matmul])
  532. else:
  533. mha_inputs.extend([q_matmul.output[0], k_matmul, v_matmul])
  534. else:
  535. return None
  536. # Add bias to inputs for MHA
  537. # Bias for cross attention is not fully supported in DMMHA and cpu MHA kernels since they assume
  538. # bias has been added to key and value when they are in BNSH format, so only bias for query is used.
  539. # Need add checks if we found such assumption is not true.
  540. if not self.disable_multi_head_attention_bias:
  541. bias_name = self.create_combined_qkv_bias(q_add, k_add, v_add, mha_node_name)
  542. mha_inputs.append(bias_name)
  543. else:
  544. mha_inputs.append("")
  545. # Add optional inputs for MHA
  546. if past_k and past_v:
  547. mha_inputs.extend([key_padding_mask, add_qk, past_k, past_v])
  548. elif key_padding_mask or add_qk:
  549. mha_inputs.extend([key_padding_mask, add_qk])
  550. # Add outputs for MHA
  551. mha_outputs = [output]
  552. if present_k and present_v:
  553. mha_outputs.extend([present_k, present_v])
  554. mha_node = helper.make_node(
  555. "MultiHeadAttention",
  556. inputs=mha_inputs,
  557. outputs=mha_outputs,
  558. name=mha_node_name,
  559. )
  560. mha_node.domain = "com.microsoft"
  561. mha_node.attribute.append(helper.make_attribute("num_heads", num_heads))
  562. if unidirectional:
  563. mha_node.attribute.append(helper.make_attribute("unidirectional", int(unidirectional)))
  564. self.increase_counter("MultiHeadAttention")
  565. return mha_node
  566. def create_attention_node(
  567. self,
  568. mask_index: str | None,
  569. q_matmul: NodeProto,
  570. k_matmul: NodeProto,
  571. v_matmul: NodeProto,
  572. q_add: NodeProto,
  573. k_add: NodeProto,
  574. v_add: NodeProto,
  575. num_heads: int,
  576. hidden_size: int,
  577. first_input: str,
  578. output: str,
  579. add_qk_str: str = "",
  580. causal: bool = False,
  581. past_k: str = "",
  582. past_v: str = "",
  583. present_k: str = "",
  584. present_v: str = "",
  585. scale: float | None = None,
  586. ) -> NodeProto | None:
  587. """Create an Attention node.
  588. Args:
  589. mask_index (str | None): mask input
  590. q_matmul (NodeProto): MatMul node in fully connection for Q
  591. k_matmul (NodeProto): MatMul node in fully connection for K
  592. v_matmul (NodeProto): MatMul node in fully connection for V
  593. q_add (NodeProto): Add bias node in fully connection for Q
  594. k_add (NodeProto): Add bias node in fully connection for K
  595. v_add (NodeProto): Add bias node in fully connection for V
  596. num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
  597. hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
  598. first_input (str): first input name
  599. output (str): output name
  600. add_qk_str (str): name of Add node after Q x K'
  601. causal: whether it is uni-directional mask.
  602. past_k (str): name of input for past K value
  603. past_v (str): name of input for past V value
  604. present_k (str): name of output to store present K value
  605. present_v (str): name of output to store present V value
  606. scale: scale before softmax
  607. Returns:
  608. Union[NodeProto, None]: the node created or None if failed.
  609. """
  610. assert num_heads > 0
  611. if hidden_size > 0 and (hidden_size % num_heads) != 0:
  612. logger.debug("input hidden size %d is not a multiple of num of heads %d", hidden_size, num_heads)
  613. return None
  614. has_bias = True
  615. if q_add is None and k_add is None and v_add is None:
  616. has_bias = False
  617. q_weight = self.model.get_initializer(q_matmul.input[1])
  618. k_weight = self.model.get_initializer(k_matmul.input[1])
  619. v_weight = self.model.get_initializer(v_matmul.input[1])
  620. q_bias, k_bias, v_bias = None, None, None
  621. if has_bias:
  622. q_bias = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0])
  623. k_bias = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0])
  624. v_bias = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0])
  625. if not (k_weight and v_weight and q_bias and k_bias):
  626. return None
  627. if q_weight is None:
  628. print(
  629. f"{q_matmul.input[1]} is not an initializer. "
  630. "Please set do_constant_folding=True in torch.onnx.export to unblock attention fusion"
  631. )
  632. return None
  633. qw = NumpyHelper.to_array(q_weight)
  634. kw = NumpyHelper.to_array(k_weight)
  635. vw = NumpyHelper.to_array(v_weight)
  636. # assert q and k have same shape as expected
  637. assert qw.shape == kw.shape
  638. qw_in_size = qw.shape[0]
  639. kw_in_size = kw.shape[0]
  640. vw_in_size = vw.shape[0]
  641. assert qw_in_size == kw_in_size == vw_in_size
  642. if hidden_size > 0 and hidden_size != qw_in_size:
  643. logger.warning(
  644. "Input hidden size (%d) is not same as weight matrix dimension of q,k,v (%d). "
  645. "Please provide a correct input hidden size or pass in 0",
  646. hidden_size,
  647. qw_in_size,
  648. )
  649. is_qkv_diff_dims = False
  650. if qw.shape != vw.shape:
  651. is_qkv_diff_dims = True
  652. # All the matrices can have the same shape or q, k matrices can have the same shape with v being different
  653. # For 2d weights, the shapes would be [in_size, out_size].
  654. # For 3d weights, shape would be [in_size, a, b] where a*b = out_size
  655. qw_out_size = np.prod(qw.shape[1:])
  656. kw_out_size = np.prod(kw.shape[1:])
  657. vw_out_size = np.prod(vw.shape[1:])
  658. qkv_weight_dim = 0
  659. if is_qkv_diff_dims:
  660. qkv_weight = np.concatenate((qw, kw, vw), axis=1)
  661. qkv_weight_dim = qw_out_size + kw_out_size + vw_out_size
  662. else:
  663. qkv_weight = np.stack((qw, kw, vw), axis=1)
  664. qkv_weight_dim = 3 * qw_out_size
  665. qkv_bias_dim = 0
  666. qkv_bias: np.ndarray | None = None
  667. if has_bias:
  668. qb = NumpyHelper.to_array(q_bias)
  669. kb = NumpyHelper.to_array(k_bias)
  670. vb = NumpyHelper.to_array(v_bias)
  671. q_bias_shape = np.prod(qb.shape)
  672. k_bias_shape = np.prod(kb.shape)
  673. v_bias_shape = np.prod(vb.shape)
  674. assert q_bias_shape == k_bias_shape == qw_out_size
  675. assert v_bias_shape == vw_out_size
  676. if is_qkv_diff_dims:
  677. qkv_bias = np.concatenate((qb, kb, vb), axis=0)
  678. qkv_bias_dim = q_bias_shape + k_bias_shape + v_bias_shape
  679. else:
  680. qkv_bias = np.stack((qb, kb, vb), axis=0)
  681. qkv_bias_dim = 3 * q_bias_shape
  682. attention_node_name = self.model.create_node_name("Attention")
  683. if not self.use_multi_head_attention:
  684. self.add_initializer(
  685. name=attention_node_name + "_qkv_weight",
  686. data_type=q_weight.data_type,
  687. dims=[qw_in_size, int(qkv_weight_dim)],
  688. vals=qkv_weight,
  689. )
  690. if has_bias:
  691. self.add_initializer(
  692. name=attention_node_name + "_qkv_bias",
  693. data_type=q_bias.data_type,
  694. dims=[int(qkv_bias_dim)],
  695. vals=qkv_bias,
  696. )
  697. # For MultiHeadAttention operator, use separated inputs for query, key and value, and no weights.
  698. if self.use_multi_head_attention:
  699. if add_qk_str:
  700. logger.debug("MultiHeadAttention does not support relative_position_bias: cannot fuse the attention.")
  701. return None
  702. attention_inputs = [
  703. q_matmul.output[0],
  704. k_matmul.output[0],
  705. v_matmul.output[0],
  706. attention_node_name + "_qkv_bias",
  707. ]
  708. if mask_index is not None:
  709. attention_inputs.append(mask_index)
  710. attention_node = helper.make_node(
  711. "MultiHeadAttention",
  712. inputs=attention_inputs,
  713. outputs=[output],
  714. name=attention_node_name,
  715. )
  716. self.increase_counter("MultiHeadAttention")
  717. else:
  718. attention_inputs = [
  719. first_input,
  720. attention_node_name + "_qkv_weight",
  721. attention_node_name + "_qkv_bias" if has_bias else "",
  722. ]
  723. if mask_index is not None:
  724. attention_inputs.append(mask_index)
  725. else:
  726. attention_inputs.append("")
  727. past_exists = past_k and past_v
  728. if past_exists:
  729. past_kv = self.concat_kv(past_k, past_v)
  730. attention_inputs.append(past_kv)
  731. if add_qk_str:
  732. # Add additional add to attention node (input name = attention_bias)
  733. if not past_exists:
  734. attention_inputs.append("")
  735. attention_inputs.append(add_qk_str)
  736. attention_outputs = [output]
  737. if present_k and present_v:
  738. present_kv = present_k.replace(".key", "").replace("_key", "").replace(".", "_")
  739. attention_outputs.append(present_kv)
  740. self.split_kv(present_k, present_v, present_kv)
  741. attention_node = helper.make_node(
  742. "Attention",
  743. inputs=attention_inputs,
  744. outputs=attention_outputs,
  745. name=attention_node_name,
  746. )
  747. self.increase_counter("Attention")
  748. attention_node.domain = "com.microsoft"
  749. attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
  750. if causal:
  751. attention_node.attribute.extend([helper.make_attribute("unidirectional", 1)])
  752. if scale is not None:
  753. attention_node.attribute.extend([helper.make_attribute("scale", scale)])
  754. if is_qkv_diff_dims:
  755. attention_node.attribute.extend(
  756. [helper.make_attribute("qkv_hidden_sizes", [qw_out_size, kw_out_size, vw_out_size])]
  757. )
  758. if self.mask_filter_value is not None:
  759. attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
  760. return attention_node
  761. def fuse(self, node, input_name_to_nodes, output_name_to_node):
  762. # Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm
  763. # Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern
  764. normalize_node = node
  765. start_node = normalize_node
  766. if normalize_node.op_type == "LayerNormalization":
  767. add_before_layernorm = self.model.match_parent(normalize_node, "Add", 0)
  768. if add_before_layernorm is not None:
  769. start_node = add_before_layernorm
  770. else:
  771. return
  772. # SkipLayerNormalization has two inputs, and one of them is the root input for attention.
  773. qkv_nodes = self.model.match_parent_path(
  774. start_node,
  775. ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
  776. [None, None, 0, 0, 0],
  777. )
  778. einsum_node = None
  779. if qkv_nodes is not None:
  780. (_, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
  781. else:
  782. # Match Albert
  783. qkv_nodes = self.model.match_parent_path(
  784. start_node, ["Add", "Einsum", "Transpose", "MatMul"], [1, None, 0, 0]
  785. )
  786. if qkv_nodes is not None:
  787. (_, einsum_node, transpose_qkv, matmul_qkv) = qkv_nodes
  788. else:
  789. return
  790. other_inputs = []
  791. for _i, node_input in enumerate(start_node.input):
  792. if node_input not in output_name_to_node:
  793. continue
  794. if node_input == qkv_nodes[0].output[0]:
  795. continue
  796. other_inputs.append(node_input)
  797. if len(other_inputs) != 1:
  798. return
  799. root_input = other_inputs[0]
  800. # Match flaubert Mask
  801. # |
  802. # Mul --> LayerNormalization --> Attention --> MatMul --> Add
  803. # | |
  804. # | |
  805. # +---------------------------------------------------------
  806. mul_before_layernorm = self.model.match_parent(start_node, "Mul", 0)
  807. if mul_before_layernorm is not None:
  808. mul_children = input_name_to_nodes[mul_before_layernorm.output[0]]
  809. if mul_children is not None and len(mul_children) == 2:
  810. layernorm_node = mul_children[1]
  811. if layernorm_node.op_type == "LayerNormalization":
  812. root_input = layernorm_node.output[0]
  813. else:
  814. return
  815. elif mul_children is not None and len(mul_children) == 5:
  816. root_input = mul_before_layernorm.output[0]
  817. else:
  818. return
  819. elif normalize_node.op_type == "LayerNormalization":
  820. children = input_name_to_nodes[root_input]
  821. for child in children:
  822. if child.op_type == "LayerNormalization":
  823. root_input = child.output[0]
  824. # When Add before the LayerNormalization produces an output
  825. # that is consumed by some other nodes other than the LayerNormalization itself,
  826. # fused SkipLayerNormalization will have several outputs.
  827. # In this case we need to pick the one used in Attention
  828. # For example, this is the case for ViT
  829. # SkipLayerNormalization --> Attention --> MatMul --> Add --> SkipLayerNormalization
  830. # | |
  831. # | |
  832. # +---------------------------------------------------------------------+
  833. parent_node = output_name_to_node[root_input]
  834. if parent_node.op_type == "SkipLayerNormalization" and len(parent_node.output) == 4:
  835. root_input = parent_node.output[0]
  836. children = input_name_to_nodes[root_input]
  837. children_types = [child.op_type for child in children]
  838. if children_types.count("MatMul") != 3:
  839. return
  840. v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None])
  841. if v_nodes is None:
  842. logger.debug("fuse_attention: failed to match v path")
  843. return
  844. (_, _, add_v, matmul_v) = v_nodes
  845. is_distill = False
  846. is_distill_add = False
  847. is_no_mask_attention = False
  848. is_sdpa = False
  849. qk_paths = {
  850. "path1": (["Softmax", "Add", "Div", "MatMul"], [0, 0, None, 0]),
  851. "path2": (["Softmax", "Add", "Mul", "MatMul"], [0, 0, None, 0]),
  852. "path3": (["Softmax", "Where", "MatMul", "Div"], [0, 0, 2, 0]),
  853. "path4": (["Softmax", "Add", "Where", "MatMul"], [0, 0, 0, 2]),
  854. "path5": (["Softmax", "Div", "MatMul"], [0, 0, 0]),
  855. "sdpa": (["Softmax", "Add", "MatMul", "Mul", "Sqrt"], [0, 0, None, 0, 1]),
  856. }
  857. qk_nodes = None
  858. for k, v in qk_paths.items():
  859. qk_nodes = self.model.match_parent_path(matmul_qkv, v[0], v[1])
  860. if qk_nodes is None:
  861. continue
  862. if k == "path3":
  863. is_distill = True
  864. elif k == "path4":
  865. is_distill_add = True
  866. elif k == "path5":
  867. is_no_mask_attention = True
  868. elif k == "sdpa":
  869. is_sdpa = True
  870. break
  871. if qk_nodes is None:
  872. logger.debug("fuse_attention: failed to match qk path")
  873. return
  874. add_qk = None
  875. matmul_qk = None
  876. where_qk = None
  877. after_q = None
  878. if is_distill:
  879. (_, where_qk, matmul_qk, _) = qk_nodes
  880. elif is_distill_add:
  881. (_, add_qk, where_qk, matmul_qk) = qk_nodes
  882. elif is_no_mask_attention:
  883. (_, _, matmul_qk) = qk_nodes
  884. elif is_sdpa:
  885. (_, add_qk, matmul_qk, after_q, _) = qk_nodes
  886. else:
  887. (_, add_qk, _, matmul_qk) = qk_nodes
  888. after_q = after_q or matmul_qk
  889. q_nodes = self.model.match_parent_path(after_q, ["Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, None])
  890. if q_nodes is None:
  891. q_nodes = self.model.match_parent_path(
  892. after_q,
  893. ["Div", "Transpose", "Reshape", "Add", "MatMul"],
  894. [0, 0, 0, 0, None],
  895. )
  896. if q_nodes is None:
  897. logger.debug("fuse_attention: failed to match q path")
  898. return
  899. reshape_q = q_nodes[-3]
  900. add_q = q_nodes[-2]
  901. matmul_q = q_nodes[-1]
  902. after_k = matmul_qk
  903. if is_sdpa:
  904. mul_k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Sqrt"], [1, None])
  905. if mul_k_nodes is None:
  906. logger.debug("fuse_attention: failed to match mul sqrt q path")
  907. return
  908. (after_k, _) = mul_k_nodes
  909. k_nodes = self.model.match_parent_path(
  910. after_k, ["Transpose", "Reshape", "Add", "MatMul"], [0 if is_sdpa else 1, 0, 0, None]
  911. )
  912. if k_nodes is None:
  913. k_nodes = self.model.match_parent_path(
  914. matmul_qk,
  915. ["Transpose", "Transpose", "Reshape", "Add", "MatMul"],
  916. [1, 0, 0, 0, None],
  917. )
  918. if k_nodes is None:
  919. logger.debug("fuse_attention: failed to match k path")
  920. return
  921. add_k = k_nodes[-2]
  922. matmul_k = k_nodes[-1]
  923. # Note that Cast might be removed by OnnxRuntime so we match two patterns here.
  924. mask_nodes = None
  925. add_qk_str = ""
  926. if is_distill:
  927. _, mask_nodes, _ = self.model.match_parent_paths(
  928. where_qk,
  929. [
  930. (["Expand", "Reshape", "Equal"], [0, 0, 0]),
  931. (["Equal", "Unsqueeze", "Unsqueeze"], [0, 0, 0]),
  932. (["Cast", "Expand", "Reshape", "Equal"], [0, 0, 0, 0]),
  933. ],
  934. output_name_to_node,
  935. )
  936. elif is_distill_add:
  937. _, mask_nodes, _ = self.model.match_parent_paths(
  938. where_qk,
  939. [
  940. (["Cast", "Equal", "Unsqueeze", "Unsqueeze"], [0, 0, 0, 0]),
  941. (["Equal", "Unsqueeze", "Unsqueeze"], [0, 0, 0]),
  942. ],
  943. output_name_to_node,
  944. )
  945. if add_qk is not None:
  946. add_qk_str = self.get_add_qk_str(add_qk)
  947. if add_qk_str is None:
  948. logger.debug("fuse_attention: failed to verify shape inference of %s", add_qk)
  949. return
  950. elif is_no_mask_attention:
  951. pass
  952. else:
  953. _, mask_nodes, _ = self.model.match_parent_paths(
  954. add_qk,
  955. [
  956. (["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], [None, 0, 1, 0, 0]),
  957. (["Mul", "Sub", "Unsqueeze", "Unsqueeze"], [None, 0, 1, 0]),
  958. # The following two patterns are for SDPA.
  959. (["Where", "Cast", "Sub", "Expand", "Unsqueeze", "Unsqueeze"], [None, 0, 0, 1, 0, 0]),
  960. (["Where", "Cast", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], [None, 0, 0, 1, 0, 0, 0]),
  961. ],
  962. output_name_to_node,
  963. )
  964. if not is_no_mask_attention and mask_nodes is None:
  965. logger.debug("fuse_attention: failed to match mask path")
  966. return
  967. if not is_no_mask_attention and len(mask_nodes) > 1:
  968. _, mul_val = self.model.get_constant_input(mask_nodes[0])
  969. # The mask value shall be a float scalar (usually is the lowest float value).
  970. if (
  971. (mul_val is None)
  972. or not (isinstance(mul_val, np.ndarray) and mul_val.size == 1)
  973. or (mul_val.item() >= 0)
  974. ):
  975. return
  976. if mul_val.item() != -10000:
  977. self.mask_filter_value = mul_val.item()
  978. if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_k.input[0] == root_input:
  979. mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) if not is_no_mask_attention else None
  980. attention_last_node = reshape_qkv if einsum_node is None else transpose_qkv
  981. q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
  982. if q_num_heads <= 0 or q_hidden_size <= 0:
  983. logger.warning(
  984. "Failed to detect num_heads and hidden_size for Attention fusion. "
  985. "Please specify those parameters in argument."
  986. )
  987. return
  988. # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
  989. # the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately
  990. new_node = self.create_attention_node(
  991. mask_index=mask_index,
  992. q_matmul=matmul_q,
  993. k_matmul=matmul_k,
  994. v_matmul=matmul_v,
  995. q_add=add_q,
  996. k_add=add_k,
  997. v_add=add_v,
  998. num_heads=q_num_heads,
  999. hidden_size=q_hidden_size,
  1000. first_input=root_input,
  1001. output=attention_last_node.output[0],
  1002. add_qk_str=add_qk_str,
  1003. )
  1004. if new_node is None:
  1005. return
  1006. self.nodes_to_add.append(new_node)
  1007. self.node_name_to_graph_name[new_node.name] = self.this_graph_name
  1008. if einsum_node is not None:
  1009. unique_index = einsum_node.input[0]
  1010. new_edge = "edge_modified_" + unique_index
  1011. shape_tensor = self.add_initializer(
  1012. name="shape_modified_tensor" + unique_index,
  1013. data_type=TensorProto.INT64,
  1014. dims=[4],
  1015. vals=[0, 0, q_num_heads, int(q_hidden_size / q_num_heads)],
  1016. raw=False,
  1017. )
  1018. self.model.add_node(
  1019. helper.make_node(
  1020. "Reshape",
  1021. [attention_last_node.output[0], shape_tensor.name],
  1022. [new_edge],
  1023. "reshape_modified_" + unique_index,
  1024. ),
  1025. self.this_graph_name,
  1026. )
  1027. einsum_node.input[0] = new_edge
  1028. self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
  1029. self.nodes_to_remove.extend(qk_nodes)
  1030. # For MultiHeadAttention operator, MatMul nodes for Q/K/V projection shall not be fused.
  1031. self.nodes_to_remove.extend(q_nodes if not self.use_multi_head_attention else q_nodes[:-1])
  1032. self.nodes_to_remove.extend(k_nodes if not self.use_multi_head_attention else k_nodes[:-1])
  1033. self.nodes_to_remove.extend(v_nodes if not self.use_multi_head_attention else v_nodes[:-1])
  1034. # Use prune graph to remove mask nodes since they are shared by all attention nodes.
  1035. self.prune_graph = True