| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311 |
- # -------------------------------------------------------------------------
- # Copyright (c) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License. See License.txt in the project root for
- # license information.
- # --------------------------------------------------------------------------
- from __future__ import annotations
- from collections import deque
- import onnx
- from ..onnx_model import ONNXModel
- class Fusion:
- """
- Base class for fusions.
- """
- def __init__(self, model: ONNXModel, fused_op_type: str, search_op_type: str):
- self.search_op_type: str = search_op_type
- self.fused_op_type: str = fused_op_type
- self.model: ONNXModel = model
- self.nodes_to_remove: list = []
- self.nodes_to_add: list = []
- self._new_node_name_prefix = self.fused_op_type + "_fused_" + self.search_op_type + "_"
- self._new_node_name_suffix = None # int|None used to create unique node names for the fused ops.
- def fuse(
- self,
- node: onnx.NodeProto,
- input_name_to_nodes: dict[str, list[onnx.NodeProto]],
- output_name_to_node: dict[str, onnx.NodeProto],
- ):
- """
- Interface function for derived fusion classes. Tries to fuse a node sequence containing
- the specified node.
- """
- raise NotImplementedError
- def apply(self) -> bool:
- """
- Apply graph fusion on the entire model graph.
- """
- input_name_to_nodes = self.model.input_name_to_nodes()
- output_name_to_node = self.model.output_name_to_node()
- for node in self.model.nodes():
- if node.op_type == self.search_op_type:
- self.fuse(node, input_name_to_nodes, output_name_to_node)
- self.model.remove_nodes(self.nodes_to_remove)
- self.model.add_nodes(self.nodes_to_add)
- graph_updated = bool(self.nodes_to_remove or self.nodes_to_add)
- if graph_updated:
- self.model.remove_unused_constant()
- return graph_updated
- def create_unique_node_name(self):
- prefix = self._new_node_name_prefix
- if self._new_node_name_suffix is None:
- largest_suffix: int = self.model.get_largest_node_name_suffix(prefix)
- self._new_node_name_suffix = largest_suffix + 1
- new_name = f"{prefix}{self._new_node_name_suffix!s}"
- self._new_node_name_suffix += 1
- return new_name
- @staticmethod
- def is_safe_to_fuse_nodes(
- nodes_to_remove: list[onnx.NodeProto],
- keep_outputs: list[str],
- input_name_to_nodes: dict[str, list[onnx.NodeProto]],
- output_name_to_node: dict[str, onnx.NodeProto],
- ) -> bool:
- for node_to_remove in nodes_to_remove:
- for output_to_remove in node_to_remove.output:
- if output_to_remove in keep_outputs:
- continue
- if output_to_remove in input_name_to_nodes:
- for impacted_node in input_name_to_nodes[output_to_remove]:
- if impacted_node not in nodes_to_remove:
- # Not safe to remove nodes since output is used by impacted_node
- return False
- return True
- @staticmethod
- def get_node_attribute(node: onnx.NodeProto, attribute_name: str):
- for attr in node.attribute:
- if attr.name == attribute_name:
- value = onnx.helper.get_attribute_value(attr)
- return value
- return None
- @staticmethod
- def input_index(node_output: str, child_node: onnx.NodeProto) -> int:
- for index, input_name in enumerate(child_node.input):
- if input_name == node_output:
- return index
- return -1
- @staticmethod
- def tensor_shape_to_list(tensor_type) -> list[int]:
- shape_list = []
- for d in tensor_type.shape.dim:
- if d.HasField("dim_value"):
- shape_list.append(d.dim_value) # known dimension
- elif d.HasField("dim_param"):
- shape_list.append(d.dim_param) # unknown dimension with symbolic name
- else:
- shape_list.append("?") # shall not happen
- return shape_list
- def get_constant_input(self, node: onnx.NodeProto):
- for i, inp in enumerate(node.input):
- value = self.model.get_constant_value(inp)
- if value is not None:
- return i, value
- return None, None
- def find_constant_input(self, node: onnx.NodeProto, expected_value: float, delta: float = 0.000001) -> int:
- i, value = self.get_constant_input(node)
- if value is not None and value.size == 1 and abs(value - expected_value) < delta:
- return i
- return -1
- def has_constant_input(self, node: onnx.NodeProto, expected_value: float, delta: float = 0.000001) -> bool:
- return self.find_constant_input(node, expected_value, delta) >= 0
- def is_constant_with_specified_rank(self, output_name: str, rank: int) -> bool:
- value = self.model.get_constant_value(output_name)
- if value is None:
- return False # Not an initializer
- if len(value.shape) != rank:
- return False # Wrong dimensions
- return True
- def match_first_parent(
- self,
- node: onnx.NodeProto,
- parent_op_type: str,
- output_name_to_node: dict[str, onnx.NodeProto] | None = None,
- exclude: list[onnx.NodeProto] = [], # noqa: B006
- ) -> tuple[onnx.NodeProto | None, int | None]:
- """
- Find parent node based on constraints on op_type.
- Args:
- node: current node.
- parent_op_type (str): constraint of parent node op_type.
- output_name_to_node (dict): dictionary with output name as key, and node as value.
- exclude (list): list of nodes that are excluded (not allowed to match as parent).
- Returns:
- parent: The matched parent node. None if not found.
- index: The input index of matched parent node. None if not found.
- """
- if output_name_to_node is None:
- output_name_to_node = self.model.output_name_to_node()
- for i, inp in enumerate(node.input):
- if inp in output_name_to_node:
- parent = output_name_to_node[inp]
- if parent.op_type == parent_op_type and parent not in exclude:
- return parent, i
- return None, None
- def match_parent(
- self,
- node: onnx.NodeProto,
- parent_op_type: str,
- input_index: int | None = None,
- output_name_to_node: dict[str, onnx.NodeProto] | None = None,
- exclude: list[onnx.NodeProto] = [], # noqa: B006
- return_indice: list[int] | None = None,
- ) -> onnx.NodeProto | None:
- """
- Find parent node based on constraints on op_type and index.
- When input_index is None, we will find the first parent node based on constraints,
- and return_indice will be appended the corresponding input index.
- Args:
- node (str): current node name.
- parent_op_type (str): constraint of parent node op_type.
- input_index (int or None): only check the parent given input index of current node.
- output_name_to_node (dict): dictionary with output name as key, and node as value.
- exclude (list): list of nodes that are excluded (not allowed to match as parent).
- return_indice (list): a list to append the input index when input_index is None.
- Returns:
- parent: The matched parent node.
- """
- assert node is not None
- assert input_index is None or input_index >= 0
- if output_name_to_node is None:
- output_name_to_node = self.model.output_name_to_node()
- if input_index is None:
- parent, index = self.match_first_parent(node, parent_op_type, output_name_to_node, exclude)
- if return_indice is not None:
- return_indice.append(index)
- return parent
- if input_index >= len(node.input):
- # Input index out of bounds.
- return None
- parent = self.model.get_parent(node, input_index, output_name_to_node)
- if parent is not None and parent.op_type == parent_op_type and parent not in exclude:
- return parent
- return None
- def match_parent_path(
- self,
- node: onnx.NodeProto,
- parent_op_types: list[str],
- parent_input_index: list[int] | None = None,
- output_name_to_node: dict[str, onnx.NodeProto] | None = None,
- return_indice: list[int] | None = None,
- ) -> list[onnx.NodeProto] | None:
- """
- Find a sequence of input edges based on constraints on parent op_type and index.
- When input_index is None, we will find the first parent node based on constraints,
- and return_indice will be appended the corresponding input index.
- Args:
- node (str): current node name.
- parent_op_types (str): constraint of parent node op_type of each input edge.
- parent_input_index (list): constraint of input index of each input edge. None means no constraint.
- output_name_to_node (dict): dictionary with output name as key, and node as value.
- return_indice (list): a list to append the input index
- When there is no constraint on input index of an edge.
- Returns:
- parents: a list of matched parent node.
- """
- if parent_input_index is not None:
- assert len(parent_input_index) == len(parent_op_types)
- if output_name_to_node is None:
- output_name_to_node = self.model.output_name_to_node()
- current_node = node
- matched_parents = []
- for i, op_type in enumerate(parent_op_types):
- matched_parent = self.match_parent(
- current_node,
- op_type,
- parent_input_index[i] if parent_input_index is not None else None,
- output_name_to_node,
- exclude=[],
- return_indice=return_indice,
- )
- if matched_parent is None:
- return None
- matched_parents.append(matched_parent)
- current_node = matched_parent
- return matched_parents
- def match_parent_paths(
- self,
- node: onnx.NodeProto,
- paths: list[tuple[list[str], list[int]]],
- output_name_to_node: dict[str, onnx.NodeProto],
- ) -> tuple[int, list[onnx.NodeProto] | None, list[int] | None]:
- """
- Find a matching parent path to the given node.
- """
- for i, path in enumerate(paths):
- return_indice = []
- matched = self.match_parent_path(node, path[0], path[1], output_name_to_node, return_indice)
- if matched:
- return i, matched, return_indice
- return -1, None, None
- def find_first_child_by_type(
- self,
- node: onnx.NodeProto,
- child_type: str,
- input_name_to_nodes: dict[str, list[onnx.NodeProto]] | None = None,
- recursive: bool = True,
- ) -> onnx.NodeProto | None:
- children = self.model.get_children(node, input_name_to_nodes)
- dq = deque(children)
- while len(dq) > 0:
- current_node = dq.pop()
- if current_node.op_type == child_type:
- return current_node
- if recursive:
- children = self.model.get_children(current_node, input_name_to_nodes)
- for child in children:
- dq.appendleft(child)
- return None
|