fusion_fastgelu.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from fusion_base import Fusion
  7. from onnx import helper
  8. from onnx_model import OnnxModel
  9. logger = getLogger(__name__)
  10. class FusionFastGelu(Fusion):
  11. def __init__(self, model: OnnxModel):
  12. super().__init__(model, "FastGelu", "Tanh")
  13. def fuse(self, tanh_node, input_name_to_nodes: dict, output_name_to_node: dict):
  14. if self.fuse_1(tanh_node, input_name_to_nodes, output_name_to_node):
  15. return
  16. if self.fuse_2(tanh_node, input_name_to_nodes, output_name_to_node):
  17. return
  18. if self.fuse_3(tanh_node, input_name_to_nodes, output_name_to_node):
  19. return
  20. if self.fuse_4(tanh_node, input_name_to_nodes, output_name_to_node):
  21. return
  22. def fuse_1(self, tanh_node, input_name_to_nodes, output_name_to_node) -> bool | None:
  23. """
  24. Fuse Gelu with tanh into one node:
  25. +---------------------------+
  26. | |
  27. | v
  28. [root] --> Pow --> Mul -----> Add --> Mul --> Tanh --> Add --> Mul
  29. | (Y=3) (B=0.0447...) (B=0.7978...) (B=1) ^
  30. | |
  31. +------> Mul(B=0.5)--------------------------------------------+
  32. Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
  33. """
  34. if tanh_node.output[0] not in input_name_to_nodes:
  35. return
  36. children = input_name_to_nodes[tanh_node.output[0]]
  37. if len(children) != 1 or children[0].op_type != "Add":
  38. return
  39. add_after_tanh = children[0]
  40. if not self.model.has_constant_input(add_after_tanh, 1.0):
  41. return
  42. if add_after_tanh.output[0] not in input_name_to_nodes:
  43. return
  44. children = input_name_to_nodes[add_after_tanh.output[0]]
  45. if len(children) != 1 or children[0].op_type != "Mul":
  46. return
  47. mul_after_tanh = children[0]
  48. mul_half = self.model.match_parent(mul_after_tanh, "Mul", None, output_name_to_node)
  49. if mul_half is None:
  50. return
  51. i = self.model.find_constant_input(mul_half, 0.5)
  52. if i < 0:
  53. return
  54. root_input = mul_half.input[0 if i == 1 else 1]
  55. # root_node could be None when root_input is graph input
  56. root_node = self.model.get_parent(mul_half, 0 if i == 1 else 1, output_name_to_node)
  57. mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node)
  58. if mul_before_tanh is None:
  59. return
  60. i = self.model.find_constant_input(mul_before_tanh, 0.7978, delta=0.0001)
  61. if i < 0:
  62. return
  63. add_before_tanh = self.model.match_parent(mul_before_tanh, "Add", 0 if i == 1 else 1, output_name_to_node)
  64. if add_before_tanh is None:
  65. return
  66. mul_after_pow = self.model.match_parent(
  67. add_before_tanh,
  68. "Mul",
  69. None,
  70. output_name_to_node,
  71. exclude=[root_node] if root_node else [],
  72. )
  73. if mul_after_pow is None:
  74. return
  75. i = self.model.find_constant_input(mul_after_pow, 0.0447, delta=0.0001)
  76. if i < 0:
  77. return
  78. pow = self.model.match_parent(mul_after_pow, "Pow", 0 if i == 1 else 1, output_name_to_node)
  79. if pow is None:
  80. return
  81. if not self.model.has_constant_input(pow, 3.0):
  82. return
  83. if pow.input[0] != root_input:
  84. return
  85. subgraph_nodes = [
  86. mul_after_tanh,
  87. mul_half,
  88. add_after_tanh,
  89. tanh_node,
  90. mul_before_tanh,
  91. add_before_tanh,
  92. mul_after_pow,
  93. pow,
  94. ]
  95. if not self.model.is_safe_to_fuse_nodes(
  96. subgraph_nodes,
  97. [mul_after_tanh.output[0]],
  98. input_name_to_nodes,
  99. output_name_to_node,
  100. ):
  101. return
  102. self.nodes_to_remove.extend(subgraph_nodes)
  103. fused_node = helper.make_node(
  104. "FastGelu",
  105. inputs=[root_input],
  106. outputs=mul_after_tanh.output,
  107. name=self.model.create_node_name("FastGelu"),
  108. )
  109. fused_node.domain = "com.microsoft"
  110. self.nodes_to_add.append(fused_node)
  111. self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
  112. return True
  113. def fuse_2(self, tanh_node, input_name_to_nodes: dict, output_name_to_node: dict) -> bool | None:
  114. """
  115. This pattern is from Tensorflow model.
  116. Fuse Gelu with tanh into one node:
  117. +---------------------------+
  118. | |
  119. | v
  120. [root] --> Pow --> Mul -----> Add --> Mul --> Tanh --> Add --> Mul(B=0.5)-->Mul-->
  121. | (Y=3) (B=0.0447...) (B=0.7978...) (B=1) ^
  122. | |
  123. +---------------------------------------------------------------------------+
  124. Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
  125. """
  126. if tanh_node.output[0] not in input_name_to_nodes:
  127. return
  128. children = input_name_to_nodes[tanh_node.output[0]]
  129. if len(children) != 1 or children[0].op_type != "Add":
  130. return
  131. add_after_tanh = children[0]
  132. if not self.model.has_constant_input(add_after_tanh, 1.0):
  133. return
  134. if add_after_tanh.output[0] not in input_name_to_nodes:
  135. return
  136. children = input_name_to_nodes[add_after_tanh.output[0]]
  137. if len(children) != 1 or children[0].op_type != "Mul":
  138. return
  139. mul_half = children[0]
  140. i = self.model.find_constant_input(mul_half, 0.5)
  141. if i < 0:
  142. return
  143. if mul_half.output[0] not in input_name_to_nodes:
  144. return
  145. children = input_name_to_nodes[mul_half.output[0]]
  146. if len(children) != 1 or children[0].op_type != "Mul":
  147. return
  148. mul_after_mul_half = children[0]
  149. # root_node could be None when root_input is graph input
  150. root_node = self.model.get_parent(
  151. mul_after_mul_half,
  152. 0 if mul_after_mul_half.input[1] == mul_half.output[0] else 1,
  153. output_name_to_node,
  154. )
  155. mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node)
  156. if mul_before_tanh is None:
  157. return
  158. i = self.model.find_constant_input(mul_before_tanh, 0.7978, delta=0.0001)
  159. if i < 0:
  160. return
  161. add_before_tanh = self.model.match_parent(mul_before_tanh, "Add", 0 if i == 1 else 1, output_name_to_node)
  162. if add_before_tanh is None:
  163. return
  164. mul_after_pow = self.model.match_parent(
  165. add_before_tanh,
  166. "Mul",
  167. None,
  168. output_name_to_node,
  169. exclude=[root_node] if root_node else [],
  170. )
  171. if mul_after_pow is None:
  172. return
  173. i = self.model.find_constant_input(mul_after_pow, 0.0447, delta=0.0001)
  174. if i < 0:
  175. return
  176. pow = self.model.match_parent(mul_after_pow, "Pow", 0 if i == 1 else 1, output_name_to_node)
  177. if pow is None:
  178. return
  179. if not self.model.has_constant_input(pow, 3.0):
  180. return
  181. root_input = mul_after_mul_half.input[0 if mul_after_mul_half.input[1] == mul_half.output[0] else 1]
  182. if pow.input[0] != root_input:
  183. return
  184. subgraph_nodes = [
  185. mul_after_mul_half,
  186. mul_half,
  187. add_after_tanh,
  188. tanh_node,
  189. mul_before_tanh,
  190. add_before_tanh,
  191. mul_after_pow,
  192. pow,
  193. ]
  194. if not self.model.is_safe_to_fuse_nodes(
  195. subgraph_nodes,
  196. [mul_after_mul_half.output[0]],
  197. input_name_to_nodes,
  198. output_name_to_node,
  199. ):
  200. return
  201. self.nodes_to_remove.extend(subgraph_nodes)
  202. fused_node = helper.make_node(
  203. "FastGelu",
  204. inputs=[root_input],
  205. outputs=mul_after_mul_half.output,
  206. name=self.model.create_node_name("FastGelu"),
  207. )
  208. fused_node.domain = "com.microsoft"
  209. self.nodes_to_add.append(fused_node)
  210. self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
  211. return True
  212. def fuse_3(self, tanh_node, input_name_to_nodes: dict, output_name_to_node: dict) -> bool | None:
  213. """
  214. OpenAI's gelu implementation, also used in Megatron:
  215. Gelu(x) = x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1.0 + 0.044715 * x * x)))
  216. Fuse subgraph into a FastGelu node:
  217. +------------ Mul (B=0.79788456) -------------------+
  218. | |
  219. +-------------------------------+ |
  220. | | |
  221. | v v
  222. [root] --> Mul (B=0.044715) --> Mul --> Add(B=1) --> Mul --> Tanh --> Add(B=1) --> Mul-->
  223. | ^
  224. | |
  225. +-----------> Mul (B=0.5) --------------------------------------------------------+
  226. """
  227. if tanh_node.output[0] not in input_name_to_nodes:
  228. return
  229. children = input_name_to_nodes[tanh_node.output[0]]
  230. if len(children) != 1 or children[0].op_type != "Add":
  231. return
  232. add_after_tanh = children[0]
  233. if not self.model.has_constant_input(add_after_tanh, 1.0):
  234. return
  235. if add_after_tanh.output[0] not in input_name_to_nodes:
  236. return
  237. children = input_name_to_nodes[add_after_tanh.output[0]]
  238. if len(children) != 1 or children[0].op_type != "Mul":
  239. return
  240. mul_last = children[0]
  241. mul_half = self.model.match_parent(mul_last, "Mul", None, output_name_to_node)
  242. if mul_half is None:
  243. return
  244. i = self.model.find_constant_input(mul_half, 0.5)
  245. if i < 0:
  246. return
  247. root_input = mul_half.input[0 if i == 1 else 1]
  248. mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node)
  249. if mul_before_tanh is None:
  250. return
  251. add_1 = self.model.match_parent(mul_before_tanh, "Add", None, output_name_to_node)
  252. if add_1 is None:
  253. return
  254. j = self.model.find_constant_input(add_1, 1.0)
  255. if j < 0:
  256. return
  257. mul_7978 = self.model.match_parent(mul_before_tanh, "Mul", None, output_name_to_node)
  258. if mul_7978 is None:
  259. return
  260. k = self.model.find_constant_input(mul_7978, 0.7978, delta=0.0001)
  261. if k < 0:
  262. return
  263. if mul_7978.input[0 if k == 1 else 1] != root_input:
  264. return
  265. mul_before_add_1 = self.model.match_parent(add_1, "Mul", 0 if j == 1 else 1, output_name_to_node)
  266. if mul_before_add_1 is None:
  267. return
  268. if mul_before_add_1.input[0] == root_input:
  269. another = 1
  270. elif mul_before_add_1.input[1] == root_input:
  271. another = 0
  272. else:
  273. return
  274. mul_0447 = self.model.match_parent(mul_before_add_1, "Mul", another, output_name_to_node)
  275. if mul_0447 is None:
  276. return
  277. m = self.model.find_constant_input(mul_0447, 0.0447, delta=0.0001)
  278. if m < 0:
  279. return
  280. if mul_0447.input[0 if m == 1 else 1] != root_input:
  281. return
  282. subgraph_nodes = [
  283. mul_0447,
  284. mul_before_add_1,
  285. add_1,
  286. mul_before_tanh,
  287. tanh_node,
  288. add_after_tanh,
  289. mul_7978,
  290. mul_half,
  291. mul_last,
  292. ]
  293. if not self.model.is_safe_to_fuse_nodes(
  294. subgraph_nodes,
  295. [mul_last.output[0]],
  296. input_name_to_nodes,
  297. output_name_to_node,
  298. ):
  299. return
  300. self.nodes_to_remove.extend(subgraph_nodes)
  301. fused_node = helper.make_node(
  302. "FastGelu",
  303. inputs=[root_input],
  304. outputs=mul_last.output,
  305. name=self.model.create_node_name("FastGelu"),
  306. )
  307. fused_node.domain = "com.microsoft"
  308. self.nodes_to_add.append(fused_node)
  309. self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
  310. return True
  311. def fuse_4(self, tanh_node, input_name_to_nodes: dict, output_name_to_node: dict) -> bool | None:
  312. """
  313. PyTorch's gelu implementation with tanh approximation:
  314. Gelu(x) = 0.5 * x * (1 + torch.tanh(0.7978845834732056 * (x + 0.044714998453855515 * x * x * x)))
  315. Fuse Gelu with tanh into one node:
  316. +-----------------+------------------+
  317. | | |
  318. | v v
  319. [root] ==> Mul --> Mul --> Mul -----> Add --> Mul --> Tanh --> Add -----> Mul --> Mul -->
  320. | (A=0.0447) (A=0.7978) (A=1) ^ (A=0.5)
  321. | |
  322. +-------------------------------------------------------------------------+
  323. Note that constant input for Add and Mul could be first or second input.
  324. """
  325. if tanh_node.output[0] not in input_name_to_nodes:
  326. return
  327. children = input_name_to_nodes[tanh_node.output[0]]
  328. if len(children) != 1 or children[0].op_type != "Add":
  329. return
  330. add_after_tanh = children[0]
  331. if not self.model.has_constant_input(add_after_tanh, 1.0):
  332. return
  333. if add_after_tanh.output[0] not in input_name_to_nodes:
  334. return
  335. children = input_name_to_nodes[add_after_tanh.output[0]]
  336. if len(children) != 1 or children[0].op_type != "Mul":
  337. return
  338. mul_after_tanh = children[0]
  339. if mul_after_tanh.output[0] not in input_name_to_nodes:
  340. return
  341. children = input_name_to_nodes[mul_after_tanh.output[0]]
  342. if len(children) != 1 or children[0].op_type != "Mul":
  343. return
  344. mul_half = children[0]
  345. if not self.model.has_constant_input(mul_half, 0.5):
  346. return
  347. root_input = mul_after_tanh.input[0 if mul_after_tanh.input[1] == add_after_tanh.output[0] else 1]
  348. mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node)
  349. if mul_before_tanh is None:
  350. return
  351. k = self.model.find_constant_input(mul_before_tanh, 0.7978, delta=0.01)
  352. if k < 0:
  353. return
  354. add_before_tanh = self.model.match_parent(mul_before_tanh, "Add", 0 if k == 1 else 1, output_name_to_node)
  355. if add_before_tanh is None:
  356. return
  357. if add_before_tanh.input[0] == root_input:
  358. another = 1
  359. elif add_before_tanh.input[1] == root_input:
  360. another = 0
  361. else:
  362. return
  363. mul_after_pow = self.model.match_parent(add_before_tanh, "Mul", another, output_name_to_node)
  364. if mul_after_pow is None:
  365. return
  366. m = self.model.find_constant_input(mul_after_pow, 0.0447, delta=0.01)
  367. if m < 0:
  368. return
  369. mul_cubed = self.model.match_parent(mul_after_pow, "Mul", 0 if m == 1 else 1, output_name_to_node)
  370. if mul_cubed is None:
  371. return
  372. if mul_cubed.input[0] == root_input:
  373. another = 1
  374. elif mul_cubed.input[1] == root_input:
  375. another = 0
  376. else:
  377. return
  378. mul_squared = self.model.match_parent(mul_cubed, "Mul", another, output_name_to_node)
  379. if mul_squared is None:
  380. return
  381. if mul_squared.input[0] != root_input or mul_squared.input[1] != root_input:
  382. return
  383. subgraph_nodes = [
  384. mul_squared,
  385. mul_cubed,
  386. mul_after_pow,
  387. add_before_tanh,
  388. mul_before_tanh,
  389. tanh_node,
  390. add_after_tanh,
  391. mul_after_tanh,
  392. mul_half,
  393. ]
  394. if not self.model.is_safe_to_fuse_nodes(
  395. subgraph_nodes,
  396. [mul_half.output[0]],
  397. input_name_to_nodes,
  398. output_name_to_node,
  399. ):
  400. return
  401. self.nodes_to_remove.extend(subgraph_nodes)
  402. fused_node = helper.make_node(
  403. "FastGelu",
  404. inputs=[root_input],
  405. outputs=mul_half.output,
  406. name=self.model.create_node_name("FastGelu"),
  407. )
  408. fused_node.domain = "com.microsoft"
  409. self.nodes_to_add.append(fused_node)
  410. self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
  411. self.increase_counter("FastGelu")
  412. return True