onnx_model_t5.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import logging
  6. import numpy as np
  7. from fusion_attention import AttentionMask, FusionAttention
  8. from fusion_base import Fusion
  9. from fusion_simplified_layernorm import FusionSimplifiedLayerNormalization, FusionSkipSimplifiedLayerNormalization
  10. from fusion_utils import NumpyHelper
  11. from onnx import NodeProto, TensorProto, helper
  12. from onnx_model import OnnxModel
  13. from onnx_model_bert import BertOnnxModel
  14. logger = logging.getLogger(__name__)
  15. class FusionT5Attention(FusionAttention):
  16. """
  17. Fuse T5 Attention subgraph into one Attention node.
  18. """
  19. def __init__(
  20. self,
  21. model: OnnxModel,
  22. hidden_size: int,
  23. num_heads: int,
  24. attention_mask: AttentionMask,
  25. ):
  26. super().__init__(
  27. model,
  28. hidden_size,
  29. num_heads,
  30. attention_mask,
  31. use_multi_head_attention=False,
  32. search_op_types=["Softmax"],
  33. )
  34. self.static_kv = 1
  35. def make_attention_node(
  36. self,
  37. mask_index: str | None,
  38. q_matmul: NodeProto,
  39. k_matmul: NodeProto,
  40. v_matmul: NodeProto,
  41. num_heads: int,
  42. hidden_size: int,
  43. input: str,
  44. output: str,
  45. attn_bias: str | None,
  46. scale: float,
  47. ) -> NodeProto | None:
  48. """Create an Attention node.
  49. Args:
  50. mask_index (str): mask input
  51. q_matmul (NodeProto): MatMul node in fully connection for Q
  52. k_matmul (NodeProto): MatMul node in fully connection for K
  53. v_matmul (NodeProto): MatMul node in fully connection for V
  54. num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
  55. hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
  56. input (str): input name
  57. output (str): output name
  58. Returns:
  59. Union[NodeProto, None]: the node created or None if failed.
  60. """
  61. assert num_heads > 0
  62. if hidden_size > 0 and (hidden_size % num_heads) != 0:
  63. logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
  64. return None
  65. q_weight = self.model.get_initializer(q_matmul.input[1])
  66. k_weight = self.model.get_initializer(k_matmul.input[1])
  67. v_weight = self.model.get_initializer(v_matmul.input[1])
  68. if q_weight is None or k_weight is None or v_weight is None:
  69. matmul = q_matmul if q_weight is None else k_matmul if k_weight is None else v_matmul
  70. print(
  71. f"{matmul.input[1]} is not an initializer. "
  72. "Please set do_constant_folding=True in torch.onnx.export to unblock attention fusion"
  73. )
  74. return None
  75. qw = NumpyHelper.to_array(q_weight)
  76. kw = NumpyHelper.to_array(k_weight)
  77. vw = NumpyHelper.to_array(v_weight)
  78. # assert q and k have same shape as expected
  79. assert qw.shape == kw.shape
  80. qw_in_size = qw.shape[0]
  81. kw_in_size = kw.shape[0]
  82. vw_in_size = vw.shape[0]
  83. assert qw_in_size == kw_in_size == vw_in_size
  84. if hidden_size > 0 and hidden_size != qw_in_size:
  85. logger.warning(
  86. f"Input hidden size ({hidden_size}) is not same as weight matrix dimension of q,k,v ({qw_in_size}). "
  87. "Please provide a correct input hidden size or pass in 0"
  88. )
  89. qw_out_size = np.prod(qw.shape[1:])
  90. qkv_weight = np.stack((qw, kw, vw), axis=1)
  91. qkv_weight_dim = 3 * qw_out_size
  92. attention_node_name = self.model.create_node_name("Attention")
  93. weight = helper.make_tensor(
  94. name=attention_node_name + "_qkv_weight",
  95. data_type=TensorProto.FLOAT,
  96. dims=[qw_in_size, qkv_weight_dim],
  97. vals=qkv_weight.tobytes(),
  98. raw=True,
  99. )
  100. self.model.add_initializer(weight, self.this_graph_name)
  101. attention_inputs = [
  102. input,
  103. attention_node_name + "_qkv_weight",
  104. "",
  105. ]
  106. if mask_index:
  107. attention_inputs.append(mask_index)
  108. else:
  109. attention_inputs.append("")
  110. if attn_bias:
  111. attention_inputs.append("") # no past
  112. attention_inputs.append(attn_bias)
  113. while attention_inputs and attention_inputs[-1] == "":
  114. attention_inputs.pop()
  115. attention_node = helper.make_node(
  116. "Attention",
  117. inputs=attention_inputs,
  118. outputs=[output],
  119. name=attention_node_name,
  120. )
  121. attention_node.domain = "com.microsoft"
  122. attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
  123. if scale is not None:
  124. attention_node.attribute.extend([helper.make_attribute("scale", scale)])
  125. if self.mask_filter_value is not None:
  126. attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
  127. return attention_node
  128. def create_mha_node(
  129. self,
  130. query: str,
  131. key: str,
  132. value: str,
  133. mask_index: str | None,
  134. attn_bias: str | None,
  135. past_key: str | None,
  136. past_value: str | None,
  137. output: str,
  138. present_key: str | None,
  139. present_value: str | None,
  140. num_heads: int,
  141. hidden_size: int,
  142. ) -> NodeProto | None:
  143. assert num_heads > 0 and hidden_size > 0 and query and key and value
  144. if (hidden_size % num_heads) != 0:
  145. logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
  146. return None
  147. attention_node_name = self.model.create_node_name("MultiHeadAttention")
  148. attention_inputs = [
  149. query,
  150. key,
  151. value,
  152. "", # bias
  153. ]
  154. if mask_index:
  155. attention_inputs.append(mask_index)
  156. else:
  157. attention_inputs.append("")
  158. if attn_bias:
  159. attention_inputs.append(attn_bias)
  160. else:
  161. attention_inputs.append("")
  162. if past_key:
  163. assert past_value
  164. attention_inputs.append(past_key)
  165. attention_inputs.append(past_value)
  166. while attention_inputs and attention_inputs[-1] == "":
  167. attention_inputs.pop()
  168. attention_outputs = [output]
  169. if present_key:
  170. assert present_value
  171. attention_outputs.append(present_key)
  172. attention_outputs.append(present_value)
  173. print(f"{attention_inputs=}, {attention_outputs=}, {attention_node_name=}")
  174. attention_node = helper.make_node(
  175. "MultiHeadAttention",
  176. inputs=attention_inputs,
  177. outputs=attention_outputs,
  178. name=attention_node_name,
  179. )
  180. attention_node.domain = "com.microsoft"
  181. attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
  182. attention_node.attribute.extend([helper.make_attribute("scale", 1.0)])
  183. if self.mask_filter_value is not None:
  184. attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
  185. self.increase_counter("MultiHeadAttention")
  186. return attention_node
  187. def fuse(self, node, input_name_to_nodes, output_name_to_node):
  188. if self.fuse_t5_encoder(node, input_name_to_nodes, output_name_to_node):
  189. return
  190. self.fuse_t5_decoder(node, input_name_to_nodes, output_name_to_node)
  191. def fuse_t5_encoder(self, softmax_node, input_name_to_nodes, output_name_to_node):
  192. assert softmax_node.op_type == "Softmax"
  193. qkv_nodes = self.model.match_child_path(
  194. softmax_node,
  195. ["MatMul", "Transpose", "Reshape"],
  196. edges=[(0, 0), (0, 0), (0, 0)],
  197. input_name_to_nodes=input_name_to_nodes,
  198. )
  199. if qkv_nodes is None:
  200. return False
  201. matmul_qkv, _, reshape_qkv = qkv_nodes
  202. qkv_shape_nodes = self.model.match_parent_path(
  203. reshape_qkv,
  204. ["Concat", "Unsqueeze", "Gather", "Shape"],
  205. [1, 0, 0, 0],
  206. output_name_to_node,
  207. )
  208. if qkv_shape_nodes is None:
  209. return False
  210. input_shape_node = qkv_shape_nodes[-1]
  211. v_nodes = self.model.match_parent_path(
  212. matmul_qkv,
  213. ["Transpose", "Reshape", "MatMul"],
  214. [1, 0, 0],
  215. output_name_to_node,
  216. )
  217. if v_nodes is None:
  218. return False
  219. _, reshape_v, matmul_v = v_nodes
  220. # todo: check reshape_v parent nodes
  221. qk_nodes = self.model.match_parent_path(
  222. matmul_qkv,
  223. ["Softmax", "Add", "MatMul"],
  224. [0, 0, 0],
  225. output_name_to_node,
  226. )
  227. if qk_nodes is None:
  228. return False
  229. _, add_qk, matmul_qk = qk_nodes
  230. mask_nodes = self.model.match_parent_path(
  231. add_qk,
  232. ["Add", "Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"],
  233. [1, 1, 0, 1, 0, 0],
  234. output_name_to_node,
  235. )
  236. is_pattern_for_one_graph_input = mask_nodes is None
  237. if mask_nodes is not None:
  238. mul_node = mask_nodes[1]
  239. else:
  240. # Pattern for SD3 and Flux.
  241. mask_nodes = self.model.match_parent_path(
  242. add_qk,
  243. ["Add", "Slice", "Mul", "Sub", "Unsqueeze", "Unsqueeze"],
  244. [1, 1, 0, 0, 1, 0],
  245. output_name_to_node,
  246. )
  247. # If the model is not optimized by ORT, there might be an additional Cast node.
  248. if mask_nodes is None:
  249. mask_nodes = self.model.match_parent_path(
  250. add_qk,
  251. ["Add", "Slice", "Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"],
  252. [1, 1, 0, 0, 1, 0, 0],
  253. output_name_to_node,
  254. )
  255. if mask_nodes is None:
  256. return False
  257. mul_node = mask_nodes[2]
  258. _, mul_val = self.model.get_constant_input(mul_node)
  259. if mul_val is None:
  260. return False
  261. if mul_val != -10000:
  262. self.mask_filter_value = float(mul_val)
  263. # If the mask is derived from shape of input_ids, it means there is no padding mask.
  264. mask_nodes_2 = self.model.match_parent_path(
  265. mask_nodes[-1],
  266. ["ConstantOfShape", "Concat", "Unsqueeze", "Gather", "Shape"],
  267. [0, 0, 0, 0, 0],
  268. output_name_to_node,
  269. )
  270. mask_nodes_3 = self.model.match_parent_path(
  271. mask_nodes[-1],
  272. ["ConstantOfShape", "Concat", "Unsqueeze", "Gather", "Shape"],
  273. [0, 0, 1, 0, 0],
  274. output_name_to_node,
  275. )
  276. if (
  277. mask_nodes_2 is not None
  278. and any(input.name == mask_nodes_2[-1].input[0] for input in self.model.graph().input)
  279. and mask_nodes_3 is not None
  280. and mask_nodes_2[-1].input[0] == mask_nodes_3[-1].input[0]
  281. and len(mask_nodes_2[1].input) == 2
  282. ):
  283. mask_index = ""
  284. else:
  285. mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
  286. res_pos_bias = None
  287. rpb_nodes = self.model.match_parent_path(
  288. add_qk,
  289. ["Add", "RelativePositionBias"],
  290. [1, 0],
  291. )
  292. if rpb_nodes is None and is_pattern_for_one_graph_input:
  293. # Pattern for SD3 and Flux.
  294. rpb_nodes = self.model.match_parent_path(
  295. add_qk,
  296. ["Add", "Slice", "RelativePositionBias"],
  297. [1, 0, 0],
  298. )
  299. if rpb_nodes is None:
  300. return False
  301. res_pos_bias = rpb_nodes[-1].output[0]
  302. k_nodes = self.model.match_parent_path(
  303. matmul_qk,
  304. ["Transpose", "Reshape", "MatMul"],
  305. [1, 0, 0],
  306. )
  307. if k_nodes is None:
  308. return False
  309. _, _, matmul_k = k_nodes
  310. # todo: check reshape_k parent nodes
  311. q_nodes = self.model.match_parent_path(
  312. matmul_qk,
  313. ["Transpose", "Reshape", "MatMul"],
  314. [0, 0, 0],
  315. )
  316. if q_nodes is None:
  317. return False
  318. _, reshape_q, matmul_q = q_nodes
  319. # todo: check reshape_q parent nodes
  320. if matmul_q.input[0] != input_shape_node.input[0]:
  321. return False
  322. q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
  323. new_node = self.make_attention_node(
  324. mask_index,
  325. matmul_q,
  326. matmul_k,
  327. matmul_v,
  328. num_heads=q_num_heads,
  329. hidden_size=q_hidden_size,
  330. input=input_shape_node.input[0],
  331. output=reshape_qkv.output[0],
  332. attn_bias=res_pos_bias,
  333. scale=1.0,
  334. )
  335. if new_node is None:
  336. return False
  337. self.nodes_to_add.append(new_node)
  338. self.node_name_to_graph_name[new_node.name] = self.this_graph_name
  339. self.nodes_to_remove.append(reshape_qkv)
  340. self.prune_graph = True
  341. return True
  342. def fuse_t5_decoder(self, softmax_node, input_name_to_nodes, output_name_to_node):
  343. assert softmax_node.op_type == "Softmax"
  344. qkv_nodes = self.model.match_child_path(
  345. softmax_node,
  346. ["MatMul", "Transpose", "Reshape"],
  347. edges=[(0, 0), (0, 0), (0, 0)],
  348. input_name_to_nodes=input_name_to_nodes,
  349. )
  350. if qkv_nodes is None:
  351. return
  352. matmul_qkv, _transpose_qkv, reshape_qkv = qkv_nodes
  353. qkv_shape_nodes = self.model.match_parent_path(
  354. reshape_qkv,
  355. ["Concat", "Unsqueeze", "Gather", "Shape"],
  356. [1, 0, 0, 0],
  357. )
  358. if qkv_shape_nodes is None:
  359. return
  360. input_shape_node = qkv_shape_nodes[-1]
  361. value = None
  362. past_value = None
  363. present_value = None
  364. v_nodes = self.model.match_parent_path(
  365. matmul_qkv,
  366. ["Concat", "Transpose", "Reshape", "MatMul"],
  367. [1, 1, 0, 0],
  368. )
  369. if v_nodes is None:
  370. v_nodes = self.model.match_parent_path(
  371. matmul_qkv,
  372. ["Transpose", "Reshape", "MatMul"],
  373. [1, 0, 0],
  374. )
  375. if v_nodes is not None:
  376. transpose_v, reshape_v, matmul_v = v_nodes
  377. value = reshape_v.input[0]
  378. present_value = transpose_v.output[0]
  379. if "present_value" not in present_value:
  380. return
  381. if matmul_v.input[0] != input_shape_node.input[0]:
  382. self.static_kv = 1
  383. else:
  384. self.static_kv = 0
  385. else:
  386. past_value = matmul_qkv.input[1]
  387. if past_value in output_name_to_node:
  388. return
  389. if "past_value_cross" not in past_value:
  390. return
  391. self.static_kv = 1
  392. else:
  393. concat_v, _, reshape_v, _ = v_nodes
  394. past_value = concat_v.input[0]
  395. if past_value in output_name_to_node:
  396. return
  397. if "past_value_self" not in past_value:
  398. return
  399. present_value = concat_v.output[0]
  400. if "present_value_self" not in present_value:
  401. return
  402. value = reshape_v.input[0]
  403. self.static_kv = 0
  404. qk_nodes = self.model.match_parent_path(
  405. matmul_qkv,
  406. ["Softmax", "Add", "MatMul"],
  407. [0, 0, 0],
  408. )
  409. if qk_nodes is None:
  410. return
  411. _, add_qk, matmul_qk = qk_nodes
  412. mask_index = None
  413. res_pos_bias = None
  414. if self.static_kv == 1:
  415. mask_nodes = self.model.match_parent_path(
  416. add_qk,
  417. ["Add", "Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"],
  418. [1, 1, 0, 1, 0, 0],
  419. )
  420. if mask_nodes is not None:
  421. mul_node = mask_nodes[1]
  422. else:
  423. mask_nodes = self.model.match_parent_path(
  424. add_qk,
  425. ["Add", "Slice", "Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"],
  426. [1, 1, 0, 0, 1, 0, 0],
  427. )
  428. if mask_nodes is None:
  429. return
  430. mul_node = mask_nodes[2]
  431. _, mul_val = self.model.get_constant_input(mul_node)
  432. if mul_val != -10000:
  433. self.mask_filter_value = mul_val
  434. mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
  435. else:
  436. matched_path_index, _, _ = self.model.match_parent_paths(
  437. add_qk,
  438. [
  439. (["Add", "Slice"], [1, 0]),
  440. (["Add", "RelativePositionBias"], [1, 0]),
  441. ],
  442. output_name_to_node,
  443. )
  444. if matched_path_index < 0:
  445. logger.debug("Skip MultiHeadAttention fusion since attention bias pattern not matched")
  446. return
  447. res_pos_bias = add_qk.input[1]
  448. key = None
  449. past_key = None
  450. present_key = None
  451. if self.static_kv == 1:
  452. k_nodes = self.model.match_parent_path(
  453. matmul_qk,
  454. ["Transpose", "Reshape", "MatMul"],
  455. [1, 0, 0],
  456. )
  457. if k_nodes is not None:
  458. transpose_k, reshape_k, _ = k_nodes
  459. key = reshape_k.input[0]
  460. present_key_transpose_nodes = input_name_to_nodes[reshape_k.output[0]]
  461. for present_key_transpose_node in present_key_transpose_nodes:
  462. present_key_candidate = self.model.find_graph_output(present_key_transpose_node.output[0])
  463. if present_key_candidate is not None:
  464. present_key = present_key_candidate.name
  465. break
  466. if present_key is None:
  467. return
  468. if "present_key_cross" not in present_key:
  469. return
  470. else:
  471. k_nodes = self.model.match_parent_path(
  472. matmul_qk,
  473. ["Transpose"],
  474. [1],
  475. )
  476. if k_nodes is None:
  477. return
  478. transpose_k = k_nodes[0]
  479. past_key = transpose_k.input[0]
  480. if past_key in output_name_to_node:
  481. return
  482. if "past_key_cross" not in past_key:
  483. return
  484. else:
  485. idx, k_nodes, _ = self.model.match_parent_paths(
  486. matmul_qk,
  487. [
  488. (["Transpose", "Concat", "Reshape", "MatMul"], [1, 0, 1, 0]),
  489. (["Transpose", "Concat", "Transpose", "Reshape", "MatMul"], [1, 0, 1, 0, 0]),
  490. ],
  491. output_name_to_node,
  492. )
  493. past_key_transpose_node = None
  494. present_key_transpose_nodes = None
  495. if k_nodes is not None:
  496. concat_k, reshape_k = k_nodes[1], k_nodes[-2]
  497. key = reshape_k.input[0]
  498. if idx == 0:
  499. past_key_transpose_node = output_name_to_node[concat_k.input[0]]
  500. past_key = past_key_transpose_node.input[0]
  501. else:
  502. past_key = concat_k.input[0]
  503. if past_key in output_name_to_node:
  504. return
  505. if "past_key_self" not in past_key:
  506. return
  507. if idx == 0:
  508. present_key_transpose_nodes = input_name_to_nodes[concat_k.output[0]]
  509. for present_key_transpose_node in present_key_transpose_nodes:
  510. present_key_candidate = self.model.find_graph_output(present_key_transpose_node.output[0])
  511. if present_key_candidate is not None:
  512. present_key = present_key_candidate.name
  513. break
  514. else:
  515. present_key = concat_k.output[0]
  516. if present_key is None:
  517. return
  518. if "present_key_self" not in present_key:
  519. return
  520. else:
  521. k_nodes = self.model.match_parent_path(
  522. matmul_qk,
  523. ["Transpose", "Reshape", "MatMul"],
  524. [1, 0, 0],
  525. )
  526. if k_nodes is None:
  527. return
  528. _, reshape_k, _ = k_nodes
  529. key = reshape_k.input[0]
  530. present_key_transpose_nodes = input_name_to_nodes[reshape_k.output[0]]
  531. for present_key_transpose_node in present_key_transpose_nodes:
  532. present_key_candidate = self.model.find_graph_output(present_key_transpose_node.output[0])
  533. if present_key_candidate is not None:
  534. present_key = present_key_candidate.name
  535. break
  536. if present_key is None:
  537. return
  538. if "present_key_self" not in present_key:
  539. return
  540. q_nodes = self.model.match_parent_path(
  541. matmul_qk,
  542. ["Transpose", "Reshape", "MatMul"],
  543. [0, 0, 0],
  544. )
  545. if q_nodes is None:
  546. return
  547. transpose_q, reshape_q, matmul_q = q_nodes
  548. if matmul_q.input[0] != input_shape_node.input[0]:
  549. return
  550. q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
  551. if self.static_kv == 1 and past_key is not None:
  552. key = past_key
  553. value = past_value
  554. past_key = None
  555. past_value = None
  556. if not (key and value and q_num_heads > 0 and q_hidden_size > 0):
  557. return
  558. new_node = self.create_mha_node(
  559. query=matmul_q.output[0],
  560. key=key,
  561. value=value,
  562. mask_index=mask_index,
  563. attn_bias=res_pos_bias,
  564. past_key=past_key,
  565. past_value=past_value,
  566. output=reshape_qkv.output[0],
  567. present_key=present_key,
  568. present_value=present_value,
  569. num_heads=q_num_heads,
  570. hidden_size=q_hidden_size,
  571. )
  572. if new_node:
  573. self.nodes_to_add.append(new_node)
  574. self.node_name_to_graph_name[new_node.name] = self.this_graph_name
  575. # Since present_* is graph output, we need update the graph to avoid circular.
  576. if present_key or present_value:
  577. for graph_output in [present_key, present_value]:
  578. if not (graph_output and self.model.find_graph_output(graph_output)):
  579. print(f"{graph_output=} does not exist in graph output")
  580. return
  581. assert graph_output in output_name_to_node
  582. output_name_to_node[graph_output].output[0] = graph_output + "_copy"
  583. self.model.replace_input_of_all_nodes(graph_output, graph_output + "_copy")
  584. self.nodes_to_remove.append(reshape_qkv)
  585. self.prune_graph = False
  586. class FusionRelativePositionBiasBlock(Fusion):
  587. def __init__(self, model: OnnxModel):
  588. super().__init__(model, "RelativePositionBias", ["Softmax"])
  589. def fuse(self, node, input_name_to_nodes, output_name_to_node):
  590. compute_bias_nodes = self.model.match_parent_path(
  591. node,
  592. ["Add", "Add", "Slice", "Unsqueeze", "Transpose", "Gather", "Where"],
  593. [0, 1, 0, 0, 0, 0, 1],
  594. output_name_to_node,
  595. )
  596. if compute_bias_nodes is None:
  597. compute_bias_nodes = self.model.match_parent_path(
  598. node,
  599. ["Add", "Add", "Slice", "Unsqueeze", "Transpose", "Gather", "Add", "Where"],
  600. [0, 1, 0, 0, 0, 0, 1, 1],
  601. output_name_to_node,
  602. )
  603. if compute_bias_nodes is None:
  604. return
  605. gather = compute_bias_nodes[5]
  606. where = compute_bias_nodes[-1]
  607. slice = compute_bias_nodes[2]
  608. unsqueeze = compute_bias_nodes[3]
  609. # Current fusion will not remove the node until the graph is processed.
  610. # This avoids to fuse it again when it is shared by multiple layers.
  611. if unsqueeze in self.nodes_to_remove:
  612. return
  613. compute_buckets_nodes = self.model.match_parent_path(
  614. where,
  615. ["Min", "ConstantOfShape", "Shape", "Add", "Cast", "Mul", "Div", "Log", "Div"],
  616. [2, 1, 0, 0, 0, 0, 0, 0, 0],
  617. output_name_to_node,
  618. )
  619. if compute_buckets_nodes is None:
  620. return
  621. # This value is to used to compute max_distance later.
  622. log_max = self.model.get_constant_value(compute_buckets_nodes[-3].input[1])
  623. div = compute_buckets_nodes[-1]
  624. range_nodes = self.model.match_parent_path(
  625. div,
  626. ["Cast", "Neg", "Min", "ConstantOfShape", "Shape", "Sub", "Unsqueeze", "Range"],
  627. [0, 0, 0, 1, 0, 0, 0, 0],
  628. output_name_to_node,
  629. )
  630. is_bidirectional = False
  631. if range_nodes is None:
  632. range_nodes = self.model.match_parent_path(
  633. div, ["Cast", "Abs", "Sub", "Unsqueeze", "Range"], [0, 0, 0, 0, 0], output_name_to_node
  634. )
  635. is_bidirectional = True
  636. if range_nodes is None:
  637. return
  638. range_node = range_nodes[-1]
  639. # Double check that the constant relative to max_distance and relative_attention_num_buckets.
  640. # Most t5 models use max_distance=128, so we hardcode it unitl we see a model with different value.
  641. # The log_max is the value of the following formula:
  642. # math.log(max_distance / (relative_attention_num_buckets // (4 if is_bidirectional else 2)))
  643. # See https://github.com/huggingface/transformers/blob/608e163b527eaee41e650ffb9eb4c422d2679902/src/transformers/models/t5/modeling_t5.py#L397.
  644. # Here is the value based on max_distance=128 and relative_attention_num_buckets=32:
  645. max_distance = int(np.round(np.exp(log_max) * (32 // (4 if is_bidirectional else 2))))
  646. if max_distance != 128:
  647. logger.warning(
  648. f"max_distance is {max_distance}, which is different from the default value 128. "
  649. "Please double check the model configuration."
  650. )
  651. node_name = self.model.create_node_name(
  652. "RelativePositionBias", name_prefix="RelPosBias_" + ("encoder" if is_bidirectional else "decoder")
  653. )
  654. table_weight_i = self.model.get_initializer(gather.input[0])
  655. if table_weight_i is None:
  656. return
  657. table_weight = NumpyHelper.to_array(table_weight_i)
  658. table_weight_t = np.transpose(table_weight)
  659. bias_table = helper.make_tensor(
  660. name=node_name + "_bias_table_weight",
  661. data_type=TensorProto.FLOAT,
  662. dims=[np.shape(table_weight)[0], np.shape(table_weight)[1]],
  663. vals=table_weight_t.tobytes(),
  664. raw=True,
  665. )
  666. self.model.add_initializer(bias_table, self.this_graph_name)
  667. # Relative position is like the following in encoder:
  668. # seq_len
  669. # |
  670. # Range(0, *)
  671. # / \
  672. # Unsqueeze(axes=0) Unsqueeze(axes=1)
  673. # \ /
  674. # Sub
  675. # |
  676. # Abs
  677. #
  678. # Relative position is like the following in decoder:
  679. # past_seq_len seq_len
  680. # \ /
  681. # Add
  682. # / \
  683. # Range(0, *) Range(0, *)
  684. # \ /
  685. # Sub
  686. # Note that the graph will slice the attention bias to get last seq_len rows.
  687. #
  688. # In new version of transformers, the pattern of decoder is changed like the following
  689. #
  690. # total_seq_len Range(start=past_seq_len, end=total_seq_len)
  691. # | |
  692. # Range(0, *) Unsqueeze(axes=1)
  693. # | |
  694. # Unsqueeze(axes=0) Cast(to=int64)
  695. # \ /
  696. # Sub
  697. # Currently, there is still Slice to get last seq_len rows so end result is same.
  698. # But need to be careful that the shape of bias tensor is changed before Slice.
  699. #
  700. # RelativePositionBias operator requires query_length == key_length so we shall pass in total_seq_len.
  701. # Here we get the end value of the Range node as length to pass to the RelativePositionBias node.
  702. # TODO: Optimization opportunity: change RelativePositionBias op to support query_length != key_length.
  703. # only compute seq_len rows, then we can remove the Slice after the RelativePositionBias node.
  704. inputs = [bias_table.name, range_node.input[1], range_node.input[1]]
  705. # Use a new tensor name since the shape might be different as mentioned above.
  706. bias_output = node_name + "_rel_pos_bias"
  707. slice.input[0] = bias_output
  708. rpb_node = helper.make_node(
  709. "RelativePositionBias",
  710. inputs=inputs,
  711. outputs=[bias_output],
  712. name=node_name,
  713. )
  714. rpb_node.domain = "com.microsoft"
  715. rpb_node.attribute.extend([helper.make_attribute("max_distance", max_distance)])
  716. rpb_node.attribute.extend([helper.make_attribute("is_bidirectional", is_bidirectional)])
  717. self.node_name_to_graph_name[rpb_node.name] = self.this_graph_name
  718. self.nodes_to_add.append(rpb_node)
  719. self.prune_graph = True
  720. class T5OnnxModel(BertOnnxModel):
  721. def __init__(self, model, num_heads: int = 0, hidden_size: int = 0):
  722. super().__init__(model, num_heads, hidden_size)
  723. self.attention_mask = AttentionMask(self)
  724. # When the model has only one input (input_ids), there is no padding mask.
  725. if len(self.model.graph.input) == 1:
  726. from fusion_options import AttentionMaskFormat # noqa: PLC0415
  727. self.attention_mask.mask_format = AttentionMaskFormat.NoMask
  728. self.attention_fusion = FusionT5Attention(self, self.hidden_size, self.num_heads, self.attention_mask)
  729. self.layer_norm_fusion = FusionSimplifiedLayerNormalization(self)
  730. self.skip_layer_norm_fusion = FusionSkipSimplifiedLayerNormalization(self)
  731. self.rpb_fusion = FusionRelativePositionBiasBlock(self)
  732. def fuse_attention(self):
  733. self.attention_fusion.apply()
  734. def fuse_layer_norm(self):
  735. self.layer_norm_fusion.apply()
  736. def fuse_skip_layer_norm(self, shape_infer=True):
  737. self.skip_layer_norm_fusion.apply()
  738. def adjust_rel_pos_bis_length_input(self):
  739. # For T5 encoder, it uses complex logic to compute the query and key length when there is only one graph input (input_ids)
  740. # We can directly get the length from shape (the 2nd dimension) of input_ids.
  741. for node in self.nodes():
  742. if node.op_type == "RelativePositionBias":
  743. nodes = self.match_parent_path(
  744. node,
  745. [
  746. "Gather",
  747. "Shape",
  748. "Transpose",
  749. "Reshape",
  750. "Concat",
  751. "Unsqueeze",
  752. "Gather",
  753. "Shape",
  754. "SimplifiedLayerNormalization",
  755. "Gather",
  756. ],
  757. [1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
  758. )
  759. # TODO: more validation on node attributes
  760. if nodes is not None:
  761. graph_input_names = [input.name for input in self.model.graph.input]
  762. if nodes[-1].input[1] in graph_input_names:
  763. node_name = self.create_node_name("Shape", name_prefix="Added_Shape_")
  764. shape_node = helper.make_node(
  765. "Shape",
  766. inputs=[nodes[-1].input[1]],
  767. outputs=[node_name + "_Output"],
  768. name=node_name,
  769. )
  770. indices_1 = helper.make_tensor(
  771. name="Constant_Index_1",
  772. data_type=TensorProto.INT64,
  773. dims=[1], # Shape of the tensor
  774. vals=[1], # Tensor values
  775. )
  776. self.add_initializer(indices_1)
  777. gather = helper.make_node(
  778. "Gather",
  779. inputs=[node_name + "_Output", "Constant_Index_1"],
  780. outputs=[node_name + "_Output_Gather_1"],
  781. name=self.create_node_name("Gather", name_prefix="Added_Gather_"),
  782. axis=0,
  783. )
  784. self.add_node(shape_node)
  785. self.add_node(gather)
  786. node.input[1] = node_name + "_Output_Gather_1"
  787. node.input[2] = node_name + "_Output_Gather_1"
  788. break
  789. # Remove get_extended_attention_mask() since it generates all zeros.
  790. def remove_extended_mask_decoder_init(self):
  791. nodes_to_remove = []
  792. for node in self.nodes():
  793. if node.op_type == "Add":
  794. extended_mask_nodes = self.match_parent_path(
  795. node,
  796. [
  797. "Mul",
  798. "Sub",
  799. "Mul",
  800. "Unsqueeze",
  801. "Cast",
  802. "LessOrEqual",
  803. "Tile",
  804. "Concat",
  805. "Unsqueeze",
  806. "Gather",
  807. "Shape",
  808. ],
  809. [1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
  810. )
  811. if extended_mask_nodes is None:
  812. continue
  813. rpb_nodes = self.match_parent_path(node, ["RelativePositionBias"], [0])
  814. if rpb_nodes is None:
  815. continue
  816. rpb_node = rpb_nodes[0]
  817. rpb_node.output[0] = node.output[0]
  818. nodes_to_remove.extend(extended_mask_nodes)
  819. nodes_to_remove.append(node)
  820. self.remove_nodes(nodes_to_remove)
  821. def remove_extended_mask_decoder(self):
  822. nodes_to_remove = []
  823. for node in self.nodes():
  824. if node.op_type == "Add":
  825. extended_mask_nodes = self.match_parent_path(
  826. node,
  827. [
  828. "Mul",
  829. "Sub",
  830. "Mul",
  831. "Unsqueeze",
  832. "Concat",
  833. "Cast",
  834. "LessOrEqual",
  835. "Tile",
  836. "Concat",
  837. "Unsqueeze",
  838. "Gather",
  839. "Shape",
  840. ],
  841. [1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0],
  842. )
  843. if extended_mask_nodes is None:
  844. continue
  845. rpb_nodes = self.match_parent_path(node, ["Slice", "RelativePositionBias"], [0, 0])
  846. if rpb_nodes is None:
  847. continue
  848. rpb_node = rpb_nodes[0]
  849. rpb_node.output[0] = node.output[0]
  850. nodes_to_remove.extend(extended_mask_nodes)
  851. nodes_to_remove.append(node)
  852. self.remove_nodes(nodes_to_remove)
  853. def preprocess(self):
  854. self.adjust_reshape_and_expand()
  855. self.rpb_fusion.apply()
  856. def postprocess(self):
  857. # remove get_extended_attention_mask() since it generates all zeros.
  858. self.remove_extended_mask_decoder_init()
  859. self.remove_extended_mask_decoder()
  860. self.adjust_rel_pos_bis_length_input()
  861. self.prune_graph()