fusion_layernorm.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. from __future__ import annotations
  7. import onnx
  8. from ..onnx_model import ONNXModel
  9. from .fusion import Fusion
  10. class FusionLayerNormalization(Fusion):
  11. def __init__(self, model: ONNXModel):
  12. super().__init__(model, "LayerNormalization", "ReduceMean")
  13. def fuse(
  14. self,
  15. reduce_mean_node: onnx.NodeProto,
  16. input_name_to_nodes: dict[str, list[onnx.NodeProto]],
  17. output_name_to_node: dict[str, onnx.NodeProto],
  18. ):
  19. """
  20. Interface function that tries to fuse a node sequence containing a ReduceMean node into a single
  21. LayerNormalization node.
  22. +----------------------+
  23. | |
  24. | v
  25. [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
  26. (axis=2 or -1) | (Y=2) (axis=2 or -1) (E-6 or E-12 or 0) ^
  27. | |
  28. +-------------------------------------------------+
  29. Or, using Mul instead of Pow:
  30. +----------------------+
  31. | |
  32. | v
  33. [Root] --> ReduceMean --> Sub --> Mul --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
  34. (axis=2 or -1) | (in0=in1) (axis=2 or -1) (E-6 or E-12 or 0) ^
  35. | |
  36. +-------------------------------------------------+
  37. It also handles cases of duplicated sub nodes exported from older version of PyTorch:
  38. +----------------------+
  39. | v
  40. | +-------> Sub-----------------------------------------------+
  41. | | |
  42. | | v
  43. [Root] --> ReduceMean --> Sub --> (Pow or Mul) --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
  44. | ^
  45. | |
  46. +----------------------+
  47. """
  48. children = self.model.get_children(reduce_mean_node, input_name_to_nodes)
  49. if len(children) == 0 or len(children) > 2:
  50. return
  51. root_input = reduce_mean_node.input[0]
  52. if children[0].op_type != "Sub" or children[0].input[0] != root_input:
  53. return
  54. if len(children) == 2:
  55. if children[1].op_type != "Sub" or children[1].input[0] != root_input:
  56. return
  57. div_node = None
  58. for child in children:
  59. div_node = self.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False)
  60. if div_node is not None:
  61. break
  62. if div_node is None:
  63. return
  64. path_id, parent_nodes, _ = self.match_parent_paths(
  65. div_node,
  66. [
  67. (["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]),
  68. (["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"], [1, 0, 0, 0, 0, 0]),
  69. (["Sqrt", "Add", "ReduceMean", "Mul", "Sub"], [1, 0, 0, 0, 0]),
  70. (["Sqrt", "Add", "ReduceMean", "Mul", "Cast", "Sub"], [1, 0, 0, 0, 0, 0]),
  71. ],
  72. output_name_to_node,
  73. )
  74. if path_id < 0:
  75. return
  76. sub_node = parent_nodes[-1]
  77. if sub_node not in children:
  78. return
  79. second_add_node = parent_nodes[1]
  80. i, add_weight = self.get_constant_input(second_add_node)
  81. if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4:
  82. # Skip fusion since epsilon value is not expected.
  83. return
  84. pow_or_mul_node = parent_nodes[3]
  85. if pow_or_mul_node.op_type == "Pow" and self.find_constant_input(pow_or_mul_node, 2.0) != 1:
  86. return
  87. elif pow_or_mul_node.op_type == "Mul" and pow_or_mul_node.input[0] != pow_or_mul_node.input[1]:
  88. return
  89. mul_node = input_name_to_nodes[div_node.output[0]][0]
  90. if mul_node.op_type != "Mul":
  91. return
  92. last_add_node = input_name_to_nodes[mul_node.output[0]][0]
  93. if last_add_node.op_type != "Add":
  94. return
  95. subgraph_nodes = [reduce_mean_node]
  96. subgraph_nodes.extend(children)
  97. subgraph_nodes.extend(parent_nodes[:-1])
  98. subgraph_nodes.extend([last_add_node, mul_node, div_node])
  99. if not self.is_safe_to_fuse_nodes(
  100. subgraph_nodes,
  101. last_add_node.output,
  102. input_name_to_nodes,
  103. output_name_to_node,
  104. ):
  105. return
  106. weight_input = mul_node.input[1 - self.input_index(div_node.output[0], mul_node)]
  107. if not self.is_constant_with_specified_rank(weight_input, 1):
  108. return
  109. bias_input = last_add_node.input[1 - self.input_index(mul_node.output[0], last_add_node)]
  110. if not self.is_constant_with_specified_rank(bias_input, 1):
  111. return
  112. self.nodes_to_remove.extend(subgraph_nodes)
  113. normalize_node = onnx.helper.make_node(
  114. "LayerNormalization",
  115. name=self.create_unique_node_name(),
  116. inputs=[reduce_mean_node.input[0], weight_input, bias_input],
  117. outputs=[last_add_node.output[0]],
  118. )
  119. normalize_node.attribute.extend([onnx.helper.make_attribute("epsilon", float(add_weight))])
  120. self.nodes_to_add.append(normalize_node)