fusion_embedlayer.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from fusion_base import Fusion
  7. from fusion_utils import FusionUtils
  8. from onnx import NodeProto, TensorProto, helper
  9. from onnx_model import OnnxModel
  10. logger = getLogger(__name__)
  11. class FusionEmbedLayerNoMask(Fusion):
  12. """
  13. Fuse embedding layer into one node (EmbedLayerNormalization).
  14. It supports the following model types: BERT, DistilBert, ALBert.
  15. """
  16. def __init__(self, model: OnnxModel, description: str = "no mask"):
  17. super().__init__(
  18. model,
  19. "EmbedLayerNormalization",
  20. ["LayerNormalization", "SkipLayerNormalization"],
  21. description,
  22. )
  23. self.utils = FusionUtils(model)
  24. self.shape_infer = None
  25. self.shape_infer_done = False
  26. # The following will be reset in each fuse call of FusionEmbedLayerNormalization
  27. self.attention = None
  28. self.embed_node = None
  29. def match_two_gather(self, add: NodeProto) -> None | tuple[NodeProto, NodeProto]:
  30. gather_0_path = self.model.match_parent_path(add, ["Gather"], [0])
  31. if gather_0_path is None:
  32. return None
  33. gather_1_path = self.model.match_parent_path(add, ["Gather"], [1])
  34. if gather_1_path is None:
  35. return None
  36. return gather_0_path[0], gather_1_path[0]
  37. def check_attention_subgraph(
  38. self,
  39. layernorm: NodeProto,
  40. input_name_to_nodes: dict[str, list[NodeProto]],
  41. is_distil_bert: bool,
  42. ) -> bool:
  43. """Check that LayerNormalization has a child of Attention node or subgraph like Attention.
  44. Args:
  45. layernorm (NodeProto): LayerNormalization node
  46. input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes
  47. is_distil_bert (bool): whether it is DistilBert or not
  48. Returns:
  49. bool: whether there is Attention node or subgraph like Attention
  50. """
  51. self.attention = self.model.find_first_child_by_type(
  52. layernorm, "Attention", input_name_to_nodes, recursive=False
  53. )
  54. if self.attention is not None:
  55. return True
  56. if layernorm.output[0] not in input_name_to_nodes:
  57. return False
  58. children = input_name_to_nodes[layernorm.output[0]]
  59. children_types = sorted([child.op_type for child in children])
  60. # Try find MultiHeadAttention
  61. if children_types == ["MatMul", "MatMul", "MatMul", "SkipLayerNormalization"]:
  62. for node in children:
  63. if node.op_type == "SkipLayerNormalization":
  64. path1 = self.model.match_parent_path(
  65. node,
  66. ["Add", "MatMul", "MultiHeadAttention", "MatMul"],
  67. [None, None, 0, 0],
  68. )
  69. if path1 is not None and path1[-1].input[0] == layernorm.output[0]:
  70. self.cross_attention = path1[2]
  71. return True
  72. # In case user disables attention fusion, check whether subgraph looks like Attention.
  73. # For Albert, there is MatMul+Add after embedding layer before attention.
  74. if len(children) == 1 and children[0].op_type == "MatMul" and children[0].output[0] in input_name_to_nodes:
  75. grandchildren = input_name_to_nodes[children[0].output[0]]
  76. if (
  77. len(grandchildren) == 1
  78. and grandchildren[0].op_type == "Add"
  79. and grandchildren[0].output[0] in input_name_to_nodes
  80. ):
  81. nodes = input_name_to_nodes[grandchildren[0].output[0]]
  82. for node in nodes:
  83. if node.op_type == "Attention":
  84. self.attention = node
  85. return True
  86. children_types = sorted([child.op_type for child in nodes])
  87. # Two Shape nodes might be merged by ORT
  88. if is_distil_bert:
  89. # SkipLayerNormailization might exist when model has been optimized by ORT first.
  90. if (
  91. children_types != ["MatMul", "MatMul", "MatMul", "Shape", "SkipLayerNormalization"]
  92. and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape", "Shape"]
  93. and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape"]
  94. ):
  95. logger.debug("No Attention like subgraph in children of LayerNormalization")
  96. return False
  97. else:
  98. if children_types != [
  99. "Add",
  100. "MatMul",
  101. "MatMul",
  102. "MatMul",
  103. ] and children_types != [
  104. "MatMul",
  105. "MatMul",
  106. "MatMul",
  107. "SkipLayerNormalization",
  108. ]:
  109. logger.debug("No Attention like subgraph in children of LayerNormalization")
  110. return False
  111. return True
  112. def match_position_embedding_distilbert(self, position_embedding_gather, input_ids, output_name_to_node):
  113. """ Match position embedding path from input_ids to Gather for DistilBert.
  114. Pattern is like the following:
  115. (input_ids)
  116. |
  117. Shape
  118. | \
  119. | Gather (indices=1)
  120. | |
  121. | Cast (optional)
  122. | |
  123. | Range (start=0, end=*, delta=1)
  124. | |
  125. | Unsqueeze
  126. | /
  127. Expand
  128. |
  129. Gather
  130. """
  131. # remove after tests pass
  132. path1 = self.model.match_parent_path(position_embedding_gather, ["Expand", "Shape"], [1, 1])
  133. if path1 is None:
  134. path1 = self.model.match_parent_path(
  135. position_embedding_gather,
  136. ["Expand", "Where", "Reshape", "Shape"],
  137. [1, 1, 2, 0],
  138. )
  139. if path1 is None:
  140. return False
  141. expand, shape = path1[0], path1[-1]
  142. if shape.input[0] != input_ids:
  143. return False
  144. _, path2, _ = self.model.match_parent_paths(
  145. expand,
  146. [
  147. (["Unsqueeze", "Range", "Cast", "Gather", "Shape"], [0, 0, 1, 0, 0]),
  148. (["Unsqueeze", "Range", "Gather", "Shape"], [0, 0, 1, 0]),
  149. ],
  150. output_name_to_node,
  151. )
  152. if path2 is None:
  153. return False
  154. range_node = path2[1]
  155. if not (
  156. self.utils.check_node_input_value(range_node, 0, 0) and self.utils.check_node_input_value(range_node, 2, 1)
  157. ):
  158. return False
  159. gather_node = path2[-2]
  160. if not (self.utils.check_node_input_value(gather_node, 1, 1)):
  161. return False
  162. shape_node = path2[-1]
  163. if shape_node.input[0] != input_ids:
  164. return False
  165. return True
  166. def match_position_embedding_roberta(self, position_embedding_gather, input_ids, output_name_to_node):
  167. """Match position embedding path from input_ids to Gather for Roberta.
  168. Roberta Embedding Layer Pattern (* is optional since it might be removed by ORT, ? is the padding word id):
  169. (input_ids) --> Equal(B=?) -- Not -- Cast(to=6) -- CumSum(axis=1) -- Mul -- Cast(to=7) -- Add(B=1) -- Cast(to=7)* --> Gather
  170. | ^
  171. V |
  172. +------------------------------+
  173. Roberta new pattern from transformers v4.9:
  174. (input_ids) --> Equal(B=?) -- Not -- Cast(to=6) -- CumSum(axis=1) -- Add(B=0) -- Mul -- Cast(to=7) -- Add(B=1) --> Gather
  175. | ^
  176. V |
  177. +-------------------------------------------+
  178. start_node = position_embedding_gather
  179. start_index = 1
  180. # match optional Cast node.
  181. parent = self.model.get_parent(start_node, start_index, output_name_to_node)
  182. if parent is None:
  183. return
  184. if parent.op_type == "Cast":
  185. if OnnxModel.get_node_attribute(parent, "to") != 7:
  186. return
  187. start_node = parent
  188. start_index = 0
  189. i, path, return_indices = self.model.match_parent_paths(
  190. start_node,
  191. [ (['Add', 'Cast', 'Mul', 'CumSum', 'Cast', 'Not', 'Equal'], [start_index, 0, 0, 0, 0, 0, 0]),
  192. (['Add', 'Cast', 'Mul', 'Add', 'CumSum', 'Cast', 'Not', 'Equal'], [start_index, 0, 0, 0, 0, 0, 0, 0])],
  193. output_name_to_node)
  194. if path is not None:
  195. # constant input of Add shall be 1.
  196. i, value = self.model.get_constant_input(path[0])
  197. if value != 1:
  198. return False
  199. _, self.padding_word_id = self.model.get_constant_input(path[-1])
  200. return input_ids == path[-1].input[0]
  201. """
  202. return False
  203. def match_position_embedding_bert(self, position_embedding_gather, input_ids, output_name_to_node):
  204. """ Match position embedding path from input_ids to Gather for BERT.
  205. BERT Embedding Layer Pattern:
  206. (input_ids)
  207. / \
  208. / Shape
  209. / |
  210. / Gather (indices=1)
  211. / |
  212. / Add (optional, B=0)
  213. / |
  214. Gather (segment_ids) Unsqueeze (axes=0)
  215. \\ | |
  216. \\ Gather Slice (data[1,512], starts=0, ends=*, axes=1, steps=1)
  217. \\ / |
  218. Add Gather
  219. \\ /
  220. Add
  221. |
  222. LayerNormalization
  223. """
  224. path = self.model.match_parent_path(
  225. position_embedding_gather,
  226. ["Slice", "Unsqueeze"],
  227. [1, 2],
  228. output_name_to_node,
  229. )
  230. if path is None:
  231. return False
  232. slice, unsqueeze = path
  233. slice_weight = self.model.get_constant_value(slice.input[0])
  234. if not (
  235. slice_weight is not None
  236. and len(slice_weight.shape) == 2
  237. and slice_weight.shape[0] == 1
  238. and self.utils.check_node_input_value(slice, 1, [0])
  239. and self.utils.check_node_input_value(slice, 3, [1])
  240. and (len(slice.input) == 4 or self.utils.check_node_input_value(slice, 4, [1]))
  241. ):
  242. return False
  243. opset_version = self.model.get_opset_version()
  244. if opset_version < 13:
  245. if not FusionUtils.check_node_attribute(unsqueeze, "axes", [0]):
  246. return False
  247. else:
  248. if not self.utils.check_node_input_value(unsqueeze, 1, [0]):
  249. return False
  250. node = self.model.get_parent(unsqueeze, 0, output_name_to_node)
  251. if node is None:
  252. return False
  253. if node.op_type == "Add":
  254. if not self.utils.check_node_input_value(node, 1, 0):
  255. return False
  256. gather = self.model.get_parent(node, 0, output_name_to_node)
  257. else:
  258. gather = node
  259. if gather is None or gather.op_type != "Gather":
  260. return False
  261. if not (self.utils.check_node_input_value(gather, 1, 1)):
  262. return False
  263. shape = self.model.get_parent(gather, 0, output_name_to_node)
  264. if shape is None or shape.op_type != "Shape":
  265. return False
  266. return input_ids == shape.input[0]
  267. def match_position_embedding(self, position_embedding_gather, input_ids, output_name_to_node):
  268. if self.match_position_embedding_bert(position_embedding_gather, input_ids, output_name_to_node):
  269. return True
  270. # TODO: Support roberta (position starts from 2 instead of 0) in EmbedLayerNormalization kernel
  271. # related: https://github.com/huggingface/transformers/issues/10736
  272. # if self.match_position_embedding_roberta(position_embedding_gather, input_ids, output_name_to_node):
  273. # return True
  274. if self.match_position_embedding_distilbert(position_embedding_gather, input_ids, output_name_to_node):
  275. return True
  276. return False
  277. def check_embedding(self, word_embedding_gather, segment_embedding_gather, position_embedding_gather):
  278. """Sanity check of embedding weights, and match hidden_size of weights and shape of inputs."""
  279. input_ids = word_embedding_gather.input[1]
  280. segment_ids = segment_embedding_gather.input[1] if segment_embedding_gather else None
  281. position_ids = position_embedding_gather.input[1]
  282. if not self.shape_infer_done:
  283. self.shape_infer = self.model.infer_runtime_shape(update=True)
  284. self.shape_infer_done = True
  285. if self.shape_infer is not None:
  286. input_ids_shape = self.shape_infer.get_edge_shape(input_ids)
  287. position_ids_shape = self.shape_infer.get_edge_shape(position_ids)
  288. assert input_ids_shape and position_ids_shape
  289. if not (
  290. len(input_ids_shape) == 2
  291. and len(position_ids_shape) == 2
  292. and input_ids_shape[1] == position_ids_shape[1]
  293. ):
  294. logger.info(
  295. f"Cannot fuse EmbedLayerNormalization: input_ids and position_ids not matched in 2nd dimension: {input_ids_shape} vs {position_ids_shape}"
  296. )
  297. return False
  298. if segment_ids and not self.shape_infer.compare_shape(input_ids, segment_ids):
  299. logger.info(
  300. f"Cannot fuse EmbedLayerNormalization: input_ids and segment_ids does not have same shape: {input_ids_shape} != {self.shape_infer.get_edge_shape(segment_ids)}"
  301. )
  302. return False
  303. word_embedding_table = self.model.get_constant_value(word_embedding_gather.input[0])
  304. if word_embedding_table is None or len(word_embedding_table.shape) != 2:
  305. logger.info("Cannot fuse EmbedLayerNormalization: word embedding table is not expected")
  306. return False
  307. position_embedding_table = self.model.get_constant_value(position_embedding_gather.input[0])
  308. if (
  309. position_embedding_table is None
  310. or len(position_embedding_table.shape) != 2
  311. or (word_embedding_table.shape[1] != position_embedding_table.shape[1])
  312. ):
  313. logger.info("Cannot fuse EmbedLayerNormalization: position embedding table is not expected")
  314. return False
  315. if segment_ids:
  316. segment_embedding_table = self.model.get_constant_value(segment_embedding_gather.input[0])
  317. if (
  318. segment_embedding_table is None
  319. or len(segment_embedding_table.shape) != 2
  320. or (word_embedding_table.shape[1] != segment_embedding_table.shape[1])
  321. ):
  322. logger.info("Cannot fuse EmbedLayerNormalization: segment embedding table is not expected")
  323. return False
  324. # In normal case, word embedding table is the largest, and segment embedding table is the smallest, while position embedding table is in between.
  325. # TODO: use other information (like initializer names) to identify different embedding weights automatically.
  326. if word_embedding_table.shape[0] <= position_embedding_table.shape[0]:
  327. logger.warning(
  328. f"word_embedding_table ({word_embedding_gather.input[0]}) size {word_embedding_table.shape[0]} <= position_embedding_table ({position_embedding_gather.input[0]}) size {position_embedding_table.shape[0]}"
  329. )
  330. if segment_ids:
  331. if word_embedding_table.shape[0] <= segment_embedding_table.shape[0]:
  332. logger.warning(
  333. f"word_embedding_table ({word_embedding_gather.input[0]}) size {word_embedding_table.shape[0]} <= segment_embedding_table ({segment_embedding_gather.input[0]}) size {segment_embedding_table.shape[0]}"
  334. )
  335. if position_embedding_table.shape[0] <= segment_embedding_table.shape[0]:
  336. logger.warning(
  337. f"position_embedding_table ({position_embedding_gather.input[0]}) size {position_embedding_table.shape[0]} <= segment_embedding_table ({segment_embedding_gather.input[0]}) size {segment_embedding_table.shape[0]}"
  338. )
  339. return True
  340. def cast_to_int32(self, input_name: str) -> tuple[str, None | NodeProto]:
  341. """Cast a graph input or node input to int32.
  342. Args:
  343. input_name (str): name of graph input or node input
  344. Returns:
  345. A tuple of casted input name and the cast node.
  346. int32_output (str): If input is int32, it is the input name, Otherwise it is output name of Cast node.
  347. input_cast_node (Union[None, NodeProto]): Cast node. It could be None if input is int32.
  348. """
  349. input_cast_node = None
  350. graph_input = self.model.find_graph_input(input_name)
  351. if graph_input is not None:
  352. if graph_input.type.tensor_type.elem_type != TensorProto.INT32:
  353. int32_output, input_cast_node = self.utils.cast_input_to_int32(input_name)
  354. else:
  355. int32_output = input_name
  356. else:
  357. int32_output, input_cast_node = self.utils.cast_input_to_int32(input_name)
  358. return int32_output, input_cast_node
  359. def create_fused_node(
  360. self,
  361. input_ids: str,
  362. layernorm: NodeProto,
  363. word_embedding_gather: NodeProto,
  364. position_embedding_gather: NodeProto,
  365. segment_embedding_gather: None | NodeProto,
  366. position_ids: str | None = None,
  367. embedding_sum_output=False,
  368. embedding_sum_name=None,
  369. ):
  370. """Create an EmbedLayerNormalization node. Note that segment embedding is optional.
  371. Args:
  372. input_ids (str): input_ids for word embeddings
  373. layernorm (NodeProto): LayerNormalization or SkipLayerNormalization node.
  374. word_embedding_gather (NodeProto): the Gather node for word embedding
  375. position_embedding_gather (NodeProto): the Gather node for position embedding
  376. segment_embedding_gather (Union[None, NodeProto]): the Gather node for segment embedding, or None.
  377. Returns:
  378. NodeProto: the EmbedLayerNormalization node created.
  379. """
  380. nodes_to_add = []
  381. input_ids, _ = self.cast_to_int32(input_ids)
  382. node_name = self.model.create_node_name("EmbedLayerNormalization")
  383. if layernorm.op_type == "LayerNormalization":
  384. gamma = layernorm.input[1]
  385. beta = layernorm.input[2]
  386. else: # SkipLayerNormalization
  387. gamma = layernorm.input[2]
  388. beta = layernorm.input[3]
  389. embed_node_inputs = None
  390. if segment_embedding_gather is not None:
  391. segment_ids, _ = self.cast_to_int32(segment_embedding_gather.input[1])
  392. embed_node_inputs = [
  393. input_ids,
  394. segment_ids,
  395. word_embedding_gather.input[0],
  396. position_embedding_gather.input[0],
  397. segment_embedding_gather.input[0],
  398. gamma,
  399. beta,
  400. ]
  401. else: # no segment embedding
  402. embed_node_inputs = [
  403. input_ids,
  404. "",
  405. word_embedding_gather.input[0],
  406. position_embedding_gather.input[0],
  407. "",
  408. gamma,
  409. beta,
  410. ]
  411. if position_ids is not None:
  412. # Adding an empty input for mask before position_ids
  413. embed_node_inputs.append("")
  414. position_ids, _ = self.cast_to_int32(position_ids)
  415. embed_node_inputs.append(position_ids)
  416. embed_node_outputs = [node_name + "_output", node_name + "_dummy_mask_index"]
  417. if embedding_sum_output:
  418. name = embedding_sum_name if embedding_sum_name is not None else node_name + "_embedding_sum"
  419. embed_node_outputs.append(name)
  420. embed_node = helper.make_node(
  421. "EmbedLayerNormalization",
  422. embed_node_inputs,
  423. outputs=embed_node_outputs,
  424. name=node_name,
  425. )
  426. embed_node.domain = "com.microsoft"
  427. # Pass attribute "epsilon" from normalize node to EmbedLayerNormalization.
  428. for att in layernorm.attribute:
  429. if att.name == "epsilon":
  430. embed_node.attribute.extend([att])
  431. # Set default value to 1e-12 if no attribute is found.
  432. # OnnxRuntime 1.2.0 or older has no epsilon attribute. The optimized model can only work for 1.3.0 or later.
  433. if len(embed_node.attribute) == 0:
  434. embed_node.attribute.extend([helper.make_attribute("epsilon", 1.0e-12)])
  435. # Make sure new EmbedLayerNormalization node is the last one in self.nodes_to_add.
  436. nodes_to_add.append(embed_node)
  437. for node in nodes_to_add:
  438. self.node_name_to_graph_name[node.name] = self.this_graph_name
  439. self.nodes_to_add.extend(nodes_to_add)
  440. self.embed_node = embed_node
  441. return embed_node
  442. def finish_fusion(self, layernorm, embed_node):
  443. self.model.replace_input_of_all_nodes(layernorm.output[0], embed_node.output[0])
  444. # use prune graph to remove nodes that is not needed
  445. self.prune_graph = True
  446. def is_skip_layer_norm_with_sum_output(self, node):
  447. return (node.op_type == "SkipLayerNormalization") and len(node.output) > 3 and len(node.output[3]) > 0
  448. def fuse_gpt2(
  449. self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node, optional_segment_gather=None
  450. ):
  451. # graph checks
  452. # gpt2 has optional segment embedding, subgraph pattern is like
  453. # input_ids position_ids
  454. # | |
  455. # token_ids Gather Gather
  456. # | \ /
  457. # Gather (optional) Add _ _ _ _ _
  458. # \ | |
  459. # LayerNormalization |
  460. # | |
  461. # Attention |
  462. # | |
  463. # Matmul |
  464. # | /
  465. # Add /
  466. # \ /
  467. # Add
  468. two_gather = self.match_two_gather(add_before_layernorm)
  469. if two_gather is None:
  470. return False
  471. word_embedding_gather, position_embedding_gather = two_gather
  472. input_ids = word_embedding_gather.input[1]
  473. position_ids = position_embedding_gather.input[1]
  474. if not self.check_attention_subgraph(layernorm, input_name_to_nodes, is_distil_bert=False):
  475. return False
  476. if not self.check_embedding(word_embedding_gather, None, position_embedding_gather):
  477. return False
  478. # If layernorm node is SkipLayerNormalization, we need look at its optional fourth output.
  479. # If the add_before_layernorm node is an Add node, then the add_output output is the first output of this node.
  480. # If the add_before_layernorm node is a SkipLayerNormalization node, then the add_output output
  481. # is the (optional) fourth index output of this node.
  482. # When add_before_layernorm is SkipLayerNormalization, add_before_layernorm and layernorm are same node.
  483. if layernorm.op_type == "SkipLayerNormalization":
  484. need_embedding_sum_output = self.is_skip_layer_norm_with_sum_output(layernorm)
  485. sum_output_index = 3
  486. node_with_sum_output = layernorm
  487. sum_output = layernorm.output[3] if need_embedding_sum_output else None
  488. is_sum_graph_output = (sum_output is not None) and (self.model.find_graph_output(sum_output) is not None)
  489. else: # layernorm.op_type == "LayerNormalization"
  490. node_with_sum_output = add_before_layernorm
  491. sum_output_index = 0 if add_before_layernorm.op_type == "Add" else 3
  492. sum_output = (
  493. add_before_layernorm.output[sum_output_index]
  494. if len(add_before_layernorm.output) > sum_output_index
  495. else None
  496. )
  497. is_sum_graph_output = (sum_output is not None) and (self.model.find_graph_output(sum_output) is not None)
  498. is_sum_used_by_multiple_nodes = (
  499. sum_output and (sum_output in input_name_to_nodes) and len(input_name_to_nodes[sum_output]) > 1
  500. )
  501. need_embedding_sum_output = (sum_output is not None) and (
  502. add_before_layernorm.op_type != "Add" or is_sum_graph_output or is_sum_used_by_multiple_nodes
  503. )
  504. # make the fused node
  505. embed_node = self.create_fused_node(
  506. input_ids,
  507. layernorm,
  508. word_embedding_gather,
  509. position_embedding_gather,
  510. optional_segment_gather,
  511. position_ids,
  512. embedding_sum_output=need_embedding_sum_output,
  513. embedding_sum_name=sum_output if is_sum_graph_output else None,
  514. )
  515. if need_embedding_sum_output:
  516. node_with_sum_output.output[sum_output_index] = "_no_use__to_be_removed_"
  517. if not is_sum_graph_output:
  518. self.model.replace_input_of_all_nodes(sum_output, embed_node.output[2])
  519. self.finish_fusion(layernorm, embed_node)
  520. return True
  521. def fuse_distilbert(self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node):
  522. """Fuse embedding layer for DistilBert
  523. Args:
  524. layernorm (NodeProto): node of LayerNormalization or SkipLayerNormalization
  525. add_before_layernorm (NodeProto): the Add node before LayerNormalization, or the SkipLayerNormalization itself
  526. input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes
  527. output_name_to_node (Dict[str, List[NodeProto]]): map from output name to nodes
  528. """
  529. # DistilBert has no segment embedding, subgraph pattern is like
  530. # input_ids
  531. # | \
  532. # | (position_embedding_subgraph)
  533. # | |
  534. # Gather Gather
  535. # \ /
  536. # Add
  537. # |
  538. # LayerNormalization
  539. two_gather = self.match_two_gather(add_before_layernorm)
  540. if two_gather is None:
  541. return False
  542. word_embedding_gather, position_embedding_gather = two_gather
  543. input_ids = word_embedding_gather.input[1]
  544. if not self.check_attention_subgraph(layernorm, input_name_to_nodes, is_distil_bert=True):
  545. return False
  546. if not self.match_position_embedding(position_embedding_gather, input_ids, output_name_to_node):
  547. return False
  548. if not self.check_embedding(word_embedding_gather, None, position_embedding_gather):
  549. return False
  550. embed_node = self.create_fused_node(
  551. input_ids, layernorm, word_embedding_gather, position_embedding_gather, None
  552. )
  553. self.finish_fusion(layernorm, embed_node)
  554. return True
  555. def fuse_bert(self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node):
  556. """Fuse embedding layer for Bert
  557. Args:
  558. layernorm (NodeProto): node of LayerNormalization or SkipLayerNormalization
  559. add_before_layernorm (NodeProto): the Add node before LayerNormalization, or the SkipLayerNormalization itself
  560. input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes
  561. output_name_to_node (Dict[str, List[NodeProto]]): map from output name to nodes
  562. """
  563. add_2_gather = self.model.match_parent_path(add_before_layernorm, ["Add"], [0])
  564. if add_2_gather is None:
  565. return False
  566. two_gather = self.match_two_gather(add_2_gather[0])
  567. if two_gather is None:
  568. return False
  569. word_embedding_gather, segment_embedding_gather = two_gather
  570. input_ids = word_embedding_gather.input[1]
  571. if not self.check_attention_subgraph(layernorm, input_name_to_nodes, is_distil_bert=False):
  572. return False
  573. position_embedding_path = self.model.match_parent_path(add_before_layernorm, ["Gather"], [1])
  574. if position_embedding_path is None:
  575. return False
  576. position_embedding_gather = position_embedding_path[0]
  577. if not self.match_position_embedding(position_embedding_gather, input_ids, output_name_to_node):
  578. if not self.match_position_embedding(segment_embedding_gather, input_ids, output_name_to_node):
  579. return False
  580. # position and segment are switched
  581. temp = segment_embedding_gather
  582. segment_embedding_gather = position_embedding_gather
  583. position_embedding_gather = temp
  584. if not self.check_embedding(word_embedding_gather, segment_embedding_gather, position_embedding_gather):
  585. return False
  586. embed_node = self.create_fused_node(
  587. input_ids,
  588. layernorm,
  589. word_embedding_gather,
  590. position_embedding_gather,
  591. segment_embedding_gather,
  592. )
  593. self.finish_fusion(layernorm, embed_node)
  594. return True
  595. def fuse(self, node, input_name_to_nodes, output_name_to_node):
  596. first_add_path = self.model.match_parent_path(node, ["Add"], [0])
  597. if node.op_type == "LayerNormalization":
  598. if first_add_path is None:
  599. return
  600. add_before_layernorm = first_add_path[0]
  601. optional_segment_gather = None
  602. else: # SkipLayerNormalization
  603. gather_0_path = self.model.match_parent_path(node, ["Gather"], [0])
  604. gather_1_path = self.model.match_parent_path(node, ["Gather"], [1])
  605. if gather_0_path is None and gather_1_path is not None:
  606. if first_add_path is None:
  607. return
  608. add_before_layernorm = first_add_path[0]
  609. optional_segment_gather = gather_1_path[0]
  610. elif gather_0_path is not None and gather_1_path is None:
  611. first_add_path = self.model.match_parent_path(node, ["Add"], [1])
  612. if first_add_path is None:
  613. return
  614. add_before_layernorm = first_add_path[0]
  615. optional_segment_gather = gather_0_path[0]
  616. else:
  617. add_before_layernorm = node # Add is fused into SkipLayerNormalization
  618. optional_segment_gather = None
  619. if self.fuse_gpt2(
  620. node, add_before_layernorm, input_name_to_nodes, output_name_to_node, optional_segment_gather
  621. ):
  622. return
  623. if self.fuse_distilbert(node, add_before_layernorm, input_name_to_nodes, output_name_to_node):
  624. return
  625. if self.fuse_bert(node, add_before_layernorm, input_name_to_nodes, output_name_to_node):
  626. return
  627. class FusionEmbedLayerNormalization(FusionEmbedLayerNoMask):
  628. def __init__(self, model: OnnxModel, use_mask_index=False):
  629. super().__init__(model, "with mask")
  630. self.use_mask_index = use_mask_index
  631. def replace_mask(self, mask_int32, attention_nodes):
  632. # Inputs of EmbedLayerNorm: input_ids, segment_ids (optional), word_embedding, position_embedding,
  633. # segment_embedding (optional), gamma, beta, mask (optional), position_ids (optional)
  634. embed_node = self.embed_node
  635. if len(embed_node.input) == 7:
  636. embed_node.input.append(mask_int32)
  637. logger.debug("append mask to %s", embed_node.name)
  638. elif len(embed_node.input) > 7 and not embed_node.input[7]:
  639. embed_node.input[7] = mask_int32
  640. logger.debug("replace mask in %s", embed_node.name)
  641. else:
  642. logger.debug("skip mask in %s", embed_node.name)
  643. return
  644. for attention_node in attention_nodes:
  645. logger.debug("update mask_index in %s", attention_node.name)
  646. if attention_node.op_type == "Attention":
  647. attention_node.input[3] = embed_node.output[1]
  648. elif attention_node.op_type == "MultiHeadAttention":
  649. attention_node.input[4] = embed_node.output[1]
  650. def fuse(self, node, input_name_to_nodes, output_name_to_node):
  651. # Reset attention and embed_node so that we know fusion is successful when they are not None.
  652. self.attention = None
  653. self.cross_attention = None
  654. self.embed_node = None
  655. super().fuse(node, input_name_to_nodes, output_name_to_node)
  656. if self.embed_node is None:
  657. return
  658. if not self.use_mask_index:
  659. logger.debug("--use_mask_index is not set: EmbedLayerNormalization will not have mask")
  660. self.increase_counter("EmbedLayerNormalization(no mask)")
  661. return
  662. if self.attention is None and self.cross_attention is None:
  663. logger.debug("EmbedLayerNormalization will not have mask since attention node is not found")
  664. self.increase_counter("EmbedLayerNormalization(no mask)")
  665. return
  666. if self.attention:
  667. mask_int32 = self.attention.input[3]
  668. else:
  669. mask_int32 = self.cross_attention.input[4]
  670. children_nodes = input_name_to_nodes[mask_int32]
  671. if self.model.find_graph_input(mask_int32):
  672. attention_nodes = [node for node in children_nodes if node.op_type in ["Attention", "MultiHeadAttention"]]
  673. self.replace_mask(mask_int32, attention_nodes)
  674. self.increase_counter("EmbedLayerNormalization(with mask)")
  675. return
  676. if mask_int32 not in output_name_to_node:
  677. logger.debug("EmbedLayerNormalization will not have mask since %s is not a node output", mask_int32)
  678. self.increase_counter("EmbedLayerNormalization(no mask)")
  679. return
  680. node = output_name_to_node[mask_int32]
  681. if node.op_type in ["ReduceSum", "Cast"]:
  682. attention_nodes = [node for node in children_nodes if node.op_type in ["Attention", "MultiHeadAttention"]]
  683. if node.op_type == "ReduceSum":
  684. mask_int32 = node.input[0]
  685. if len(children_nodes) == len(attention_nodes):
  686. self.nodes_to_remove.append(node)
  687. self.replace_mask(mask_int32, attention_nodes)
  688. self.increase_counter("EmbedLayerNormalization(with mask)")