fusion_bart_attention.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import logging
  6. import numpy as np
  7. from fusion_attention import AttentionMask, FusionAttention
  8. from onnx import helper
  9. from onnx_model import OnnxModel
  10. logger = logging.getLogger(__name__)
  11. class FusionBartAttention(FusionAttention):
  12. """
  13. Fuse Bart Attention subgraph into one Attention node.
  14. """
  15. def __init__(
  16. self,
  17. model: OnnxModel,
  18. hidden_size: int,
  19. num_heads: int,
  20. attention_mask: AttentionMask,
  21. ):
  22. super().__init__(model, hidden_size, num_heads, attention_mask)
  23. def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
  24. # SkipLayerNormalization has two inputs, and one of them is the root input for attention.
  25. qkv_nodes = self.model.match_parent_path(
  26. normalize_node,
  27. ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
  28. [1, 1, 0, 0, 0],
  29. )
  30. if qkv_nodes is not None:
  31. (
  32. add_out,
  33. matmul_out,
  34. reshape_qkv,
  35. transpose_qkv,
  36. matmul_qkv,
  37. ) = qkv_nodes
  38. else:
  39. logger.debug("fuse_attention: failed to match qkv path")
  40. return
  41. other_inputs = []
  42. for input_ in normalize_node.input:
  43. if input_ not in output_name_to_node:
  44. continue
  45. if input_ == qkv_nodes[0].output[0]:
  46. continue
  47. other_inputs.append(input_)
  48. if len(other_inputs) != 1:
  49. return
  50. root_input = other_inputs[0]
  51. # Sometimes the input name to the attention MatMul nodes does not match the input name to the end
  52. # SkipLayerNormalization node (name saved in root_input). We find the true input name to the MatMul
  53. # nodes by getting the initial SkipLayerNormalization node and checking how many MatMul nodes are
  54. # children nodes for each of its output names.
  55. """
  56. root_input
  57. +---------------------------------------------------+
  58. | |
  59. | |
  60. SkipLayerNormalization --> Attention --> MatMul --> SkipLayerNormalization
  61. """
  62. skip_layernorm = output_name_to_node[root_input]
  63. # For some attention blocks, the end SkipLayerNormalization node may point to another node whose
  64. # child is the LayerNormalization node.
  65. if skip_layernorm.op_type in {"Add", "Clip"}:
  66. skip_layernorm = self.model.get_children(skip_layernorm)[0]
  67. for output in skip_layernorm.output:
  68. if not output:
  69. continue
  70. children = input_name_to_nodes[output]
  71. children_types = [child.op_type for child in children]
  72. if children_types.count("MatMul") >= 1:
  73. root_input = output
  74. break
  75. graph_input_names = {node.name for node in self.model.graph().input}
  76. graph_output_names = {node.name for node in self.model.graph().output}
  77. v_nodes_past_or_present = self.model.match_parent_path(
  78. matmul_qkv,
  79. ["Transpose", "Reshape", "Add", "MatMul"],
  80. [1, 0, 0, None],
  81. )
  82. v_nodes_with_past = self.model.match_parent_path(
  83. matmul_qkv,
  84. ["Concat", "Transpose", "Reshape", "Add", "MatMul"],
  85. [1, 1, 0, 0, None],
  86. )
  87. v_nodes_past_only_oai = self.model.match_parent_path(
  88. matmul_qkv,
  89. ["Transpose", "Reshape", "Reshape", "Transpose"],
  90. [1, 0, 0, 0],
  91. )
  92. past_v, present_v = "", ""
  93. v_nodes, add_v, matmul_v = [], None, None
  94. if v_nodes_past_or_present is not None:
  95. v_nodes = v_nodes_past_or_present
  96. (transpose_v, reshape_v, add_v, matmul_v) = v_nodes
  97. # Find past_v input name
  98. start_child_nodes = input_name_to_nodes[add_v.output[0]]
  99. for start_child_node in start_child_nodes:
  100. if start_child_node.op_type == "Concat":
  101. concat_v_nodes = self.model.match_parent_path(
  102. start_child_node,
  103. ["Reshape", "Transpose"],
  104. [0, 0],
  105. )
  106. if concat_v_nodes is not None:
  107. past_v = concat_v_nodes[-1].input[0]
  108. start_child_nodes = input_name_to_nodes[start_child_node.output[0]]
  109. break
  110. # Find present_v output name
  111. for start_child_node in start_child_nodes:
  112. start_grandchild_nodes = input_name_to_nodes[start_child_node.output[0]]
  113. for start_grandchild_node in start_grandchild_nodes:
  114. if start_grandchild_node.output[0] in graph_output_names:
  115. present_v = start_grandchild_node.output[0]
  116. break
  117. if present_v != "":
  118. break
  119. elif v_nodes_with_past is not None:
  120. v_nodes = v_nodes_with_past
  121. (concat_v, transpose_v, reshape_v, add_v, matmul_v) = v_nodes
  122. past_v = concat_v.input[0]
  123. present_v = concat_v.output[0]
  124. elif matmul_qkv.input[1] in graph_input_names:
  125. # Hugging Face's cross-attention where past_v is used directly as value
  126. past_v = matmul_qkv.input[1]
  127. elif v_nodes_past_only_oai is not None:
  128. # OpenAI's cross-attention where past_v is used directly as value
  129. v_nodes = v_nodes_past_only_oai
  130. past_v = v_nodes[-1].input[0]
  131. else:
  132. logger.debug("fuse_attention: failed to match v path")
  133. return
  134. past_v = past_v if past_v in graph_input_names else ""
  135. present_v = present_v if present_v in graph_output_names else ""
  136. qk_nodes_no_mask = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
  137. qk_nodes_with_mask = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])
  138. qk_nodes, add_qk = [], None
  139. if qk_nodes_no_mask is not None:
  140. _, matmul_qk = qk_nodes_no_mask
  141. qk_nodes = qk_nodes_no_mask
  142. elif qk_nodes_with_mask is not None:
  143. _, add_qk, matmul_qk = qk_nodes_with_mask
  144. qk_nodes = qk_nodes_with_mask
  145. else:
  146. logger.debug("fuse_attention: failed to match qk path")
  147. return
  148. q_nodes_hf = self.model.match_parent_path(
  149. matmul_qk,
  150. ["Transpose", "Reshape", "Mul", "Add", "MatMul"],
  151. [0, 0, 0, 0, 1],
  152. )
  153. q_nodes_oai = self.model.match_parent_path(
  154. matmul_qk,
  155. ["Mul", "Transpose", "Reshape", "Add", "MatMul"],
  156. [0, 0, 0, 0, 1],
  157. )
  158. q_nodes = []
  159. if q_nodes_hf is not None:
  160. q_nodes = q_nodes_hf
  161. (transpose_q, reshape_q, mul_q, add_q, matmul_q) = q_nodes
  162. elif q_nodes_oai is not None:
  163. q_nodes = q_nodes_oai
  164. (mul_q, transpose_q, reshape_q, add_q, matmul_q) = q_nodes
  165. else:
  166. logger.debug("fuse_attention: failed to match q path")
  167. return
  168. k_nodes_no_past_hf = self.model.match_parent_path(
  169. matmul_qk,
  170. ["Transpose", "Reshape", "MatMul"],
  171. [1, 0, 0],
  172. )
  173. k_nodes_with_past_hf = self.model.match_parent_path(
  174. matmul_qk,
  175. ["Transpose", "Concat", "Transpose", "Reshape", "MatMul"],
  176. [1, 0, 1, 0, 0],
  177. )
  178. k_nodes_past_or_present_oai = self.model.match_parent_path(
  179. matmul_qk,
  180. ["Mul", "Transpose", "Reshape", "MatMul"],
  181. [1, 0, 0, 0],
  182. )
  183. k_nodes_past_only_oai = self.model.match_parent_path(
  184. matmul_qk,
  185. ["Mul", "Transpose", "Reshape", "Reshape", "Transpose"],
  186. [1, 0, 0, 0, 0],
  187. )
  188. past_k, present_k = "", ""
  189. k_nodes, add_k, matmul_k = [], None, None
  190. if k_nodes_no_past_hf is not None:
  191. k_nodes = k_nodes_no_past_hf
  192. (transpose_k, reshape_k, matmul_k) = k_nodes
  193. # Find present_k output name
  194. transpose_k_nodes = input_name_to_nodes[reshape_k.output[0]]
  195. for transpose_k_node in transpose_k_nodes:
  196. if transpose_k_node.output[0] in graph_output_names:
  197. present_k = transpose_k_node.output[0]
  198. break
  199. elif k_nodes_with_past_hf is not None:
  200. k_nodes = k_nodes_with_past_hf
  201. (_, concat_k, transpose_k, reshape_k, matmul_k) = k_nodes
  202. past_k = concat_k.input[0]
  203. present_k = concat_k.output[0]
  204. elif output_name_to_node[matmul_qk.input[1]].input[0] in graph_input_names:
  205. # Hugging Face's cross-attention where past_k is used directly as key
  206. k_nodes = [output_name_to_node[matmul_qk.input[1]]]
  207. past_k = k_nodes[0].input[0]
  208. elif k_nodes_past_or_present_oai is not None:
  209. k_nodes = k_nodes_past_or_present_oai
  210. (_, transpose_k, reshape_k, matmul_k) = k_nodes
  211. # Find past_k input name
  212. start_child_nodes = input_name_to_nodes[matmul_k.output[0]]
  213. for start_child_node in start_child_nodes:
  214. if start_child_node.op_type == "Concat":
  215. concat_k_nodes = self.model.match_parent_path(
  216. start_child_node,
  217. ["Reshape", "Transpose"],
  218. [0, 0],
  219. )
  220. if concat_k_nodes is not None:
  221. past_k = concat_k_nodes[-1].input[0]
  222. start_child_nodes = input_name_to_nodes[start_child_node.output[0]]
  223. break
  224. # Find present_k output name
  225. for start_child_node in start_child_nodes:
  226. start_grandchild_nodes = input_name_to_nodes[start_child_node.output[0]]
  227. for start_grandchild_node in start_grandchild_nodes:
  228. if start_grandchild_node.output[0] in graph_output_names:
  229. present_k = start_grandchild_node.output[0]
  230. break
  231. if present_k != "":
  232. break
  233. elif k_nodes_past_only_oai is not None:
  234. # OpenAI's cross-attention where past_k is used directly as key
  235. k_nodes = k_nodes_past_only_oai
  236. past_k = k_nodes[-1].input[0]
  237. else:
  238. logger.debug("fuse_attention: failed to match k path")
  239. return
  240. past_k = past_k if past_k in graph_input_names else ""
  241. present_k = present_k if present_k in graph_output_names else ""
  242. if matmul_k is not None and add_k is None:
  243. # Create empty Add node for attention graph
  244. add_v_tensor = self.model.get_initializer(add_v.input[0])
  245. bias_dim = add_v_tensor.dims[0]
  246. dtype = add_v_tensor.data_type
  247. empty_bias_name = "empty_bias"
  248. empty_tensor = self.model.get_initializer(empty_bias_name)
  249. if empty_tensor is None:
  250. self.add_initializer(
  251. empty_bias_name,
  252. dtype,
  253. dims=[bias_dim],
  254. vals=np.array([0.0] * bias_dim, dtype=helper.tensor_dtype_to_np_dtype(dtype)),
  255. )
  256. add_name = self.model.create_node_name("Add")
  257. add_k = helper.make_node("Add", [empty_bias_name, matmul_k.output[0]], [reshape_k.name], add_name)
  258. three_root_inputs = bool(past_k) and bool(past_v) and matmul_k is None and matmul_v is None
  259. one_root_input = (
  260. not three_root_inputs
  261. and matmul_q.input[0] == root_input
  262. and matmul_k.input[0] == root_input
  263. and matmul_v.input[0] == root_input
  264. )
  265. two_root_inputs = (
  266. not three_root_inputs
  267. and matmul_q.input[0] == root_input
  268. and matmul_k.input[0] == matmul_v.input[0]
  269. and matmul_k.input[0] != matmul_q.input[0]
  270. )
  271. # There are 5 types of attention:
  272. # 1) Encoder attention with one_root_input=True and qk_nodes=qk_nodes_no_mask
  273. # 2) Decoder self attention with one_root_input=True and qk_nodes=qk_nodes_with_mask
  274. # 3) Decoder cross attention with two_root_inputs=True and qk_nodes=qk_nodes_no_mask
  275. # 4) Decoder self attention with past with one_root_input=True and qk_nodes=qk_nodes_with_mask and past_k=past_decoder_key and past_v=past_decoder_value
  276. # 5) Decoder cross attention with past with three_root_inputs=True and qk_nodes=qk_nodes_no_mask
  277. encoder_attention = one_root_input and qk_nodes == qk_nodes_no_mask
  278. decoder_self_attention = one_root_input and qk_nodes == qk_nodes_with_mask
  279. decoder_cross_attention = two_root_inputs and qk_nodes == qk_nodes_no_mask
  280. decoder_self_attention_with_past = decoder_self_attention and bool(past_k) and bool(past_v)
  281. decoder_cross_attention_with_past = three_root_inputs and qk_nodes == qk_nodes_no_mask
  282. # For decoder self-attentions, the attention mask needs to be included in the attention node
  283. causal_mask = qk_nodes == qk_nodes_with_mask
  284. mask_nodes = []
  285. if causal_mask:
  286. mask_nodes_bart = self.model.match_parent_path(
  287. add_qk,
  288. ["Where"],
  289. [1],
  290. )
  291. mask_nodes_whisper_hf = self.model.match_parent_path(
  292. add_qk,
  293. ["Slice", "Expand", "Where"],
  294. [1, 0, 1],
  295. )
  296. mask_nodes_whisper_oai = self.model.match_parent_path(
  297. add_qk,
  298. ["Slice", "Unsqueeze", "Gather", "Shape", "Add"],
  299. [1, 2, 0, 0, 0],
  300. )
  301. mask_nodes_whisper_oai_unit_test = self.model.match_parent_path(
  302. add_qk,
  303. ["Slice", "Slice"],
  304. [1, 0],
  305. )
  306. if mask_nodes_whisper_hf is not None:
  307. mask_nodes = mask_nodes_whisper_hf
  308. elif mask_nodes_whisper_oai is not None:
  309. mask_nodes = mask_nodes_whisper_oai
  310. elif mask_nodes_whisper_oai_unit_test is not None:
  311. mask_nodes = mask_nodes_whisper_oai_unit_test
  312. elif mask_nodes_bart is not None:
  313. mask_nodes = mask_nodes_bart
  314. else:
  315. logger.debug("fuse_attention: failed to match mask nodes")
  316. return
  317. assert len(mask_nodes) > 0
  318. if (
  319. encoder_attention
  320. or decoder_self_attention
  321. or decoder_cross_attention
  322. or decoder_self_attention_with_past
  323. or decoder_cross_attention_with_past
  324. ):
  325. attention_last_node = reshape_qkv
  326. num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
  327. if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0:
  328. logger.debug("fuse_attention: failed to detect num_heads or hidden_size")
  329. return
  330. new_node = None
  331. if decoder_self_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past:
  332. # Note: Decoder attention with past key and past value is fused as multi-head attention
  333. # rather than attention because multi-head attention supports separate past key and past
  334. # value whereas attention supports concatenated past key and past value.
  335. new_node = (
  336. self.create_multihead_attention_node(
  337. q_matmul=matmul_q,
  338. k_matmul=matmul_k if decoder_cross_attention or decoder_self_attention_with_past else past_k,
  339. v_matmul=matmul_v if decoder_cross_attention or decoder_self_attention_with_past else past_v,
  340. q_add=add_q,
  341. k_add=add_k if decoder_cross_attention or decoder_self_attention_with_past else None,
  342. v_add=add_v if decoder_cross_attention or decoder_self_attention_with_past else None,
  343. num_heads=num_heads,
  344. hidden_size=hidden_size,
  345. output=attention_last_node.output[0],
  346. unidirectional=causal_mask,
  347. past_k=past_k if decoder_self_attention_with_past else "",
  348. past_v=past_v if decoder_self_attention_with_past else "",
  349. present_k=present_k,
  350. present_v=present_v,
  351. )
  352. if self.use_multi_head_attention
  353. else None
  354. )
  355. else:
  356. # Temporarily set multi-head attention flag to false
  357. use_multi_head_attention_ground_truth = self.use_multi_head_attention
  358. self.use_multi_head_attention = False
  359. new_node = self.create_attention_node(
  360. mask_index=None,
  361. q_matmul=matmul_q,
  362. k_matmul=matmul_k,
  363. v_matmul=matmul_v,
  364. q_add=add_q,
  365. k_add=add_k,
  366. v_add=add_v,
  367. num_heads=num_heads,
  368. hidden_size=hidden_size,
  369. first_input=root_input,
  370. output=attention_last_node.output[0],
  371. causal=causal_mask,
  372. past_k=past_k,
  373. past_v=past_v,
  374. present_k=present_k,
  375. present_v=present_v,
  376. )
  377. self.use_multi_head_attention = use_multi_head_attention_ground_truth
  378. if new_node is None:
  379. logger.debug("fuse_attention: failed to create fused node")
  380. return
  381. self.nodes_to_add.append(new_node)
  382. self.node_name_to_graph_name[new_node.name] = self.this_graph_name
  383. self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
  384. self.nodes_to_remove.extend(qk_nodes)
  385. # When using multi-head attention, keep MatMul nodes in original graph
  386. if decoder_self_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past:
  387. if len(q_nodes) > 0 and q_nodes[-1].op_type == "MatMul":
  388. q_nodes.pop()
  389. if len(k_nodes) > 0 and k_nodes[-1].op_type == "MatMul":
  390. k_nodes.pop()
  391. if len(v_nodes) > 0 and v_nodes[-1].op_type == "MatMul":
  392. v_nodes.pop()
  393. if self.disable_multi_head_attention_bias:
  394. if len(q_nodes) > 0 and q_nodes[-1].op_type == "Add":
  395. q_nodes.pop()
  396. if len(k_nodes) > 0 and k_nodes[-1].op_type == "Add":
  397. k_nodes.pop()
  398. if len(v_nodes) > 0 and v_nodes[-1].op_type == "Add":
  399. v_nodes.pop()
  400. self.nodes_to_remove.extend(q_nodes)
  401. self.nodes_to_remove.extend(k_nodes)
  402. self.nodes_to_remove.extend(v_nodes)
  403. # Use prune graph to remove mask nodes since they are shared by all attention nodes.
  404. self.prune_graph = True