fusion.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. from __future__ import annotations
  7. from collections import deque
  8. import onnx
  9. from ..onnx_model import ONNXModel
  10. class Fusion:
  11. """
  12. Base class for fusions.
  13. """
  14. def __init__(self, model: ONNXModel, fused_op_type: str, search_op_type: str):
  15. self.search_op_type: str = search_op_type
  16. self.fused_op_type: str = fused_op_type
  17. self.model: ONNXModel = model
  18. self.nodes_to_remove: list = []
  19. self.nodes_to_add: list = []
  20. self._new_node_name_prefix = self.fused_op_type + "_fused_" + self.search_op_type + "_"
  21. self._new_node_name_suffix = None # int|None used to create unique node names for the fused ops.
  22. def fuse(
  23. self,
  24. node: onnx.NodeProto,
  25. input_name_to_nodes: dict[str, list[onnx.NodeProto]],
  26. output_name_to_node: dict[str, onnx.NodeProto],
  27. ):
  28. """
  29. Interface function for derived fusion classes. Tries to fuse a node sequence containing
  30. the specified node.
  31. """
  32. raise NotImplementedError
  33. def apply(self) -> bool:
  34. """
  35. Apply graph fusion on the entire model graph.
  36. """
  37. input_name_to_nodes = self.model.input_name_to_nodes()
  38. output_name_to_node = self.model.output_name_to_node()
  39. for node in self.model.nodes():
  40. if node.op_type == self.search_op_type:
  41. self.fuse(node, input_name_to_nodes, output_name_to_node)
  42. self.model.remove_nodes(self.nodes_to_remove)
  43. self.model.add_nodes(self.nodes_to_add)
  44. graph_updated = bool(self.nodes_to_remove or self.nodes_to_add)
  45. if graph_updated:
  46. self.model.remove_unused_constant()
  47. return graph_updated
  48. def create_unique_node_name(self):
  49. prefix = self._new_node_name_prefix
  50. if self._new_node_name_suffix is None:
  51. largest_suffix: int = self.model.get_largest_node_name_suffix(prefix)
  52. self._new_node_name_suffix = largest_suffix + 1
  53. new_name = f"{prefix}{self._new_node_name_suffix!s}"
  54. self._new_node_name_suffix += 1
  55. return new_name
  56. @staticmethod
  57. def is_safe_to_fuse_nodes(
  58. nodes_to_remove: list[onnx.NodeProto],
  59. keep_outputs: list[str],
  60. input_name_to_nodes: dict[str, list[onnx.NodeProto]],
  61. output_name_to_node: dict[str, onnx.NodeProto],
  62. ) -> bool:
  63. for node_to_remove in nodes_to_remove:
  64. for output_to_remove in node_to_remove.output:
  65. if output_to_remove in keep_outputs:
  66. continue
  67. if output_to_remove in input_name_to_nodes:
  68. for impacted_node in input_name_to_nodes[output_to_remove]:
  69. if impacted_node not in nodes_to_remove:
  70. # Not safe to remove nodes since output is used by impacted_node
  71. return False
  72. return True
  73. @staticmethod
  74. def get_node_attribute(node: onnx.NodeProto, attribute_name: str):
  75. for attr in node.attribute:
  76. if attr.name == attribute_name:
  77. value = onnx.helper.get_attribute_value(attr)
  78. return value
  79. return None
  80. @staticmethod
  81. def input_index(node_output: str, child_node: onnx.NodeProto) -> int:
  82. for index, input_name in enumerate(child_node.input):
  83. if input_name == node_output:
  84. return index
  85. return -1
  86. @staticmethod
  87. def tensor_shape_to_list(tensor_type) -> list[int]:
  88. shape_list = []
  89. for d in tensor_type.shape.dim:
  90. if d.HasField("dim_value"):
  91. shape_list.append(d.dim_value) # known dimension
  92. elif d.HasField("dim_param"):
  93. shape_list.append(d.dim_param) # unknown dimension with symbolic name
  94. else:
  95. shape_list.append("?") # shall not happen
  96. return shape_list
  97. def get_constant_input(self, node: onnx.NodeProto):
  98. for i, inp in enumerate(node.input):
  99. value = self.model.get_constant_value(inp)
  100. if value is not None:
  101. return i, value
  102. return None, None
  103. def find_constant_input(self, node: onnx.NodeProto, expected_value: float, delta: float = 0.000001) -> int:
  104. i, value = self.get_constant_input(node)
  105. if value is not None and value.size == 1 and abs(value - expected_value) < delta:
  106. return i
  107. return -1
  108. def has_constant_input(self, node: onnx.NodeProto, expected_value: float, delta: float = 0.000001) -> bool:
  109. return self.find_constant_input(node, expected_value, delta) >= 0
  110. def is_constant_with_specified_rank(self, output_name: str, rank: int) -> bool:
  111. value = self.model.get_constant_value(output_name)
  112. if value is None:
  113. return False # Not an initializer
  114. if len(value.shape) != rank:
  115. return False # Wrong dimensions
  116. return True
  117. def match_first_parent(
  118. self,
  119. node: onnx.NodeProto,
  120. parent_op_type: str,
  121. output_name_to_node: dict[str, onnx.NodeProto] | None = None,
  122. exclude: list[onnx.NodeProto] = [], # noqa: B006
  123. ) -> tuple[onnx.NodeProto | None, int | None]:
  124. """
  125. Find parent node based on constraints on op_type.
  126. Args:
  127. node: current node.
  128. parent_op_type (str): constraint of parent node op_type.
  129. output_name_to_node (dict): dictionary with output name as key, and node as value.
  130. exclude (list): list of nodes that are excluded (not allowed to match as parent).
  131. Returns:
  132. parent: The matched parent node. None if not found.
  133. index: The input index of matched parent node. None if not found.
  134. """
  135. if output_name_to_node is None:
  136. output_name_to_node = self.model.output_name_to_node()
  137. for i, inp in enumerate(node.input):
  138. if inp in output_name_to_node:
  139. parent = output_name_to_node[inp]
  140. if parent.op_type == parent_op_type and parent not in exclude:
  141. return parent, i
  142. return None, None
  143. def match_parent(
  144. self,
  145. node: onnx.NodeProto,
  146. parent_op_type: str,
  147. input_index: int | None = None,
  148. output_name_to_node: dict[str, onnx.NodeProto] | None = None,
  149. exclude: list[onnx.NodeProto] = [], # noqa: B006
  150. return_indice: list[int] | None = None,
  151. ) -> onnx.NodeProto | None:
  152. """
  153. Find parent node based on constraints on op_type and index.
  154. When input_index is None, we will find the first parent node based on constraints,
  155. and return_indice will be appended the corresponding input index.
  156. Args:
  157. node (str): current node name.
  158. parent_op_type (str): constraint of parent node op_type.
  159. input_index (int or None): only check the parent given input index of current node.
  160. output_name_to_node (dict): dictionary with output name as key, and node as value.
  161. exclude (list): list of nodes that are excluded (not allowed to match as parent).
  162. return_indice (list): a list to append the input index when input_index is None.
  163. Returns:
  164. parent: The matched parent node.
  165. """
  166. assert node is not None
  167. assert input_index is None or input_index >= 0
  168. if output_name_to_node is None:
  169. output_name_to_node = self.model.output_name_to_node()
  170. if input_index is None:
  171. parent, index = self.match_first_parent(node, parent_op_type, output_name_to_node, exclude)
  172. if return_indice is not None:
  173. return_indice.append(index)
  174. return parent
  175. if input_index >= len(node.input):
  176. # Input index out of bounds.
  177. return None
  178. parent = self.model.get_parent(node, input_index, output_name_to_node)
  179. if parent is not None and parent.op_type == parent_op_type and parent not in exclude:
  180. return parent
  181. return None
  182. def match_parent_path(
  183. self,
  184. node: onnx.NodeProto,
  185. parent_op_types: list[str],
  186. parent_input_index: list[int] | None = None,
  187. output_name_to_node: dict[str, onnx.NodeProto] | None = None,
  188. return_indice: list[int] | None = None,
  189. ) -> list[onnx.NodeProto] | None:
  190. """
  191. Find a sequence of input edges based on constraints on parent op_type and index.
  192. When input_index is None, we will find the first parent node based on constraints,
  193. and return_indice will be appended the corresponding input index.
  194. Args:
  195. node (str): current node name.
  196. parent_op_types (str): constraint of parent node op_type of each input edge.
  197. parent_input_index (list): constraint of input index of each input edge. None means no constraint.
  198. output_name_to_node (dict): dictionary with output name as key, and node as value.
  199. return_indice (list): a list to append the input index
  200. When there is no constraint on input index of an edge.
  201. Returns:
  202. parents: a list of matched parent node.
  203. """
  204. if parent_input_index is not None:
  205. assert len(parent_input_index) == len(parent_op_types)
  206. if output_name_to_node is None:
  207. output_name_to_node = self.model.output_name_to_node()
  208. current_node = node
  209. matched_parents = []
  210. for i, op_type in enumerate(parent_op_types):
  211. matched_parent = self.match_parent(
  212. current_node,
  213. op_type,
  214. parent_input_index[i] if parent_input_index is not None else None,
  215. output_name_to_node,
  216. exclude=[],
  217. return_indice=return_indice,
  218. )
  219. if matched_parent is None:
  220. return None
  221. matched_parents.append(matched_parent)
  222. current_node = matched_parent
  223. return matched_parents
  224. def match_parent_paths(
  225. self,
  226. node: onnx.NodeProto,
  227. paths: list[tuple[list[str], list[int]]],
  228. output_name_to_node: dict[str, onnx.NodeProto],
  229. ) -> tuple[int, list[onnx.NodeProto] | None, list[int] | None]:
  230. """
  231. Find a matching parent path to the given node.
  232. """
  233. for i, path in enumerate(paths):
  234. return_indice = []
  235. matched = self.match_parent_path(node, path[0], path[1], output_name_to_node, return_indice)
  236. if matched:
  237. return i, matched, return_indice
  238. return -1, None, None
  239. def find_first_child_by_type(
  240. self,
  241. node: onnx.NodeProto,
  242. child_type: str,
  243. input_name_to_nodes: dict[str, list[onnx.NodeProto]] | None = None,
  244. recursive: bool = True,
  245. ) -> onnx.NodeProto | None:
  246. children = self.model.get_children(node, input_name_to_nodes)
  247. dq = deque(children)
  248. while len(dq) > 0:
  249. current_node = dq.pop()
  250. if current_node.op_type == child_type:
  251. return current_node
  252. if recursive:
  253. children = self.model.get_children(current_node, input_name_to_nodes)
  254. for child in children:
  255. dq.appendleft(child)
  256. return None