fusion_gelu.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from fusion_base import Fusion
  7. from onnx import helper
  8. from onnx_model import OnnxModel
  9. logger = getLogger(__name__)
  10. class FusionGelu(Fusion):
  11. def __init__(self, model: OnnxModel):
  12. super().__init__(model, "Gelu", "Erf")
  13. def fuse(self, erf_node, input_name_to_nodes: dict, output_name_to_node: dict):
  14. if self.fuse_1(erf_node, input_name_to_nodes, output_name_to_node):
  15. return
  16. if self.fuse_2(erf_node, input_name_to_nodes, output_name_to_node):
  17. return
  18. self.fuse_3(erf_node, input_name_to_nodes, output_name_to_node)
  19. def fuse_1(self, erf_node, input_name_to_nodes: dict, output_name_to_node: dict) -> bool | None:
  20. """
  21. This pattern is from PyTorch model
  22. Fuse Gelu with Erf into one node:
  23. Pattern 1:
  24. +-------Mul(0.5)---------------------+
  25. | |
  26. | v
  27. [root] --> Div -----> Erf --> Add --> Mul -->
  28. (B=1.4142...) (1)
  29. Pattern 2:
  30. +------------------------------------+
  31. | |
  32. | v
  33. [root] --> Div -----> Erf --> Add --> Mul -->Mul -->
  34. (B=1.4142...) (1) (0.5)
  35. Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
  36. """
  37. if erf_node.output[0] not in input_name_to_nodes:
  38. return
  39. children = input_name_to_nodes[erf_node.output[0]]
  40. if len(children) != 1 or children[0].op_type != "Add":
  41. return
  42. add_after_erf = children[0]
  43. if not self.model.has_constant_input(add_after_erf, 1):
  44. return
  45. if add_after_erf.output[0] not in input_name_to_nodes:
  46. return
  47. children = input_name_to_nodes[add_after_erf.output[0]]
  48. if len(children) != 1 or children[0].op_type != "Mul":
  49. return
  50. mul_after_erf = children[0]
  51. div = self.model.match_parent(erf_node, "Div", 0, output_name_to_node)
  52. if div is None:
  53. return
  54. if self.model.find_constant_input(div, 1.4142, delta=0.001) != 1:
  55. return
  56. subgraph_input = div.input[0]
  57. another = 1 if mul_after_erf.input[0] == add_after_erf.output[0] else 0
  58. if subgraph_input == mul_after_erf.input[another]: # pattern 2
  59. children = input_name_to_nodes[mul_after_erf.output[0]]
  60. if len(children) != 1 or children[0].op_type != "Mul":
  61. return
  62. mul_half = children[0]
  63. if not self.model.has_constant_input(mul_half, 0.5):
  64. return
  65. subgraph_output = mul_half.output[0]
  66. else: # pattern 1
  67. mul_half = self.model.match_parent(mul_after_erf, "Mul", another, output_name_to_node)
  68. if mul_half is None:
  69. return
  70. if not self.model.has_constant_input(mul_half, 0.5):
  71. return
  72. if subgraph_input not in mul_half.input:
  73. return
  74. subgraph_output = mul_after_erf.output[0]
  75. subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul_half]
  76. if not self.model.is_safe_to_fuse_nodes(
  77. subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node
  78. ):
  79. return
  80. self.nodes_to_remove.extend(subgraph_nodes)
  81. fused_node = helper.make_node(
  82. "Gelu", inputs=[subgraph_input], outputs=[subgraph_output], name=self.model.create_node_name("Gelu")
  83. )
  84. fused_node.domain = "com.microsoft"
  85. self.nodes_to_add.append(fused_node)
  86. self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
  87. self.increase_counter("Gelu")
  88. return True
  89. def fuse_2(self, erf_node, input_name_to_nodes: dict, output_name_to_node: dict) -> bool | None:
  90. """
  91. This pattern is from Keras model
  92. Fuse Gelu with Erf into one node:
  93. +------------------------------------------+
  94. | |
  95. | v
  96. [root] --> Div -----> Erf --> Add --> Mul -->Mul
  97. (B=1.4142...) (A=1) (A=0.5)
  98. Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
  99. """
  100. if erf_node.output[0] not in input_name_to_nodes:
  101. return
  102. children = input_name_to_nodes[erf_node.output[0]]
  103. if len(children) != 1 or children[0].op_type != "Add":
  104. return
  105. add_after_erf = children[0]
  106. if not self.model.has_constant_input(add_after_erf, 1):
  107. return
  108. if add_after_erf.output[0] not in input_name_to_nodes:
  109. return
  110. children = input_name_to_nodes[add_after_erf.output[0]]
  111. if len(children) != 1 or children[0].op_type != "Mul":
  112. return
  113. mul_after_erf = children[0]
  114. if not self.model.has_constant_input(mul_after_erf, 0.5):
  115. return
  116. if mul_after_erf.output[0] not in input_name_to_nodes:
  117. return
  118. children = input_name_to_nodes[mul_after_erf.output[0]]
  119. if len(children) != 1 or children[0].op_type != "Mul":
  120. return
  121. mul = children[0]
  122. div = self.model.match_parent(erf_node, "Div", 0, output_name_to_node)
  123. if div is None:
  124. return
  125. sqrt_node = None
  126. if self.model.find_constant_input(div, 1.4142, delta=0.001) != 1:
  127. sqrt_node = self.model.match_parent(div, "Sqrt", 1, output_name_to_node)
  128. if sqrt_node is None:
  129. return
  130. if not self.model.has_constant_input(sqrt_node, 2.0):
  131. return
  132. root_node = self.model.get_parent(div, 0, output_name_to_node)
  133. if root_node is None:
  134. return
  135. if root_node.output[0] not in mul.input:
  136. return
  137. subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul]
  138. if sqrt_node:
  139. subgraph_nodes.append(sqrt_node)
  140. if not self.model.is_safe_to_fuse_nodes(
  141. subgraph_nodes, [mul.output[0]], input_name_to_nodes, output_name_to_node
  142. ):
  143. return
  144. self.nodes_to_remove.extend(subgraph_nodes)
  145. fused_node = helper.make_node(
  146. "Gelu", inputs=[root_node.output[0]], outputs=[mul.output[0]], name=self.model.create_node_name("Gelu")
  147. )
  148. fused_node.domain = "com.microsoft"
  149. self.nodes_to_add.append(fused_node)
  150. self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
  151. self.increase_counter("Gelu")
  152. return True
  153. def fuse_3(self, erf_node, input_name_to_nodes: dict, output_name_to_node: dict) -> bool | None:
  154. """
  155. This pattern is from TensorFlow model
  156. Fuse Gelu with Erf into one node:
  157. +----------------------------------------------+
  158. | |
  159. | v
  160. [root] --> Mul -----> Erf --> Add --> Mul -->Mul
  161. (A=0.7071067690849304) (B=1) (B=0.5)
  162. Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
  163. """
  164. if erf_node.output[0] not in input_name_to_nodes:
  165. return
  166. children = input_name_to_nodes[erf_node.output[0]]
  167. if len(children) != 1 or children[0].op_type != "Add":
  168. return
  169. add_after_erf = children[0]
  170. if not self.model.has_constant_input(add_after_erf, 1):
  171. return
  172. if add_after_erf.output[0] not in input_name_to_nodes:
  173. return
  174. children = input_name_to_nodes[add_after_erf.output[0]]
  175. if len(children) != 1 or children[0].op_type != "Mul":
  176. return
  177. mul_half = children[0]
  178. if not self.model.has_constant_input(mul_half, 0.5):
  179. return
  180. first_mul = self.model.match_parent(erf_node, "Mul", 0, output_name_to_node)
  181. if first_mul is None:
  182. return
  183. i = self.model.find_constant_input(first_mul, 0.7071067690849304, delta=0.001)
  184. if i < 0:
  185. return
  186. root_node = self.model.get_parent(first_mul, 0 if i == 1 else 1, output_name_to_node)
  187. if root_node is None:
  188. return
  189. if mul_half.output[0] not in input_name_to_nodes:
  190. return
  191. children = input_name_to_nodes[mul_half.output[0]]
  192. if len(children) != 1 or children[0].op_type != "Mul":
  193. return
  194. last_mul = children[0]
  195. if not (last_mul.input[0] == root_node.output[0] or last_mul.input[1] == root_node.output[0]):
  196. return
  197. subgraph_nodes = [first_mul, erf_node, add_after_erf, mul_half, last_mul]
  198. if not self.model.is_safe_to_fuse_nodes(
  199. subgraph_nodes,
  200. [last_mul.output[0]],
  201. input_name_to_nodes,
  202. output_name_to_node,
  203. ):
  204. return
  205. self.nodes_to_remove.extend(subgraph_nodes)
  206. fused_node = helper.make_node(
  207. "Gelu", inputs=[root_node.output[0]], outputs=[last_mul.output[0]], name=self.model.create_node_name("Gelu")
  208. )
  209. fused_node.domain = "com.microsoft"
  210. self.nodes_to_add.append(fused_node)
  211. self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
  212. self.increase_counter("Gelu")
  213. return True