fusion_gelu_approximation.py 1004 B

12345678910111213141516171819202122232425
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from fusion_base import Fusion
  6. from onnx import helper
  7. from onnx_model import OnnxModel
  8. class FusionGeluApproximation(Fusion):
  9. def __init__(self, model: OnnxModel):
  10. super().__init__(model, "FastGelu", ["Gelu", "BiasGelu"], "GeluApproximation")
  11. def fuse(self, node, input_name_to_nodes, output_name_to_node):
  12. new_node = helper.make_node(
  13. "FastGelu",
  14. inputs=node.input,
  15. outputs=node.output,
  16. name=self.model.create_node_name("FastGelu", node.op_type + "_Approximation"),
  17. )
  18. new_node.domain = "com.microsoft"
  19. self.nodes_to_remove.append(node)
  20. self.nodes_to_add.append(new_node)
  21. self.node_name_to_graph_name[new_node.name] = self.this_graph_name