fusion_gelu.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. from __future__ import annotations
  7. import onnx
  8. from ..onnx_model import ONNXModel
  9. from .fusion import Fusion
  10. class FusionGelu(Fusion):
  11. def __init__(self, model: ONNXModel):
  12. super().__init__(model, "Gelu", "Erf")
  13. def fuse(
  14. self,
  15. erf_node: onnx.NodeProto,
  16. input_name_to_nodes: dict[str, list[onnx.NodeProto]],
  17. output_name_to_node: dict[str, onnx.NodeProto],
  18. ):
  19. """
  20. Interface function that tries to fuse a node sequence containing an Erf node into a single
  21. Gelu node.
  22. """
  23. if (
  24. self.fuse_1(erf_node, input_name_to_nodes, output_name_to_node)
  25. or self.fuse_2(erf_node, input_name_to_nodes, output_name_to_node)
  26. or self.fuse_3(erf_node, input_name_to_nodes, output_name_to_node)
  27. ):
  28. self.model.set_opset_import("com.microsoft", 1)
  29. def fuse_1(
  30. self,
  31. erf_node: onnx.NodeProto,
  32. input_name_to_nodes: dict[str, list[onnx.NodeProto]],
  33. output_name_to_node: dict[str, onnx.NodeProto],
  34. ) -> bool:
  35. """
  36. This pattern is from PyTorch model
  37. Fuse Gelu with Erf into one node:
  38. Pattern 1:
  39. +-------Mul(0.5)---------------------+
  40. | |
  41. | v
  42. [root] --> Div -----> Erf --> Add --> Mul -->
  43. (B=1.4142...) (1)
  44. Pattern 2:
  45. +------------------------------------+
  46. | |
  47. | v
  48. [root] --> Div -----> Erf --> Add --> Mul -->Mul -->
  49. (B=1.4142...) (1) (0.5)
  50. Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
  51. """
  52. if erf_node.output[0] not in input_name_to_nodes:
  53. return False
  54. children = input_name_to_nodes[erf_node.output[0]]
  55. if len(children) != 1 or children[0].op_type != "Add":
  56. return False
  57. add_after_erf = children[0]
  58. if not self.has_constant_input(add_after_erf, 1):
  59. return False
  60. if add_after_erf.output[0] not in input_name_to_nodes:
  61. return False
  62. children = input_name_to_nodes[add_after_erf.output[0]]
  63. if len(children) != 1 or children[0].op_type != "Mul":
  64. return False
  65. mul_after_erf = children[0]
  66. div = self.match_parent(erf_node, "Div", 0, output_name_to_node)
  67. if div is None:
  68. return False
  69. if self.find_constant_input(div, 1.4142, delta=0.001) != 1:
  70. return False
  71. subgraph_input = div.input[0]
  72. another = 1 if mul_after_erf.input[0] == add_after_erf.output[0] else 0
  73. if subgraph_input == mul_after_erf.input[another]: # pattern 2
  74. children = input_name_to_nodes[mul_after_erf.output[0]]
  75. if len(children) != 1 or children[0].op_type != "Mul":
  76. return False
  77. mul_half = children[0]
  78. if not self.has_constant_input(mul_half, 0.5):
  79. return False
  80. subgraph_output = mul_half.output[0]
  81. else: # pattern 1
  82. mul_half = self.match_parent(mul_after_erf, "Mul", another, output_name_to_node)
  83. if mul_half is None:
  84. return False
  85. if not self.has_constant_input(mul_half, 0.5):
  86. return False
  87. if subgraph_input not in mul_half.input:
  88. return False
  89. subgraph_output = mul_after_erf.output[0]
  90. subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul_half]
  91. if not self.is_safe_to_fuse_nodes(subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node):
  92. return False
  93. self.nodes_to_remove.extend(subgraph_nodes)
  94. fused_node = onnx.helper.make_node(
  95. "Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[subgraph_output]
  96. )
  97. fused_node.domain = "com.microsoft"
  98. self.nodes_to_add.append(fused_node)
  99. return True
  100. def fuse_2(
  101. self,
  102. erf_node: onnx.NodeProto,
  103. input_name_to_nodes: dict[str, list[onnx.NodeProto]],
  104. output_name_to_node: dict[str, onnx.NodeProto],
  105. ) -> bool:
  106. """
  107. This pattern is from Keras model
  108. Fuse Gelu with Erf into one node:
  109. +------------------------------------------+
  110. | |
  111. | v
  112. [root] --> Div -----> Erf --> Add --> Mul -->Mul
  113. (B=1.4142...) (A=1) (A=0.5)
  114. Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
  115. """
  116. if erf_node.output[0] not in input_name_to_nodes:
  117. return False
  118. children = input_name_to_nodes[erf_node.output[0]]
  119. if len(children) != 1 or children[0].op_type != "Add":
  120. return False
  121. add_after_erf = children[0]
  122. if not self.has_constant_input(add_after_erf, 1):
  123. return False
  124. if add_after_erf.output[0] not in input_name_to_nodes:
  125. return False
  126. children = input_name_to_nodes[add_after_erf.output[0]]
  127. if len(children) != 1 or children[0].op_type != "Mul":
  128. return False
  129. mul_after_erf = children[0]
  130. if not self.has_constant_input(mul_after_erf, 0.5):
  131. return False
  132. if mul_after_erf.output[0] not in input_name_to_nodes:
  133. return False
  134. children = input_name_to_nodes[mul_after_erf.output[0]]
  135. if len(children) != 1 or children[0].op_type != "Mul":
  136. return False
  137. mul = children[0]
  138. div = self.match_parent(erf_node, "Div", 0, output_name_to_node)
  139. if div is None:
  140. return False
  141. sqrt_node = None
  142. if self.find_constant_input(div, 1.4142, delta=0.001) != 1:
  143. sqrt_node = self.match_parent(div, "Sqrt", 1, output_name_to_node)
  144. if sqrt_node is None:
  145. return False
  146. if not self.has_constant_input(sqrt_node, 2.0):
  147. return False
  148. subgraph_input = div.input[0]
  149. if subgraph_input not in mul.input:
  150. return False
  151. subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul]
  152. if sqrt_node:
  153. subgraph_nodes.append(sqrt_node)
  154. if not self.is_safe_to_fuse_nodes(subgraph_nodes, [mul.output[0]], input_name_to_nodes, output_name_to_node):
  155. return False
  156. self.nodes_to_remove.extend(subgraph_nodes)
  157. fused_node = onnx.helper.make_node(
  158. "Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[mul.output[0]]
  159. )
  160. fused_node.domain = "com.microsoft"
  161. self.nodes_to_add.append(fused_node)
  162. return True
  163. def fuse_3(
  164. self,
  165. erf_node: onnx.NodeProto,
  166. input_name_to_nodes: dict[str, list[onnx.NodeProto]],
  167. output_name_to_node: dict[str, onnx.NodeProto],
  168. ) -> bool:
  169. """
  170. This pattern is from TensorFlow model
  171. Fuse Gelu with Erf into one node:
  172. +----------------------------------------------+
  173. | |
  174. | v
  175. [root] --> Mul -----> Erf --> Add --> Mul -->Mul
  176. (A=0.7071067690849304) (B=1) (B=0.5)
  177. Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
  178. """
  179. if erf_node.output[0] not in input_name_to_nodes:
  180. return False
  181. children = input_name_to_nodes[erf_node.output[0]]
  182. if len(children) != 1 or children[0].op_type != "Add":
  183. return False
  184. add_after_erf = children[0]
  185. if not self.has_constant_input(add_after_erf, 1):
  186. return False
  187. if add_after_erf.output[0] not in input_name_to_nodes:
  188. return False
  189. children = input_name_to_nodes[add_after_erf.output[0]]
  190. if len(children) != 1 or children[0].op_type != "Mul":
  191. return False
  192. mul_half = children[0]
  193. if not self.has_constant_input(mul_half, 0.5):
  194. return False
  195. first_mul = self.match_parent(erf_node, "Mul", 0, output_name_to_node)
  196. if first_mul is None:
  197. return False
  198. i = self.find_constant_input(first_mul, 0.7071067690849304, delta=0.001)
  199. if i < 0:
  200. return False
  201. root_input_index = 1 - i
  202. subgraph_input = first_mul.input[root_input_index]
  203. if mul_half.output[0] not in input_name_to_nodes:
  204. return False
  205. children = input_name_to_nodes[mul_half.output[0]]
  206. if len(children) != 1 or children[0].op_type != "Mul":
  207. return False
  208. last_mul = children[0]
  209. if not (last_mul.input[0] == subgraph_input or last_mul.input[1] == subgraph_input):
  210. return False
  211. subgraph_nodes = [first_mul, erf_node, add_after_erf, mul_half, last_mul]
  212. if not self.is_safe_to_fuse_nodes(
  213. subgraph_nodes,
  214. [last_mul.output[0]],
  215. input_name_to_nodes,
  216. output_name_to_node,
  217. ):
  218. return False
  219. self.nodes_to_remove.extend(subgraph_nodes)
  220. fused_node = onnx.helper.make_node(
  221. "Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[last_mul.output[0]]
  222. )
  223. fused_node.domain = "com.microsoft"
  224. self.nodes_to_add.append(fused_node)
  225. return True