fusion_skiplayernorm.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from fusion_base import Fusion
  7. from fusion_utils import NumpyHelper
  8. from onnx import helper
  9. from onnx_model import OnnxModel
  10. logger = getLogger(__name__)
  11. class FusionSkipLayerNormalization(Fusion):
  12. """
  13. Fuse Add + LayerNormalization into one node: SkipLayerNormalization
  14. Note: This fusion does not check the input shape of Add and LayerNormalization.
  15. """
  16. def __init__(
  17. self,
  18. model: OnnxModel,
  19. fused_op_type: str = "SkipLayerNormalization",
  20. search_op_types: str = "LayerNormalization",
  21. shape_infer: bool = True,
  22. ):
  23. super().__init__(model, fused_op_type, search_op_types)
  24. if shape_infer:
  25. # Update shape inference is needed since other fusions might add new edge which does not have shape info yet.
  26. self.shape_infer_helper = self.model.infer_runtime_shape({"batch_size": 4, "seq_len": 7}, update=True)
  27. if self.shape_infer_helper is None:
  28. # TODO(tianleiwu): support subgraph in shape inference or add broadcasting in SkipLayerNormalization op.
  29. logger.warning("symbolic shape inference disabled or failed.")
  30. def fuse(self, node, input_name_to_nodes, output_name_to_node):
  31. add = self.model.get_parent(node, 0, output_name_to_node)
  32. # In some models there is input_ids->gather->add->LayerNorm and one of input of the
  33. # add node is initializer with fixed shape which should not be fused into SkipLayerNorm
  34. if add is None or add.op_type != "Add":
  35. return
  36. # The number of inputs of add should be 2
  37. if len(add.input) != 2:
  38. return
  39. for add_input in add.input:
  40. if self.model.get_initializer(add_input) is not None:
  41. return
  42. # To avoid an Add node have two children of LayerNormalization, we shall only fuse one SkipLayerNormalization
  43. if add in self.nodes_to_remove:
  44. return
  45. # Root Mean Square Layer Normalization
  46. simplified = node.op_type == "SimplifiedLayerNormalization"
  47. if hasattr(self, "shape_infer_helper"):
  48. if self.shape_infer_helper is not None:
  49. if (
  50. self.shape_infer_helper.get_edge_shape(add.input[0])
  51. and len(self.shape_infer_helper.get_edge_shape(add.input[0])) != 3
  52. ):
  53. logger.debug("skip SkipLayerNormalization fusion since shape of input %s is not 3D", add.input[0])
  54. return
  55. # TODO(tianleiwu): support broadcasting Skip shape (1, sequence_length, hidden_size) or (sequence_length, hidden_size)
  56. if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]):
  57. logger.debug(
  58. "skip SkipLayerNormalization fusion since shape of inputs (%s, %s) are not same",
  59. add.input[0],
  60. add.input[1],
  61. )
  62. return
  63. else:
  64. logger.debug("skip SkipLayerNormalization fusion since symbolic shape inference failed")
  65. return
  66. gather_path = self.model.match_parent_path(add, ["Gather"], [None])
  67. if gather_path is not None and self.model.find_graph_input(gather_path[0].input[1]) is None:
  68. if self.model.match_parent_path(gather_path[0], ["ConstantOfShape"], [1]) is None:
  69. return
  70. # This means that the residual Add before the LayerNormalization produces an output
  71. # that is consumed by some other nodes or graph output other than the LayerNormalization itself
  72. # We can still go ahead with the SkipLayerNormalization fusion but we need to
  73. # preserve the output of Add and that needs to be produced by SkipLayerNormalization.
  74. add_has_graph_output = self.model.find_graph_output(add.output[0]) is not None
  75. residual_add_has_multiple_consumers = (
  76. add_has_graph_output or len(self.model.get_children(add, input_name_to_nodes)) > 1
  77. )
  78. outputs_to_keep = node.output
  79. if residual_add_has_multiple_consumers:
  80. outputs_to_keep.extend([add.output[0]])
  81. outputs = [node.output[0]]
  82. # Skip the other optional outputs of SkipLayerNormalization before adding the Add's output
  83. if residual_add_has_multiple_consumers:
  84. outputs.extend(["", "", add.output[0]])
  85. if self.model.is_safe_to_fuse_nodes([add, node], outputs_to_keep, input_name_to_nodes, output_name_to_node):
  86. self.nodes_to_remove.extend([add, node])
  87. inputs = (
  88. [add.input[0], add.input[1], node.input[1], node.input[2]]
  89. if not simplified
  90. else [add.input[0], add.input[1], node.input[1]]
  91. )
  92. normalize_node = helper.make_node(
  93. self.fused_op_type,
  94. inputs=inputs,
  95. outputs=outputs,
  96. name=self.model.create_node_name(self.fused_op_type, name_prefix="SkipLayerNorm"),
  97. )
  98. normalize_node.domain = "com.microsoft"
  99. # Pass attribute "epsilon" from layernorm node to SkipLayerNormalization
  100. for att in node.attribute:
  101. if att.name == "epsilon":
  102. normalize_node.attribute.extend([att])
  103. # Set default epsilon if no epsilon exists from layernorm
  104. if len(normalize_node.attribute) == 0:
  105. normalize_node.attribute.extend([helper.make_attribute("epsilon", 1.0e-12)])
  106. self.nodes_to_add.append(normalize_node)
  107. self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name
  108. class FusionBiasSkipLayerNormalization(Fusion):
  109. def __init__(self, model: OnnxModel):
  110. super().__init__(model, "SkipLayerNormalization", "SkipLayerNormalization", "add bias")
  111. def fuse(self, node, input_name_to_nodes, output_name_to_node):
  112. if len(node.input) != 4:
  113. return
  114. return_indice = []
  115. nodes = self.model.match_parent_path(node, ["Add", "MatMul"], [None, None], output_name_to_node, return_indice)
  116. if nodes is not None:
  117. (add, _matmul) = nodes
  118. else:
  119. # In case of fp16, we could have a Cast between the MatMul and the bias Add
  120. return_indice = []
  121. nodes = self.model.match_parent_path(
  122. node, ["Add", "Cast", "MatMul"], [None, None, None], output_name_to_node, return_indice
  123. )
  124. if nodes is not None:
  125. (add, _cast, _matmul) = nodes
  126. else:
  127. return
  128. assert len(return_indice) == 2 or len(return_indice) == 3
  129. add_input_index = return_indice[0]
  130. if add_input_index >= 2:
  131. return
  132. sln_input = add.input[return_indice[1]]
  133. bias_input = add.input[1 - return_indice[1]]
  134. skip_input = node.input[1 - add_input_index]
  135. # bias should be one dimension
  136. initializer = self.model.get_initializer(bias_input)
  137. if initializer is None:
  138. return
  139. bias_weight = NumpyHelper.to_array(initializer)
  140. if bias_weight is None:
  141. logger.debug("Bias weight not found")
  142. return
  143. if len(bias_weight.shape) != 1:
  144. logger.debug("Bias weight is not 1D")
  145. return
  146. subgraph_nodes = [node, add]
  147. if not self.model.is_safe_to_fuse_nodes(subgraph_nodes, node.output, input_name_to_nodes, output_name_to_node):
  148. logger.debug("Skip fusing SkipLayerNormalization with Bias since it is not safe")
  149. return
  150. self.nodes_to_remove.extend(subgraph_nodes)
  151. inputs = [
  152. sln_input,
  153. skip_input,
  154. node.input[2],
  155. node.input[3],
  156. bias_input,
  157. ]
  158. new_node = helper.make_node(
  159. "SkipLayerNormalization",
  160. inputs=inputs,
  161. outputs=node.output,
  162. name=self.model.create_node_name("SkipLayerNormalization", "SkipLayerNorm_AddBias_"),
  163. )
  164. new_node.domain = "com.microsoft"
  165. # Pass attribute "epsilon" from skiplayernorm node to skiplayernorm(add bias)
  166. for att in node.attribute:
  167. if att.name == "epsilon":
  168. new_node.attribute.extend([att])
  169. # Set default epsilon if no epsilon exists from skiplayernorm
  170. if len(new_node.attribute) == 0:
  171. new_node.attribute.extend([helper.make_attribute("epsilon", 1.0e-12)])
  172. self.nodes_to_add.append(new_node)
  173. self.node_name_to_graph_name[new_node.name] = self.this_graph_name