fusion_biassplitgelu.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from fusion_base import Fusion
  7. from onnx import helper
  8. from onnx_model import OnnxModel
  9. logger = getLogger(__name__)
  10. class FusionBiasSplitGelu(Fusion):
  11. def __init__(self, model: OnnxModel):
  12. super().__init__(model, "BiasSplitGelu", "Gelu")
  13. def fuse(self, gelu_node, input_name_to_nodes: dict, output_name_to_node: dict):
  14. """
  15. [root] --->Add --------------------> Slice ---------------> Mul -->
  16. | ^ ^
  17. | | |
  18. +----------------------------+---Slice --> Gelu---+
  19. | | ^
  20. | |-----|
  21. | | |
  22. | Mul Mul
  23. | ^ ^
  24. v | |
  25. Shape ---> Gather --> Add --> Div --+
  26. """
  27. if gelu_node.output[0] not in input_name_to_nodes:
  28. return
  29. children = input_name_to_nodes[gelu_node.output[0]]
  30. if len(children) != 1 or children[0].op_type != "Mul":
  31. return
  32. mul_after_gelu = children[0]
  33. slice_before_gelu = self.model.match_parent(gelu_node, "Slice", 0, output_name_to_node)
  34. if slice_before_gelu is None:
  35. return
  36. if self.model.find_constant_input(slice_before_gelu, -1, delta=0.001) != 3:
  37. return
  38. add_output = slice_before_gelu.input[0]
  39. start_index_nodes = self.model.match_parent_path(
  40. slice_before_gelu,
  41. ["Div", "Add", "Gather", "Shape", "Add"],
  42. [1, 0, 0, 0, 0],
  43. output_name_to_node, # Mul(1) is optional
  44. )
  45. if start_index_nodes is None:
  46. start_index_nodes = self.model.match_parent_path(
  47. slice_before_gelu,
  48. ["Mul", "Div", "Add", "Gather", "Shape", "Add"],
  49. [1, 0, 0, 0, 0, 0],
  50. output_name_to_node,
  51. )
  52. if start_index_nodes is None or start_index_nodes[-2].input[0] != add_output:
  53. return
  54. end_index_nodes = self.model.match_parent_path(slice_before_gelu, ["Mul", "Div"], [2, 0], output_name_to_node)
  55. if (
  56. end_index_nodes is None or end_index_nodes[1] not in start_index_nodes
  57. ): # the Div is parent of both two Mul nodes
  58. return
  59. slice_before_mul = self.model.match_parent(mul_after_gelu, "Slice", 0, output_name_to_node)
  60. if slice_before_mul is None:
  61. return
  62. if (
  63. slice_before_mul.input[2] != slice_before_gelu.input[1]
  64. ): # end index of slice_before_mul is start index of slice_before_gelu
  65. return
  66. subgraph_nodes = [
  67. *start_index_nodes,
  68. end_index_nodes[0],
  69. mul_after_gelu,
  70. gelu_node,
  71. slice_before_mul,
  72. slice_before_gelu,
  73. ]
  74. subgraph_output = mul_after_gelu.output[0]
  75. if not self.model.is_safe_to_fuse_nodes(
  76. subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node
  77. ):
  78. logger.info("Skip fuse BiasSplitGelu since it is not safe to fuse the subgraph.")
  79. return
  80. add_node = start_index_nodes[-1]
  81. bias_index, _value = self.model.get_constant_input(add_node)
  82. if not isinstance(bias_index, int):
  83. return
  84. self.nodes_to_remove.extend(subgraph_nodes)
  85. node_name = self.model.create_node_name("BiasSplitGelu", name_prefix="BiasSplitGelu")
  86. fused_node = helper.make_node(
  87. "BiasSplitGelu",
  88. inputs=[add_node.input[1 - bias_index], add_node.input[bias_index]],
  89. outputs=[subgraph_output],
  90. name=node_name,
  91. )
  92. fused_node.domain = "com.microsoft"
  93. self.nodes_to_add.append(fused_node)
  94. self.node_name_to_graph_name[node_name] = self.this_graph_name