fusion_gemmfastgelu.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from fusion_base import Fusion
  7. from fusion_utils import NumpyHelper
  8. from onnx import NodeProto, TensorProto, helper
  9. from onnx_model import OnnxModel
  10. logger = getLogger(__name__)
  11. class FusionGemmFastGelu(Fusion):
  12. def __init__(self, model: OnnxModel):
  13. super().__init__(model, "GemmFastGelu", "FastGelu", "GemmFastGelu")
  14. self.shape_infer = None
  15. self.shape_infer_done = False
  16. def get_dimensions_from_tensor_proto(self, tensor_proto: TensorProto) -> int | None:
  17. if tensor_proto.type.tensor_type.HasField("shape"):
  18. return len(tensor_proto.type.tensor_type.shape.dim)
  19. else:
  20. return None
  21. def get_dimensions(self, input_name: str) -> int | None:
  22. graph_input = self.model.find_graph_input(input_name)
  23. if graph_input:
  24. return self.get_dimensions_from_tensor_proto(graph_input)
  25. if not self.shape_infer_done:
  26. self.shape_infer = self.model.infer_runtime_shape(update=True)
  27. self.shape_infer_done = True
  28. if self.shape_infer is not None:
  29. return self.get_dimensions_from_tensor_proto(self.shape_infer.known_vi_[input_name])
  30. return None
  31. def fuse(
  32. self,
  33. node: NodeProto,
  34. input_name_to_nodes: dict[str, list[NodeProto]],
  35. output_name_to_node: dict[str, NodeProto],
  36. ):
  37. """
  38. This pattern is from PyTorch bert model
  39. Fuse MatMul with FastGelu into one node:
  40. [root] --> MatMul --> FastGelu -->
  41. """
  42. has_bias = False
  43. if len(node.input) == 2:
  44. has_bias = True
  45. match_nodes = self.model.match_parent_path(node, ["MatMul"], [0])
  46. if match_nodes is None:
  47. return
  48. matmul = match_nodes[0]
  49. # matmul input X should >= two dimension, input weight should be two dimension
  50. weight_index = -1
  51. x_dims = 0
  52. weight = None
  53. for i, input in enumerate(matmul.input):
  54. initializer = self.model.get_initializer(input)
  55. if initializer is None:
  56. x_dims = self.get_dimensions(matmul.input[i])
  57. else:
  58. weight_index = i
  59. weight = NumpyHelper.to_array(initializer)
  60. if weight is None:
  61. return
  62. if len(weight.shape) != 2:
  63. return
  64. if x_dims < len(weight.shape):
  65. return
  66. # bias weight should be one dimension
  67. bias_index = -1
  68. if has_bias:
  69. bias_weight = None
  70. for i, input in enumerate(node.input):
  71. initializer = self.model.get_initializer(input)
  72. if initializer is None:
  73. continue
  74. bias_index = i
  75. bias_weight = NumpyHelper.to_array(initializer)
  76. break
  77. if bias_weight is None:
  78. return
  79. if len(bias_weight.shape) != 1:
  80. return
  81. subgraph_nodes = [node, matmul]
  82. if not self.model.is_safe_to_fuse_nodes(
  83. subgraph_nodes, [node.output[0]], input_name_to_nodes, output_name_to_node
  84. ):
  85. return
  86. self.nodes_to_remove.extend(subgraph_nodes)
  87. inputs = (
  88. [matmul.input[1 - weight_index], matmul.input[weight_index], node.input[bias_index]]
  89. if has_bias
  90. else [matmul.input[1 - weight_index], matmul.input[weight_index]]
  91. )
  92. fused_node = helper.make_node(
  93. "GemmFastGelu",
  94. inputs=inputs,
  95. outputs=node.output,
  96. name=self.model.create_node_name("GemmFastGelu"),
  97. )
  98. fused_node.domain = "com.microsoft"
  99. self.nodes_to_add.append(fused_node)
  100. self.node_name_to_graph_name[fused_node.name] = self.this_graph_name