fusion_layernorm.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from fusion_base import Fusion
  7. from onnx import TensorProto, helper
  8. from onnx_model import OnnxModel
  9. logger = getLogger(__name__)
  10. class FusionLayerNormalization(Fusion):
  11. def __init__(self, model: OnnxModel, check_constant_and_dimension: bool = True, force: bool = False):
  12. super().__init__(model, "LayerNormalization", "ReduceMean")
  13. self.check_constant_and_dimension = check_constant_and_dimension
  14. self.force = force
  15. def fuse(self, node, input_name_to_nodes: dict, output_name_to_node: dict):
  16. """
  17. Fuse Layer Normalization subgraph into one node LayerNormalization:
  18. +----------------------+
  19. | |
  20. | v
  21. [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
  22. (axis=2 or -1) | (Y=2) (axis=2 or -1) (B=E-6 or E-12) ^
  23. | |
  24. +-------------------------------------------------+
  25. It also handles cases of duplicated sub nodes exported from older version of PyTorch:
  26. +----------------------+
  27. | v
  28. | +-------> Sub-----------------------------------------------+
  29. | | |
  30. | | v
  31. [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
  32. | ^
  33. | |
  34. +----------------------+
  35. """
  36. subgraph_nodes = []
  37. children = self.model.get_children(node, input_name_to_nodes)
  38. if len(children) == 0 or len(children) > 2:
  39. return
  40. root_input = node.input[0]
  41. if children[0].op_type != "Sub" or children[0].input[0] != root_input:
  42. return
  43. if len(children) == 2:
  44. if children[1].op_type != "Sub" or children[1].input[0] != root_input:
  45. return
  46. div_node = None
  47. for child in children:
  48. # Check if Sub --> Div exists
  49. div_node_1 = self.model.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False)
  50. if div_node_1 is not None:
  51. div_node = div_node_1
  52. break
  53. else:
  54. # Check if Sub --> Cast --> Div
  55. div_node_2 = self.model.match_child_path(child, ["Cast", "Div"])
  56. if div_node_2 is not None:
  57. div_node = div_node_2[-1]
  58. break
  59. if div_node is None:
  60. return
  61. _path_id, parent_nodes, _ = self.model.match_parent_paths(
  62. div_node,
  63. [
  64. (["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]),
  65. (["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"], [1, 0, 0, 0, 0, 0]),
  66. ],
  67. output_name_to_node,
  68. )
  69. if parent_nodes is None:
  70. return
  71. sub_node = parent_nodes[-1]
  72. if sub_node not in children:
  73. return
  74. add_eps_node = parent_nodes[1]
  75. i, epsilon = self.model.get_constant_input(add_eps_node)
  76. if epsilon is None or epsilon <= 0 or epsilon > 1.0e-4:
  77. logger.debug(f"skip SkipLayerNormalization fusion since epsilon value is not expected: {epsilon}")
  78. return
  79. pow_node = parent_nodes[3]
  80. if self.model.find_constant_input(pow_node, 2.0) != 1:
  81. return
  82. if div_node.output[0] not in input_name_to_nodes:
  83. return
  84. # In MMDit model, Div might have two Mul+Add children paths.
  85. div_children = input_name_to_nodes[div_node.output[0]]
  86. for temp_node in div_children:
  87. if temp_node.op_type == "Cast":
  88. # Div --> Cast --> Mul
  89. subgraph_nodes.append(temp_node) # add Cast node to list of subgraph nodes
  90. if temp_node.output[0] not in input_name_to_nodes:
  91. continue
  92. mul_node = input_name_to_nodes[temp_node.output[0]][0]
  93. else:
  94. # Div --> Mul
  95. mul_node = temp_node
  96. if mul_node.op_type != "Mul":
  97. continue
  98. if mul_node.output[0] not in input_name_to_nodes:
  99. continue
  100. last_add_node = input_name_to_nodes[mul_node.output[0]][0]
  101. if last_add_node.op_type != "Add":
  102. continue
  103. subgraph_nodes.append(node)
  104. subgraph_nodes.extend(children)
  105. subgraph_nodes.extend(parent_nodes[:-1])
  106. subgraph_nodes.extend([last_add_node, mul_node, div_node])
  107. node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node
  108. weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)]
  109. if self.check_constant_and_dimension and not self.model.is_constant_with_specified_dimension(
  110. weight_input, 1, "layernorm weight"
  111. ):
  112. continue
  113. bias_input = last_add_node.input[1 - self.model.input_index(mul_node.output[0], last_add_node)]
  114. if self.check_constant_and_dimension and not self.model.is_constant_with_specified_dimension(
  115. bias_input, 1, "layernorm bias"
  116. ):
  117. continue
  118. layer_norm_output = last_add_node.output[0]
  119. if not self.model.is_safe_to_fuse_nodes(
  120. subgraph_nodes,
  121. last_add_node.output,
  122. input_name_to_nodes,
  123. output_name_to_node,
  124. ):
  125. # If it is not safe to fuse, somce computation may be duplicated if we force to fuse it.
  126. # It it unknown that force fusion might bring performance gain/loss.
  127. # User need test performance impact to see whether forcing fusion can help.
  128. if self.force:
  129. self.prune_graph = True
  130. else:
  131. logger.debug("It is not safe to fuse LayerNormalization node. Skip")
  132. continue
  133. else:
  134. self.nodes_to_remove.extend(subgraph_nodes)
  135. normalize_node = helper.make_node(
  136. "LayerNormalization",
  137. inputs=[node.input[0], weight_input, bias_input],
  138. outputs=[layer_norm_output],
  139. name=self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm"),
  140. )
  141. normalize_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))])
  142. self.nodes_to_add.append(normalize_node)
  143. self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name
  144. class FusionLayerNormalizationNCHW(Fusion):
  145. def __init__(self, model: OnnxModel):
  146. super().__init__(model, "LayerNormalization", "ReduceMean")
  147. def get_weight_or_bias(self, output_name, description):
  148. value = self.model.get_constant_value(output_name)
  149. if value is None:
  150. logger.debug(f"{description} {output_name} is not initializer.")
  151. return None
  152. if len(value.shape) != 3 or value.shape[1] != 1 or value.shape[2] != 1:
  153. logger.debug(f"{description} {output_name} shall have 3 dimensions Cx1x1. Got shape {value.shape}")
  154. return None
  155. return value.reshape([value.shape[0]])
  156. def create_transpose_node(self, input_name: str, perm: list[int], output_name=None):
  157. """Append a Transpose node after an input"""
  158. node_name = self.model.create_node_name("Transpose")
  159. if output_name is None:
  160. output_name = node_name + "_out" + "-" + input_name
  161. transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name)
  162. transpose_node.attribute.extend([helper.make_attribute("perm", perm)])
  163. return transpose_node
  164. def fuse(self, node, input_name_to_nodes: dict, output_name_to_node: dict):
  165. """
  166. Fuse Layer Normalization subgraph into one node LayerNormalization:
  167. +----------------------+
  168. | NxCxHxW |
  169. | v (Cx1x1) (Cx1x1)
  170. [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add -->
  171. (axes=1) | (Y=2) (axes=1) (E-6) ^
  172. | |
  173. +-----------------------------------------------+
  174. Fused subgraph:
  175. (0,2,3,1) (0,3,1,2)
  176. [Root] --> Transpose --> LayerNormalization --> Transpose -->
  177. """
  178. axes = OnnxModel.get_node_attribute(node, "axes")
  179. if (not isinstance(axes, list)) or axes != [1]:
  180. return
  181. subgraph_nodes = []
  182. children = self.model.get_children(node, input_name_to_nodes)
  183. if len(children) != 1:
  184. return
  185. root_input = node.input[0]
  186. if children[0].op_type != "Sub" or children[0].input[0] != root_input:
  187. return
  188. sub = children[0]
  189. div_node = self.model.find_first_child_by_type(sub, "Div", input_name_to_nodes, recursive=False)
  190. if div_node is None:
  191. return
  192. parent_nodes = self.model.match_parent_path(
  193. div_node,
  194. ["Sqrt", "Add", "ReduceMean", "Pow", "Sub"],
  195. [1, 0, 0, 0, 0],
  196. output_name_to_node,
  197. )
  198. if parent_nodes is None:
  199. return
  200. _sqrt_node, second_add_node, reduce_mean_node, pow_node, sub_node = parent_nodes
  201. if sub != sub_node:
  202. return
  203. i, epsilon = self.model.get_constant_input(second_add_node)
  204. if epsilon is None or epsilon <= 0 or epsilon > 1.0e-4:
  205. logger.debug(f"skip SkipLayerNormalization fusion since epsilon value is not expected: {epsilon}")
  206. return
  207. axes = OnnxModel.get_node_attribute(reduce_mean_node, "axes")
  208. assert isinstance(axes, list)
  209. if axes != [1]:
  210. return
  211. if self.model.find_constant_input(pow_node, 2.0) != 1:
  212. return
  213. temp_node = input_name_to_nodes[div_node.output[0]][0]
  214. mul_node = temp_node
  215. if mul_node.op_type != "Mul":
  216. return
  217. last_add_node = input_name_to_nodes[mul_node.output[0]][0]
  218. if last_add_node.op_type != "Add":
  219. return
  220. subgraph_nodes.append(node)
  221. subgraph_nodes.extend(parent_nodes)
  222. subgraph_nodes.extend([last_add_node, mul_node, div_node])
  223. if not self.model.is_safe_to_fuse_nodes(
  224. subgraph_nodes,
  225. last_add_node.output,
  226. input_name_to_nodes,
  227. output_name_to_node,
  228. ):
  229. logger.debug("It is not safe to fuse LayerNormalization node. Skip")
  230. return
  231. node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node
  232. weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)]
  233. weight = self.get_weight_or_bias(weight_input, "layernorm weight")
  234. if weight is None:
  235. return
  236. bias_input = last_add_node.input[1 - self.model.input_index(mul_node.output[0], last_add_node)]
  237. bias = self.get_weight_or_bias(bias_input, "layernorm bias")
  238. if bias is None:
  239. return
  240. weight_nhwc = helper.make_tensor(weight_input + "_NHWC", TensorProto.FLOAT, weight.shape, weight)
  241. bias_nhwc = helper.make_tensor(bias_input + "_NHWC", TensorProto.FLOAT, weight.shape, weight)
  242. self.model.add_initializer(weight_nhwc, self.this_graph_name)
  243. self.model.add_initializer(bias_nhwc, self.this_graph_name)
  244. self.nodes_to_remove.extend(subgraph_nodes)
  245. transpose_input = self.create_transpose_node(node.input[0], [0, 2, 3, 1])
  246. layernorm_node_name = self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm")
  247. transpose_output = self.create_transpose_node(
  248. layernorm_node_name + "_out_nhwc", [0, 3, 1, 2], last_add_node.output[0]
  249. )
  250. normalize_node = helper.make_node(
  251. "LayerNormalization",
  252. inputs=[transpose_input.output[0], weight_input + "_NHWC", bias_input + "_NHWC"],
  253. outputs=[layernorm_node_name + "_out_nhwc"],
  254. name=layernorm_node_name,
  255. )
  256. normalize_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))])
  257. self.nodes_to_add.append(transpose_input)
  258. self.nodes_to_add.append(normalize_node)
  259. self.nodes_to_add.append(transpose_output)
  260. self.node_name_to_graph_name[transpose_input.name] = self.this_graph_name
  261. self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name
  262. self.node_name_to_graph_name[transpose_output.name] = self.this_graph_name
  263. counter_name = "LayerNormalization(NHWC)"
  264. self.increase_counter(counter_name)
  265. class FusionLayerNormalizationTF(Fusion):
  266. def __init__(self, model: OnnxModel):
  267. super().__init__(model, "LayerNormalization", "Add", "TF")
  268. def fuse(self, node, input_name_to_nodes: dict, output_name_to_node: dict):
  269. """
  270. Layer Norm from Tensorflow model(using keras2onnx or tf2onnx):
  271. +------------------------------------+
  272. | |
  273. | |
  274. (Cast_1) |
  275. | |
  276. | v (B) (B) (A)
  277. Add --> (Cast_1) --> ReduceMean --> Sub --> Mul --> ReduceMean --> (Cast_3) --> Add --> Sqrt --> Reciprocol --> Mul --> Mul --> Sub --> Add
  278. | | | ^ ^
  279. | | | | |
  280. | +--------------------------------------------------(Cast_2)-------------------------------|-------+ |
  281. | v |
  282. +---------------------------------------------------------------------------------------------------------------> Mul--------------------+
  283. """
  284. return_indice = []
  285. _, parent_nodes, return_indice = self.model.match_parent_paths(
  286. node,
  287. [
  288. (
  289. [
  290. "Sub",
  291. "Mul",
  292. "Mul",
  293. "Reciprocal",
  294. "Sqrt",
  295. "Add",
  296. "ReduceMean",
  297. "Mul",
  298. "Sub",
  299. "ReduceMean",
  300. ],
  301. [1, 1, None, 0, 0, 0, None, 0, 0, None],
  302. ),
  303. (
  304. [
  305. "Sub",
  306. "Mul",
  307. "Mul",
  308. "Reciprocal",
  309. "Sqrt",
  310. "Add",
  311. "Cast",
  312. "ReduceMean",
  313. "Mul",
  314. "Sub",
  315. "ReduceMean",
  316. ],
  317. [1, 1, None, 0, 0, 0, 0, None, 0, 0, None],
  318. ),
  319. ],
  320. output_name_to_node,
  321. )
  322. if parent_nodes is None:
  323. return
  324. assert len(return_indice) == 3
  325. if not (return_indice[0] in [0, 1] and return_indice[1] in [0, 1] and return_indice[2] in [0, 1]):
  326. logger.debug("return indice is exepected in [0, 1], but got {return_indice}")
  327. return
  328. (
  329. sub_node_0,
  330. mul_node_0,
  331. mul_node_1,
  332. reciprocol_node,
  333. sqrt_node,
  334. add_node_0,
  335. ) = parent_nodes[:6]
  336. reduce_mean_node_0, mul_node_2, sub_node_1, reduce_mean_node_1 = parent_nodes[-4:]
  337. cast_node_3 = None
  338. if len(parent_nodes) == 11:
  339. cast_node_3 = parent_nodes[6]
  340. assert cast_node_3.op_type == "Cast"
  341. mul_node_3 = self.model.match_parent(node, "Mul", 0, output_name_to_node)
  342. if mul_node_3 is None:
  343. logger.debug("mul_node_3 not found")
  344. return
  345. node_before_reduce = self.model.get_parent(reduce_mean_node_1, 0, output_name_to_node)
  346. root_node = (
  347. node_before_reduce
  348. if cast_node_3 is None
  349. else self.model.get_parent(node_before_reduce, 0, output_name_to_node)
  350. )
  351. if root_node is None:
  352. logger.debug("root node is none")
  353. return
  354. i, epsilon = self.model.get_constant_input(add_node_0)
  355. if epsilon is None or epsilon <= 0 or (epsilon > 1.0e-5 and cast_node_3 is None):
  356. logger.debug("epsilon is not matched")
  357. return
  358. if cast_node_3 is None and (
  359. reduce_mean_node_1.input[0] not in mul_node_3.input or reduce_mean_node_1.input[0] not in sub_node_1.input
  360. ):
  361. logger.debug("reduce_mean_node_1 and mul_node_3 shall link from root node")
  362. return
  363. if cast_node_3 is not None and (
  364. node_before_reduce.input[0] not in mul_node_3.input or reduce_mean_node_1.input[0] not in sub_node_1.input
  365. ):
  366. logger.debug("reduce_mean_node_1 and mul_node_3 shall link from root node")
  367. return
  368. if mul_node_2.input[0] != mul_node_2.input[1]:
  369. logger.debug("mul_node_2 shall have two same inputs")
  370. return
  371. subgraph_nodes = [
  372. node,
  373. sub_node_0,
  374. mul_node_0,
  375. mul_node_1,
  376. reciprocol_node,
  377. sqrt_node,
  378. add_node_0,
  379. reduce_mean_node_0,
  380. mul_node_2,
  381. sub_node_1,
  382. reduce_mean_node_1,
  383. mul_node_3,
  384. ]
  385. if cast_node_3 is not None:
  386. cast_node_2 = self.model.match_parent(mul_node_0, "Cast", 0, output_name_to_node)
  387. if cast_node_2 is None:
  388. logger.debug("cast_node_2 not found")
  389. return
  390. subgraph_nodes.extend([node_before_reduce, cast_node_2, cast_node_3])
  391. if not self.model.is_safe_to_fuse_nodes(
  392. subgraph_nodes,
  393. node.output,
  394. self.model.input_name_to_nodes(),
  395. self.model.output_name_to_node(),
  396. ):
  397. logger.debug("not safe to fuse layer normalization")
  398. return
  399. self.nodes_to_remove.extend(subgraph_nodes)
  400. weight_input = mul_node_1.input[1]
  401. bias_input = sub_node_0.input[0]
  402. # TODO: add epsilon attribute
  403. fused_node = helper.make_node(
  404. "LayerNormalization",
  405. inputs=[mul_node_3.input[0], weight_input, bias_input],
  406. outputs=[node.output[0]],
  407. name=self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm"),
  408. )
  409. fused_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))])
  410. self.nodes_to_add.append(fused_node)
  411. self.node_name_to_graph_name[fused_node.name] = self.this_graph_name