fusion_shape.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from fusion_base import Fusion
  7. from fusion_utils import FusionUtils
  8. from numpy import ndarray
  9. from onnx import NodeProto, TensorProto
  10. from onnx_model import OnnxModel
  11. logger = getLogger(__name__)
  12. class FusionShape(Fusion):
  13. def __init__(self, model: OnnxModel):
  14. super().__init__(model, "Shape", "Concat")
  15. self.utils = FusionUtils(model)
  16. self.shape_infer = None
  17. self.shape_infer_done = False
  18. def get_dimensions_from_tensor_proto(self, tensor_proto: TensorProto) -> int | None:
  19. if tensor_proto.type.tensor_type.HasField("shape"):
  20. return len(tensor_proto.type.tensor_type.shape.dim)
  21. else:
  22. return None
  23. def get_dimensions(self, input_name: str) -> int | None:
  24. shape = self.model.get_shape(input_name)
  25. if shape is not None:
  26. return len(shape)
  27. if not self.shape_infer_done:
  28. self.shape_infer = self.model.infer_runtime_shape(update=True)
  29. self.shape_infer_done = True
  30. if self.shape_infer is not None:
  31. return self.get_dimensions_from_tensor_proto(self.shape_infer.known_vi_[input_name])
  32. return None
  33. def fuse(
  34. self,
  35. concat_node: NodeProto,
  36. input_name_to_nodes: dict[str, list[NodeProto]],
  37. output_name_to_node: dict[str, NodeProto],
  38. ):
  39. #
  40. # Simplify subgraph like
  41. #
  42. # (2d_input)
  43. # / \
  44. # Shape shape
  45. # / \
  46. # Gather(indices=0) Gather(indices=1)
  47. # | |
  48. # Unsqueeze(axes=0) Unsqueeze(axes=0)
  49. # \ /
  50. # Concat
  51. # |
  52. #
  53. # into (2d_input) --> Shape -->
  54. #
  55. opset_version = self.model.get_opset_version()
  56. inputs = len(concat_node.input)
  57. root = None
  58. shape_output = None
  59. for i in range(inputs):
  60. path = self.model.match_parent_path(
  61. concat_node,
  62. ["Unsqueeze", "Gather", "Shape"],
  63. [i, 0, 0],
  64. output_name_to_node,
  65. )
  66. if path is None:
  67. return
  68. unsqueeze, gather, shape = path
  69. if i == 0:
  70. shape_output = shape.output[0]
  71. if root is None:
  72. root = shape.input[0]
  73. if self.get_dimensions(root) != inputs:
  74. return
  75. elif shape.input[0] != root:
  76. return
  77. if not FusionUtils.check_node_attribute(unsqueeze, "axis", 0, default_value=0):
  78. return
  79. if opset_version < 13:
  80. if not FusionUtils.check_node_attribute(unsqueeze, "axes", [0]):
  81. return
  82. else:
  83. if not self.utils.check_node_input_value(unsqueeze, 1, [0]):
  84. return
  85. value = self.model.get_constant_value(gather.input[1])
  86. if not (isinstance(value, ndarray) and value.size == 1 and value.item() == i):
  87. return
  88. if self.model.find_graph_output(concat_node.output[0]) is None:
  89. self.model.replace_input_of_all_nodes(concat_node.output[0], shape_output)
  90. self.increase_counter("Reshape")
  91. self.prune_graph = True