onnx_model_bert.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from convert_to_packing_mode import PackingMode
  7. from fusion_attention import AttentionMask, FusionAttention
  8. from fusion_bart_attention import FusionBartAttention
  9. from fusion_biasgelu import FusionBiasGelu
  10. from fusion_constant_fold import FusionConstantFold
  11. from fusion_embedlayer import FusionEmbedLayerNormalization
  12. from fusion_fastgelu import FusionFastGelu
  13. from fusion_gelu import FusionGelu
  14. from fusion_gelu_approximation import FusionGeluApproximation
  15. from fusion_gemmfastgelu import FusionGemmFastGelu
  16. from fusion_layernorm import FusionLayerNormalization, FusionLayerNormalizationTF
  17. from fusion_options import AttentionMaskFormat, FusionOptions
  18. from fusion_qordered_attention import FusionQOrderedAttention
  19. from fusion_qordered_gelu import FusionQOrderedGelu
  20. from fusion_qordered_layernorm import FusionQOrderedLayerNormalization
  21. from fusion_qordered_matmul import FusionQOrderedMatMul
  22. from fusion_quickgelu import FusionQuickGelu
  23. from fusion_reshape import FusionReshape
  24. from fusion_rotary_attention import FusionRotaryEmbeddings
  25. from fusion_shape import FusionShape
  26. from fusion_simplified_layernorm import FusionSimplifiedLayerNormalization, FusionSkipSimplifiedLayerNormalization
  27. from fusion_skiplayernorm import FusionBiasSkipLayerNormalization, FusionSkipLayerNormalization
  28. from fusion_utils import FusionUtils
  29. from onnx import ModelProto, TensorProto, helper
  30. from onnx_model import OnnxModel
  31. logger = getLogger(__name__)
  32. class BertOnnxModel(OnnxModel):
  33. def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
  34. """Initialize BERT ONNX Model.
  35. Args:
  36. model (ModelProto): the ONNX model
  37. num_heads (int, optional): number of attention heads. Defaults to 0 (detect the parameter automatically).
  38. hidden_size (int, optional): hidden dimension. Defaults to 0 (detect the parameter automatically).
  39. """
  40. assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0)
  41. super().__init__(model)
  42. self.num_heads = num_heads
  43. self.hidden_size = hidden_size
  44. self.attention_mask = AttentionMask(self)
  45. self.attention_fusion = FusionAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
  46. self.qordered_attention_fusion = FusionQOrderedAttention(
  47. self, self.hidden_size, self.num_heads, self.attention_mask
  48. )
  49. self.utils = FusionUtils(self)
  50. def fuse_constant_fold(self):
  51. fusion = FusionConstantFold(self)
  52. fusion.apply()
  53. def fuse_attention(self):
  54. self.attention_fusion.apply()
  55. # Only relevant in models with Q-DQ nodes
  56. self.qordered_attention_fusion.apply()
  57. def fuse_gelu(self):
  58. fusion = FusionGelu(self)
  59. fusion.apply()
  60. fusion = FusionFastGelu(self)
  61. fusion.apply()
  62. fusion = FusionQuickGelu(self)
  63. fusion.apply()
  64. # Only relevant in models with Q-DQ nodes
  65. fusion = FusionQOrderedGelu(self)
  66. fusion.apply()
  67. def fuse_bias_gelu(self, is_fastgelu):
  68. fusion = FusionBiasGelu(self, is_fastgelu)
  69. fusion.apply()
  70. def gelu_approximation(self):
  71. fusion = FusionGeluApproximation(self)
  72. fusion.apply()
  73. def fuse_gemm_fast_gelu(self):
  74. fusion = FusionGemmFastGelu(self)
  75. fusion.apply()
  76. def fuse_add_bias_skip_layer_norm(self):
  77. fusion = FusionBiasSkipLayerNormalization(self)
  78. fusion.apply()
  79. def fuse_reshape(self):
  80. fusion = FusionReshape(self)
  81. fusion.apply()
  82. def fuse_shape(self):
  83. fusion = FusionShape(self)
  84. fusion.apply()
  85. def fuse_embed_layer(self, use_mask_index):
  86. fusion = FusionEmbedLayerNormalization(self, use_mask_index)
  87. fusion.apply()
  88. def fuse_layer_norm(self):
  89. fusion = FusionLayerNormalization(self)
  90. fusion.apply()
  91. fusion = FusionLayerNormalizationTF(self)
  92. fusion.apply()
  93. # Only relevant in models with Q-DQ nodes
  94. fusion = FusionQOrderedLayerNormalization(self)
  95. fusion.apply()
  96. def fuse_simplified_layer_norm(self):
  97. fusion = FusionSimplifiedLayerNormalization(self)
  98. fusion.apply()
  99. def fuse_skip_layer_norm(self, shape_infer=True):
  100. fusion = FusionSkipLayerNormalization(self, shape_infer=shape_infer)
  101. fusion.apply()
  102. def fuse_skip_simplified_layer_norm(self):
  103. fusion = FusionSkipSimplifiedLayerNormalization(self)
  104. fusion.apply()
  105. def fuse_rotary_embeddings(self):
  106. fusion = FusionRotaryEmbeddings(self)
  107. fusion.apply()
  108. # Remove non-MS domain functions
  109. rot_emb_nodes = list(
  110. filter(
  111. lambda node: node.op_type == "RotaryEmbedding" and node.domain != "com.microsoft",
  112. self.model.graph.node,
  113. )
  114. )
  115. non_ms_domains_to_keep = {node.domain for node in rot_emb_nodes}
  116. i = 0
  117. while i < len(self.model.functions):
  118. fn = self.model.functions[i]
  119. if "RotaryEmbedding" in fn.name and fn.domain not in non_ms_domains_to_keep:
  120. self.model.functions.remove(fn)
  121. else:
  122. i += 1
  123. # Only relevant in models with Q-DQ nodes
  124. def fuse_qordered_mamtul(self):
  125. fusion = FusionQOrderedMatMul(self)
  126. fusion.apply()
  127. def get_graph_inputs_from_node_type(self, op_type: str, input_indices: list[int], casted: bool):
  128. """
  129. Get graph inputs that feed into node type (like EmbedLayerNormalization or Attention).
  130. Returns a list of the graph input names based on the filter whether it is casted or not.
  131. """
  132. graph_inputs = []
  133. output_name_to_node = self.output_name_to_node()
  134. nodes = self.get_nodes_by_op_type(op_type)
  135. for node in nodes:
  136. bert_inputs = [node.input[i] for i in input_indices if i < len(node.input)]
  137. for bert_input in bert_inputs:
  138. if self.find_graph_input(bert_input):
  139. if not casted:
  140. graph_inputs.append(bert_input)
  141. elif bert_input in output_name_to_node:
  142. parent = output_name_to_node[bert_input]
  143. if parent.op_type == "Cast" and self.find_graph_input(parent.input[0]) is not None:
  144. if casted:
  145. graph_inputs.append(parent.input[0])
  146. return graph_inputs
  147. def get_graph_inputs_from_fused_nodes(self, casted: bool):
  148. inputs = self.get_graph_inputs_from_node_type("EmbedLayerNormalization", [0, 1, 7], casted)
  149. inputs += self.get_graph_inputs_from_node_type("Attention", [3], casted)
  150. return inputs
  151. def change_graph_inputs_to_int32(self):
  152. """Change data type of all graph inputs to int32 type, and add Cast node if needed."""
  153. graph = self.graph()
  154. add_cast_count = 0
  155. remove_cast_count = 0
  156. for graph_input in graph.input:
  157. new_node, removed_nodes = self.change_graph_input_type(graph_input, TensorProto.INT32)
  158. if new_node:
  159. add_cast_count += 1
  160. remove_cast_count += len(removed_nodes)
  161. logger.info(
  162. f"Graph inputs are changed to int32. Added {add_cast_count} Cast nodes, and removed {remove_cast_count} Cast nodes."
  163. )
  164. def use_dynamic_axes(self, dynamic_batch_dim="batch_size", dynamic_seq_len="max_seq_len"):
  165. """
  166. Update input and output shape to use dynamic axes.
  167. """
  168. bert_graph_inputs = self.get_graph_inputs_from_fused_nodes(
  169. casted=True
  170. ) + self.get_graph_inputs_from_fused_nodes(casted=False)
  171. for input in self.model.graph.input:
  172. if input.name in bert_graph_inputs:
  173. dim_proto = input.type.tensor_type.shape.dim[0]
  174. dim_proto.dim_param = dynamic_batch_dim
  175. if dynamic_seq_len is not None:
  176. dim_proto = input.type.tensor_type.shape.dim[1]
  177. dim_proto.dim_param = dynamic_seq_len
  178. for output in self.model.graph.output:
  179. dim_proto = output.type.tensor_type.shape.dim[0]
  180. dim_proto.dim_param = dynamic_batch_dim
  181. def preprocess(self):
  182. self.adjust_reshape_and_expand()
  183. return
  184. def adjust_reshape_and_expand(self):
  185. nodes_to_remove = []
  186. for node in self.nodes():
  187. if node.op_type == "Reshape":
  188. # Clean up unnecessary reshape nodes.
  189. # Find reshape nodes with no actually data in "shape" attribute and remove.
  190. reshape_shape = self.get_constant_value(node.input[1])
  191. if reshape_shape is not None and reshape_shape.size == 0:
  192. nodes_to_remove.extend([node])
  193. self.replace_input_of_all_nodes(node.output[0], node.input[0])
  194. continue
  195. # Find path "Slice" -> "Reshape" -> "Expand" -> "Expand" -> current "Reshape", simplify the graph by
  196. # changing current reshape's input to output of slice.
  197. reshape_path = self.match_parent_path(
  198. node,
  199. ["Expand", "Expand", "Reshape", "Slice"],
  200. [0, 0, 0, 0],
  201. self.output_name_to_node(),
  202. )
  203. if reshape_path is not None:
  204. expand_node = reshape_path[-3]
  205. expand_shape_value = self.get_constant_value(expand_node.input[1])
  206. reshape_before_expand = reshape_path[-2]
  207. shape_value = self.get_constant_value(reshape_before_expand.input[1])
  208. slice_node = reshape_path[-1]
  209. if (
  210. expand_shape_value is not None
  211. and shape_value is not None
  212. and len(expand_shape_value) == 2
  213. and len(shape_value) == 1
  214. and expand_shape_value[1] == shape_value[0]
  215. ):
  216. node.input[0] = slice_node.output[0]
  217. if nodes_to_remove:
  218. self.remove_nodes(nodes_to_remove)
  219. logger.info(f"Removed Reshape and Expand count: {len(nodes_to_remove)}")
  220. def clean_graph(self):
  221. output_name_to_node = self.output_name_to_node()
  222. nodes_to_remove = []
  223. for node in self.nodes():
  224. # Before:
  225. # input_ids --> Shape --> Gather(indices=0) --> Unsqueeze ------+
  226. # | |
  227. # | v
  228. # +----> Shape --> Gather(indices=1) --> Unsqueeze---> Concat --> ConstantOfShape -->Cast --> EmbedLayerNormaliation/ReduceSum
  229. # After:
  230. # input_ids --> Shape --> ConstantOfShape -->Cast --> EmbedLayerNormaliation/ReduceSum
  231. # TODO: merge ConstantOfShape -->Cast to ConstantOfShape (need update the data type of value)
  232. op_input_id = {"EmbedLayerNormalization": 1, "ReduceSum": 0, "Attention": 3}
  233. if node.op_type in op_input_id:
  234. i = op_input_id[node.op_type]
  235. parent_nodes = self.match_parent_path(
  236. node,
  237. [
  238. "Cast",
  239. "ConstantOfShape",
  240. "Concat",
  241. "Unsqueeze",
  242. "Gather",
  243. "Shape",
  244. ],
  245. [i, 0, 0, 0, 0, 0],
  246. output_name_to_node,
  247. )
  248. if parent_nodes is not None:
  249. (
  250. cast,
  251. constantOfShape, # noqa: N806
  252. concat,
  253. unsqueeze,
  254. gather,
  255. shape,
  256. ) = parent_nodes
  257. if shape.input[0] == self.graph().input[0].name:
  258. constantOfShape.input[0] = shape.output[0]
  259. output_name_to_node = self.output_name_to_node()
  260. if node.op_type == "Attention":
  261. # Before:
  262. # input_ids --> Shape -->ConstantOfShape -->Cast --> ReduceSum --> Attention
  263. # After:
  264. # remove this path, and remove the optional mask_index input of Attention node.
  265. parent_nodes = self.match_parent_path(
  266. node,
  267. ["ReduceSum", "Cast", "ConstantOfShape", "Shape"],
  268. [3, 0, 0, 0],
  269. output_name_to_node,
  270. )
  271. if parent_nodes is not None:
  272. if parent_nodes[-1].input[0] == self.graph().input[0].name:
  273. attention_node = helper.make_node(
  274. "Attention",
  275. inputs=node.input[0 : len(node.input) - 1],
  276. outputs=node.output,
  277. name=node.name + "_remove_mask",
  278. )
  279. attention_node.domain = "com.microsoft"
  280. attention_node.attribute.extend([helper.make_attribute("num_heads", self.num_heads)])
  281. self.add_node(attention_node, self.get_graph_by_node(attention_node).name)
  282. nodes_to_remove.append(node)
  283. self.remove_nodes(nodes_to_remove)
  284. def postprocess(self):
  285. self.clean_graph()
  286. self.prune_graph()
  287. def optimize(self, options: FusionOptions | None = None, add_dynamic_axes: bool = False):
  288. if (options is not None) and not options.enable_shape_inference:
  289. self.disable_shape_inference()
  290. self.utils.remove_identity_nodes()
  291. # Remove cast nodes that having same data type of input and output based on symbolic shape inference.
  292. self.utils.remove_useless_cast_nodes()
  293. # Apply any missed constant-folding model optimizations (e.g. for Dynamo-exported models)
  294. self.fuse_constant_fold()
  295. if (options is None) or options.enable_layer_norm:
  296. self.fuse_layer_norm()
  297. self.fuse_simplified_layer_norm()
  298. if (options is None) or options.enable_gelu:
  299. self.fuse_gelu()
  300. self.preprocess()
  301. self.fuse_reshape()
  302. if (options is None) or options.enable_skip_layer_norm:
  303. self.fuse_skip_layer_norm(options.enable_shape_inference)
  304. self.fuse_skip_simplified_layer_norm()
  305. if (options is None) or options.enable_rotary_embeddings:
  306. self.fuse_rotary_embeddings()
  307. if options is not None:
  308. self.attention_mask.set_mask_format(options.attention_mask_format)
  309. if options.use_multi_head_attention and not isinstance(self.attention_fusion, FusionBartAttention):
  310. self.attention_fusion = FusionAttention(
  311. self,
  312. self.hidden_size,
  313. self.num_heads,
  314. self.attention_mask,
  315. options.use_multi_head_attention,
  316. )
  317. if (options is None) or options.enable_attention:
  318. self.fuse_attention()
  319. # Perform the MatMul fusion after the Attention fusion as we do not
  320. # want to fuse the MatMuls inside the Attention subgraphs
  321. if (options is None) or options.enable_qordered_matmul:
  322. self.fuse_qordered_mamtul()
  323. self.fuse_shape()
  324. if (options is None) or options.enable_embed_layer_norm:
  325. use_mask_index = options.attention_mask_format == AttentionMaskFormat.MaskIndexEnd
  326. self.fuse_embed_layer(use_mask_index)
  327. # Remove reshape nodes that having same shape of input and output based on symbolic shape inference.
  328. self.utils.remove_useless_reshape_nodes()
  329. self.postprocess()
  330. # Bias fusion is done after postprocess to avoid extra Reshape between bias and Gelu/FastGelu/SkipLayerNormalization
  331. if (options is None) or options.enable_bias_gelu:
  332. # Fuse Gelu and Add Bias before it.
  333. self.fuse_bias_gelu(is_fastgelu=True)
  334. self.fuse_bias_gelu(is_fastgelu=False)
  335. if (options is None) or options.enable_bias_skip_layer_norm:
  336. # Fuse SkipLayerNormalization and Add Bias before it.
  337. self.fuse_add_bias_skip_layer_norm()
  338. if options is not None and options.enable_gelu_approximation:
  339. self.gelu_approximation()
  340. if options is not None and options.enable_gemm_fast_gelu:
  341. self.fuse_gemm_fast_gelu()
  342. self.remove_unused_constant()
  343. # Use symbolic batch dimension in input and output.
  344. if add_dynamic_axes:
  345. self.use_dynamic_axes()
  346. logger.info(f"opset version: {self.get_opset_version()}")
  347. def get_fused_operator_statistics(self):
  348. """
  349. Returns node count of fused operators.
  350. """
  351. op_count = {}
  352. ops = [
  353. "EmbedLayerNormalization",
  354. "Attention",
  355. "MultiHeadAttention",
  356. "Gelu",
  357. "FastGelu",
  358. "BiasGelu",
  359. "GemmFastGelu",
  360. "LayerNormalization",
  361. "SimplifiedLayerNormalization",
  362. "SkipLayerNormalization",
  363. "SkipSimplifiedLayerNormalization",
  364. "RotaryEmbedding",
  365. ]
  366. q_ops = [
  367. "QOrderedAttention",
  368. "QOrderedGelu",
  369. "QOrderedLayerNormalization",
  370. "QOrderedMatMul",
  371. ]
  372. for op in ops + q_ops:
  373. nodes = self.get_nodes_by_op_type(op)
  374. op_count[op] = len(nodes)
  375. logger.info(f"Optimized operators: {op_count}")
  376. return op_count
  377. def is_fully_optimized(self, fused_op_count=None):
  378. """
  379. Returns True when the model is fully optimized.
  380. """
  381. if fused_op_count is None:
  382. fused_op_count = self.get_fused_operator_statistics()
  383. def op_count(op_name: str):
  384. return fused_op_count.get(op_name) or 0
  385. embed = op_count("EmbedLayerNormalization")
  386. attention = op_count("Attention") + op_count("MultiHeadAttention") + op_count("QOrderedAttention")
  387. gelu = op_count("Gelu") + op_count("BiasGelu") + op_count("FastGelu")
  388. layer_norm = op_count("LayerNormalization") + op_count("SkipLayerNormalization")
  389. simple_layer_norm = op_count("SimplifiedLayerNormalization") + op_count("SkipSimplifiedLayerNormalization")
  390. is_perfect = (
  391. (embed > 0)
  392. and (attention > 0)
  393. and (attention == gelu)
  394. and ((layer_norm >= 2 * attention) or (simple_layer_norm >= 2 * attention))
  395. )
  396. if layer_norm == 0:
  397. logger.debug("Layer Normalization not fused")
  398. if simple_layer_norm == 0:
  399. logger.debug("Simple Layer Normalization not fused")
  400. if gelu == 0:
  401. logger.debug("Gelu (or FastGelu) not fused")
  402. if embed == 0:
  403. logger.debug("EmbedLayerNormalization not fused")
  404. if attention == 0:
  405. logger.warning("Attention (or MultiHeadAttention) not fused")
  406. return is_perfect
  407. def convert_to_packing_mode(self, use_symbolic_shape_infer: bool = False):
  408. packing_mode = PackingMode(self)
  409. packing_mode.convert(use_symbolic_shape_infer)