fusion_bias_add.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from fusion_base import Fusion
  7. from numpy import ndarray
  8. from onnx import helper
  9. from onnx_model import OnnxModel
  10. logger = getLogger(__name__)
  11. class FusionBiasAdd(Fusion):
  12. def __init__(self, model: OnnxModel):
  13. super().__init__(model, "BiasAdd", "Add")
  14. def fuse(self, add_node, input_name_to_nodes: dict, output_name_to_node: dict):
  15. """
  16. Fuse Add bias and Add skip connection into BiasAdd
  17. """
  18. nodes = self.model.match_parent_path(
  19. add_node,
  20. ["Add", "MatMul", "BiasSplitGelu", "MatMul", "SkipLayerNormalization"],
  21. [0, None, 0, 0, 0],
  22. output_name_to_node,
  23. )
  24. if nodes is None:
  25. return
  26. bias_node = nodes[0]
  27. skip_layer_norm = nodes[-1]
  28. # Check skip connection is from SkipLayerNormalization output
  29. if add_node.input[1] not in skip_layer_norm.output:
  30. return
  31. bias_index, bias_value = self.model.get_constant_input(bias_node)
  32. if not (isinstance(bias_index, int) and (bias_value is not None) and isinstance(bias_value, ndarray)):
  33. return
  34. if bias_value.ndim != 1:
  35. return
  36. self.nodes_to_remove.extend([add_node, bias_node])
  37. node_name = self.model.create_node_name("BiasAdd")
  38. fused_node = helper.make_node(
  39. "BiasAdd",
  40. inputs=[bias_node.input[1 - bias_index], bias_node.input[bias_index], add_node.input[1]],
  41. outputs=[add_node.output[0]],
  42. name=node_name,
  43. )
  44. fused_node.domain = "com.microsoft"
  45. self.nodes_to_add.append(fused_node)
  46. self.node_name_to_graph_name[node_name] = self.this_graph_name