fusion_quickgelu.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import logging
  6. from fusion_base import Fusion
  7. from onnx import helper
  8. from onnx_model import OnnxModel
  9. logger = logging.getLogger(__name__)
  10. class FusionQuickGelu(Fusion):
  11. def __init__(self, model: OnnxModel):
  12. super().__init__(model, "QuickGelu", ["Mul"])
  13. def fuse(self, node, input_name_to_nodes, output_name_to_node):
  14. # Fuse the following subgraph to `QuickGelu`
  15. #
  16. # root_input
  17. # / \
  18. # | Mul ----+
  19. # | (B = ~1.702) |
  20. # \ | |
  21. # \ Sigmoid |---- `QuickGelu`
  22. # \ / |
  23. # \ / |
  24. # Mul ----+
  25. # |
  26. # root_output
  27. if node.op_type != "Mul":
  28. logger.debug("fuse_quickgelu: failed to match second Mul node")
  29. return
  30. second_mul_node = node
  31. root_input = second_mul_node.input[0]
  32. sigmoid_node = self.model.match_parent_path(second_mul_node, ["Sigmoid"], [1])
  33. if sigmoid_node is None:
  34. logger.debug("fuse_quickgelu: failed to match Sigmoid node")
  35. return
  36. sigmoid_node = sigmoid_node[0]
  37. first_mul_node = self.model.match_parent_path(sigmoid_node, ["Mul"], [0])
  38. if first_mul_node is None:
  39. logger.debug("fuse_quickgelu: failed to match first Mul node")
  40. return
  41. first_mul_node = first_mul_node[0]
  42. approximation_value = self.model.get_constant_value(first_mul_node.input[1]).item()
  43. if abs(approximation_value - 1.7021484375) >= 1e-3:
  44. logger.debug("fuse_quickgelu: failed to match approximation value")
  45. return
  46. if first_mul_node.input[0] != root_input:
  47. logger.debug("fuse_quickgelu: failed to match root input with first Mul node's input")
  48. return
  49. new_node = helper.make_node(
  50. "QuickGelu",
  51. inputs=[root_input],
  52. outputs=[second_mul_node.output[0]],
  53. name=self.model.create_node_name("QuickGelu"),
  54. )
  55. new_node.domain = "com.microsoft"
  56. new_node.attribute.extend([helper.make_attribute("alpha", approximation_value)])
  57. self.nodes_to_remove.extend([first_mul_node, sigmoid_node, second_mul_node])
  58. self.nodes_to_add.append(new_node)
  59. self.node_name_to_graph_name[new_node.name] = self.this_graph_name
  60. self.increase_counter("QuickGelu")