fusion_constant_fold.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from fusion_base import Fusion
  7. from fusion_utils import NumpyHelper
  8. from onnx_model import OnnxModel
  9. logger = getLogger(__name__)
  10. class FusionConstantFold(Fusion):
  11. def __init__(self, model: OnnxModel):
  12. super().__init__(model, "", ["Transpose"])
  13. self.count = 0
  14. def apply(self):
  15. super().apply()
  16. if self.count > 0:
  17. logger.info(f"Constant Folded: {self.count}")
  18. def fuse(self, node, input_name_to_nodes, output_name_to_node):
  19. """
  20. Apply multiple fusions on Transpose nodes that can be constant folded.
  21. """
  22. self.fuse_1(node, input_name_to_nodes, output_name_to_node)
  23. self.fuse_2(node, input_name_to_nodes, output_name_to_node)
  24. def fuse_1(self, node, input_name_to_nodes, output_name_to_node):
  25. """
  26. Constant fold any initializer data representing a MatMul's
  27. weights that are stored in a Transpose op
  28. Ex: Transpose --> Gemm or Transpose --> MatMul
  29. """
  30. # Check if Transpose node only has one input and one output
  31. if len(node.input) != 1 or len(node.output) != 1:
  32. logger.debug("fuse_constant_fold: node has more than one input or output")
  33. return
  34. # Check if input is initializer data
  35. proto = self.model.get_initializer(node.input[0])
  36. if proto is None:
  37. logger.debug("fuse_constant_fold: failed to identify initializer input")
  38. return
  39. # Check that all nodes using input are Transpose ops that also only use the initializer data as input
  40. skip = False
  41. for child_node in input_name_to_nodes[node.input[0]]:
  42. if not (child_node.op_type == "Transpose" and len(node.input) == 1):
  43. skip = True
  44. break
  45. if skip:
  46. logger.debug("fuse_constant_fold: other non-Transpose nodes use the initializer")
  47. return
  48. # Check that all nodes using output are Gemm or MatMul ops
  49. for child_node in input_name_to_nodes[node.output[0]]:
  50. if not (child_node.op_type == "Gemm" or child_node.op_type == "MatMul"):
  51. skip = True
  52. break
  53. if skip:
  54. logger.debug("fuse_constant_fold: other non-Gemm and non-MatMul nodes use the transposed data")
  55. return
  56. # Check if initializer data is 2D
  57. weight = NumpyHelper.to_array(proto)
  58. if len(weight.shape) != 2:
  59. logger.debug("fuse_constant_fold: shape of initializer data is not 2D")
  60. return
  61. # Remove old TensorProto and add new TensorProto while re-using same name
  62. name = proto.name
  63. dtype = proto.data_type
  64. self.remove_initializer(proto)
  65. self.add_initializer(
  66. name=name,
  67. data_type=dtype,
  68. dims=[weight.shape[1], weight.shape[0]],
  69. vals=weight.T,
  70. )
  71. # Update weights input to be the initializer name and not
  72. # the output of the Transpose op
  73. for child_node in input_name_to_nodes[node.output[0]]:
  74. for i in range(len(child_node.input)):
  75. if child_node.input[i] == node.output[0]:
  76. child_node.input[i] = node.input[0]
  77. if child_node.op_type == "Gemm" and (i == 0 or i == 1):
  78. # Ensure that transA/transB is set to 0 in Gemm
  79. key = "transA" if i == 0 else "transB"
  80. for j, attr_key in enumerate(child_node.attribute):
  81. if attr_key.name == key:
  82. child_node.attribute[j].i = 0
  83. # Add node to list of nodes to remove
  84. self.nodes_to_remove.append(node)
  85. self.count += 1
  86. def fuse_2(self, node, input_name_to_nodes, output_name_to_node):
  87. """
  88. Constant fold any Transpose --> Transpose ops since the root input
  89. is the final result
  90. Ex: root_input --> Transpose --> Transpose --> next_node to root_input --> next_node
  91. """
  92. # Check if Transpose node only has one input and one output
  93. if len(node.input) != 1 or len(node.output) != 1:
  94. logger.debug("fuse_constant_fold: node has more than one input or output")
  95. return
  96. # Check if parent node is Transpose node with only one input and one output
  97. parent_node = self.model.match_parent(node, "Transpose", 0)
  98. if parent_node is None:
  99. logger.debug("fuse_constant_fold: failed to identify parent Transpose node")
  100. return
  101. if len(parent_node.input) != 1 or len(parent_node.output) != 1:
  102. logger.debug("fuse_constant_fold: parent node has more than one input or output")
  103. return
  104. node_perm = node.attribute[0].ints
  105. parent_node_perm = parent_node.attribute[0].ints
  106. if node_perm != parent_node_perm:
  107. logger.debug("fuse_constant_fold: Transpose node permutations aren't identical")
  108. return
  109. # For nodes that use output of child Transpose node as an input,
  110. # replace that input with root_input
  111. root_input = parent_node.input[0]
  112. output_nodes = input_name_to_nodes[node.output[0]]
  113. for output_node in output_nodes:
  114. for i, input_ in enumerate(output_node.input):
  115. if input_ == node.output[0]:
  116. output_node.input[i] = root_input
  117. # Add node to list of nodes to remove
  118. self.nodes_to_remove.append(node)
  119. self.nodes_to_remove.append(parent_node)
  120. self.count += 1