fusion_transpose.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from fusion_base import Fusion
  7. from fusion_utils import FusionUtils
  8. from onnx import NodeProto, TensorProto, helper
  9. from onnx_model import OnnxModel
  10. logger = getLogger(__name__)
  11. class FusionTranspose(Fusion):
  12. def __init__(self, model: OnnxModel):
  13. super().__init__(model, "Transpose", "Transpose")
  14. def fuse(
  15. self,
  16. transpose_node: NodeProto,
  17. input_name_to_nodes: dict[str, list[NodeProto]],
  18. output_name_to_node: dict[str, NodeProto],
  19. ):
  20. """
  21. Note that onnxruntime will do comprehensive transpose optimization after loading model.
  22. The purpose of this fusion is to make graph clean before running onnxruntime.
  23. Case 1:
  24. (input)-->Transpose(perm=a)-->Transpose(perm=b)-->
  25. After:
  26. (input)-->Transpose(perm=a)--> (this path can be removed if the output is not used anymore)
  27. |
  28. +----->Transpose(perm=a*b)-->
  29. Case 2 (Cast has only one child):
  30. (input)-->Transpose(perm=a)--> Cast -->Transpose(perm=b)-->
  31. After:
  32. (input)-->Transpose(perm=a)--> (this path can be removed if the output is not used anymore)
  33. |
  34. +----->Cast --> Transpose(perm=a*b)-->
  35. """
  36. transpose_b = transpose_node
  37. if transpose_b.input[0] not in output_name_to_node:
  38. return
  39. transpose_a = output_name_to_node[transpose_b.input[0]]
  40. if transpose_a.op_type != "Cast":
  41. cast_node = None
  42. else:
  43. cast_node = transpose_a
  44. cast_children = self.model.get_children(cast_node, input_name_to_nodes)
  45. if cast_children and len(cast_children) > 1:
  46. return
  47. if cast_node.input[0] not in output_name_to_node:
  48. return
  49. transpose_a = output_name_to_node[cast_node.input[0]]
  50. if transpose_a.op_type != "Transpose":
  51. return
  52. permutation = OnnxModel.get_node_attribute(transpose_b, "perm")
  53. assert isinstance(permutation, list)
  54. parent_permutation = OnnxModel.get_node_attribute(transpose_a, "perm")
  55. assert isinstance(parent_permutation, list)
  56. assert len(parent_permutation) == len(permutation)
  57. output_permutation = []
  58. for _j, index in enumerate(permutation):
  59. output_permutation.append(parent_permutation[index])
  60. if cast_node is None:
  61. if FusionUtils.skip_parent(self.model, transpose_b, transpose_a, input_name_to_nodes):
  62. self.nodes_to_remove.append(transpose_a)
  63. else:
  64. if FusionUtils.skip_parent(self.model, cast_node, transpose_a, input_name_to_nodes):
  65. self.nodes_to_remove.append(transpose_a)
  66. transpose_b.ClearField("attribute")
  67. transpose_b.attribute.extend([helper.make_attribute("perm", output_permutation)])
  68. class FusionInsertTranspose(Fusion):
  69. def __init__(self, model: OnnxModel):
  70. super().__init__(model, "", "GroupNorm")
  71. def create_transpose_node(self, input_name: str, perm: list[int], output_name=None):
  72. """Append a Transpose node after an input"""
  73. node_name = self.model.create_node_name("Transpose")
  74. if output_name is None:
  75. output_name = node_name + "_out" + "-" + input_name
  76. transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name)
  77. transpose_node.attribute.extend([helper.make_attribute("perm", perm)])
  78. return transpose_node
  79. def fuse(
  80. self,
  81. group_norm_node: NodeProto,
  82. input_name_to_nodes: dict[str, list[NodeProto]],
  83. output_name_to_node: dict[str, NodeProto],
  84. ):
  85. """
  86. This optimization will insert an Transpose, and onnxruntime transpose optimizer will remove it together with
  87. another Transpose so that we can get effect of reducing one Transpose after onnxruntime optimization.
  88. Before:
  89. --> Gemm --> Unsqueeze(axes=[2]) --> Unsqueeze(axes=[3]) --> Add --> Transpose([0,2,3,1]) --> GroupNorm
  90. After:
  91. --> Gemm --> Unsqueeze(axes=[1]) --> Unsqueeze(axes=[2]) -->Transpose([0,3,1,2]) --> Add --> Transpose([0,2,3,1]) --> GroupNorm
  92. """
  93. gemm_path = self.model.match_parent_path(
  94. group_norm_node, ["Transpose", "Add", "Unsqueeze", "Unsqueeze", "Gemm"], [0, 0, None, 0, 0]
  95. )
  96. if gemm_path is None:
  97. return
  98. transpose, add, unsqueeze_3, unsqueeze_2, gemm = gemm_path
  99. if self.model.find_graph_output(unsqueeze_3.output[0]):
  100. return
  101. permutation = OnnxModel.get_node_attribute(transpose, "perm")
  102. assert isinstance(permutation, list)
  103. if permutation != [0, 2, 3, 1]:
  104. return
  105. if not (
  106. len(unsqueeze_3.input) == 2
  107. and self.model.get_constant_value(unsqueeze_3.input[1]) == 3
  108. and len(unsqueeze_2.input) == 2
  109. and self.model.get_constant_value(unsqueeze_2.input[1]) == 2
  110. and len(self.model.get_children(gemm, input_name_to_nodes)) == 1
  111. and len(self.model.get_children(unsqueeze_3, input_name_to_nodes)) == 1
  112. and len(self.model.get_children(unsqueeze_2, input_name_to_nodes)) == 1
  113. ):
  114. return
  115. # Here we use hard-coded name so that it could be shared for the whole model.
  116. axes_1 = "ort_const_unsqueeze_axes_1"
  117. if self.model.get_initializer(axes_1) is None:
  118. self.add_initializer(
  119. name=axes_1,
  120. data_type=TensorProto.INT64,
  121. dims=[1],
  122. vals=[1],
  123. raw=False,
  124. )
  125. axes_2 = "ort_const_unsqueeze_axes_2"
  126. if self.model.get_initializer(axes_2) is None:
  127. self.add_initializer(
  128. name=axes_2,
  129. data_type=TensorProto.INT64,
  130. dims=[1],
  131. vals=[2],
  132. raw=False,
  133. )
  134. unsqueeze_3.input[1] = "ort_const_unsqueeze_axes_2"
  135. unsqueeze_2.input[1] = "ort_const_unsqueeze_axes_1"
  136. transpose_output_name = self.model.create_node_name("Transpose") + "_NCHW"
  137. self.model.replace_input_of_all_nodes(unsqueeze_3.output[0], transpose_output_name)
  138. new_transpose = self.create_transpose_node(unsqueeze_3.output[0], [0, 3, 1, 2], transpose_output_name)
  139. self.model.add_node(new_transpose, self.this_graph_name)
  140. self.increase_counter("Insert Transpose")