| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
- # -------------------------------------------------------------------------
- # Copyright (c) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License.
- # --------------------------------------------------------------------------
- from logging import getLogger
- from fusion_base import Fusion
- from fusion_utils import NumpyHelper
- from onnx import helper
- from onnx_model import OnnxModel
- logger = getLogger(__name__)
- class FusionBiasGelu(Fusion):
- def __init__(self, model: OnnxModel, is_fastgelu):
- if is_fastgelu:
- super().__init__(model, "FastGelu", "FastGelu", "add bias")
- else:
- super().__init__(model, "BiasGelu", "Gelu")
- def fuse(self, node, input_name_to_nodes, output_name_to_node):
- gelu_op_type = node.op_type
- fuse_op_type = "BiasGelu" if gelu_op_type == "Gelu" else "FastGelu"
- if len(node.input) != 1:
- return
- nodes = self.model.match_parent_path(node, ["Add", "MatMul"], [0, None])
- if nodes is None:
- return
- (add, matmul) = nodes
- bias_weight = None
- # bias should be one dimension
- bias_index = -1
- for i, input in enumerate(add.input):
- initializer = self.model.get_initializer(input)
- if initializer is None:
- continue
- bias_index = i
- bias_weight = NumpyHelper.to_array(initializer)
- break
- if bias_weight is None:
- return
- if len(bias_weight.shape) != 1:
- return
- subgraph_nodes = [node, add]
- if not self.model.is_safe_to_fuse_nodes(
- subgraph_nodes, [node.output[0]], input_name_to_nodes, output_name_to_node
- ):
- return
- self.nodes_to_remove.extend(subgraph_nodes)
- fused_node = helper.make_node(
- fuse_op_type,
- inputs=[matmul.output[0], add.input[bias_index]],
- outputs=node.output,
- name=self.model.create_node_name(fuse_op_type, gelu_op_type + "_AddBias_"),
- )
- fused_node.domain = "com.microsoft"
- self.nodes_to_add.append(fused_node)
- self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
|