fusion_utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. import numpy
  7. from numpy import array_equal, ndarray
  8. from onnx import NodeProto, TensorProto, helper, numpy_helper
  9. from onnx_model import OnnxModel
  10. logger = getLogger(__name__)
  11. class FusionUtils:
  12. def __init__(self, model: OnnxModel):
  13. self.model: OnnxModel = model
  14. def cast_graph_input_to_int32(self, input_name: str) -> tuple[bool, str]:
  15. graph_input = self.model.find_graph_input(input_name)
  16. if graph_input is not None and graph_input.type.tensor_type.elem_type != TensorProto.INT32:
  17. cast_output, cast_node = self.cast_input_to_int32(input_name)
  18. logger.debug(f"Casted graph input {input_name} to int32")
  19. return True, cast_output
  20. logger.debug(f"Did not cast graph input {input_name} to int32: found {graph_input is not None}")
  21. return False, input_name
  22. def cast_input(self, input_name: str, target_type="int32"):
  23. output_name = input_name + "_" + target_type
  24. if target_type == "int32":
  25. to_type = int(TensorProto.INT32)
  26. elif target_type == "float32":
  27. to_type = int(TensorProto.FLOAT)
  28. elif target_type == "float16":
  29. to_type = int(TensorProto.FLOAT16)
  30. else:
  31. raise ValueError("Invalid target_type: {target_type}")
  32. cast_node = self.add_cast_node(input_name, to_type, output_name)
  33. return output_name, cast_node
  34. def add_cast_node(
  35. self,
  36. input_name: str,
  37. to_type: int,
  38. output_name: str | None = None,
  39. output_name_to_node=None,
  40. graph_name: str | None = None,
  41. ):
  42. if output_name is None:
  43. output_name = input_name + f"_cast_to_{to_type}"
  44. # Avoid consequent Cast nodes.
  45. inputs = [input_name]
  46. if output_name_to_node is None:
  47. output_name_to_node = self.model.output_name_to_node()
  48. if input_name in output_name_to_node:
  49. parent_node = output_name_to_node[input_name]
  50. if parent_node and parent_node.op_type == "Cast":
  51. inputs = [parent_node.input[0]]
  52. cast_node = helper.make_node("Cast", inputs=inputs, outputs=[output_name])
  53. cast_node.attribute.extend([helper.make_attribute("to", to_type)])
  54. self.model.add_node(cast_node, graph_name=graph_name)
  55. return cast_node
  56. def cast_input_to_int32(self, input_name: str):
  57. return self.cast_input(input_name, "int32")
  58. def remove_cast_int32(self, input_name: str):
  59. input_name_to_nodes = self.model.input_name_to_nodes()
  60. nodes = input_name_to_nodes[input_name]
  61. for node in nodes:
  62. if node.op_type == "Cast":
  63. is_int32 = False
  64. for att in node.attribute:
  65. if att.name == "to" and att.i == int(TensorProto.INT32):
  66. is_int32 = True
  67. break
  68. if is_int32:
  69. output_name = node.output[0]
  70. self.model.remove_node(node)
  71. self.model.replace_input_of_all_nodes(output_name, input_name)
  72. @staticmethod
  73. def update_node_input(node, i, new_input_name, input_name_to_nodes):
  74. old_input_reference = 0
  75. if (node.input[i] in input_name_to_nodes) and node in input_name_to_nodes[node.input[i]]:
  76. input_name_to_nodes[node.input[i]].remove(node)
  77. old_input_reference = len(input_name_to_nodes[node.input[i]])
  78. node.input[i] = new_input_name
  79. if new_input_name in input_name_to_nodes:
  80. input_name_to_nodes[new_input_name].append(node)
  81. else:
  82. input_name_to_nodes[new_input_name] = [node]
  83. return old_input_reference
  84. @staticmethod
  85. def skip_parent(model: OnnxModel, node, parent_node, input_name_to_nodes, node_input_index=0, parent_input_index=0):
  86. """
  87. Before:
  88. (input)-->parent-->node-->(output)
  89. After:
  90. (input)-->parent-->
  91. |
  92. +----->node-->(output)
  93. This function returns a flag whether the parent node can be removed.
  94. """
  95. old_input_name = node.input[node_input_index]
  96. new_input_name = parent_node.input[parent_input_index]
  97. old_input_reference = FusionUtils.update_node_input(node, node_input_index, new_input_name, input_name_to_nodes)
  98. # We can remove the first Transpose if its output is not used (linked to graph output or other nodes) anymore.
  99. parent_can_be_removed = (old_input_reference == 0) and not model.find_graph_output(old_input_name)
  100. return parent_can_be_removed
  101. def get_squeeze_or_unsqueeze_axes(self, node: NodeProto) -> ndarray | None:
  102. assert node.op_type in ["Squeeze", "Unsqueeze"]
  103. # For opset >= 13, axes is an input instead of an attribute.
  104. if len(node.input) > 1:
  105. return self.model.get_constant_value(node.input[1])
  106. axes = None
  107. for attr in node.attribute:
  108. if attr.name == "axes":
  109. axes = helper.get_attribute_value(attr)
  110. return axes
  111. @staticmethod
  112. def check_node_attribute(node, attribute_name: str, expected_value, default_value=None):
  113. """Verify that a node has expected value for an attribute.
  114. Args:
  115. node (NodeProto): a node to check
  116. attribute_name (str): name of attribute
  117. expected_value (Any): expected value of the attribute
  118. default_value (Any, optional): default value if the attribute does not exist. Defaults to None.
  119. Returns:
  120. bool: whether the check is passed or not
  121. """
  122. value = default_value
  123. for attr in node.attribute:
  124. if attr.name == attribute_name:
  125. value = helper.get_attribute_value(attr)
  126. if isinstance(expected_value, list):
  127. return (isinstance(value, (ndarray, list))) and array_equal(expected_value, value, equal_nan=False)
  128. else:
  129. return value == expected_value
  130. @staticmethod
  131. def transpose_2d_int8_tensor(tensor: TensorProto):
  132. """Transpose a 2-D INT8 TensorProto
  133. Args:
  134. tensor (TensorProto): tensor to be transposed
  135. Returns:
  136. tensor (TensorProto): transposed tensor
  137. """
  138. if not isinstance(tensor, TensorProto):
  139. raise TypeError(f"Expected input type is an ONNX TensorProto but got {type(tensor)}")
  140. if len(tensor.dims) != 2 or tensor.data_type != TensorProto.INT8:
  141. raise ValueError("Only INT8 2-D tensors can be transposed")
  142. if tensor.raw_data:
  143. int32_data = numpy.reshape(numpy.frombuffer(tensor.raw_data, dtype="int8"), tensor.dims)
  144. int32_transposed_data = numpy.transpose(int32_data, [1, 0])
  145. tensor.raw_data = int32_transposed_data.tobytes()
  146. else:
  147. raise ValueError("only raw buffer supported")
  148. return tensor
  149. @staticmethod
  150. def check_qdq_node_for_fusion(node: NodeProto, model: OnnxModel, allow_per_tensor_quantization_only=True):
  151. """Verify if a provided QuantizeLinear (Q) / DequantizeLinear (DQ) node is a good candidate for fusion.
  152. It is a good candidate for fusion if:
  153. (1) The Q/DQ node is for per-tensor quantization if allow_per_tensor_quantization_only is `True`
  154. (2) The Q/DQ node should have constant scale
  155. (3) The Q/DQ node should have a zero point of 0
  156. Args:
  157. node (NodeProto): a Q/DQ node to check
  158. Returns:
  159. bool: whether the check is passed or not
  160. """
  161. if node.op_type not in {"QuantizeLinear", "DequantizeLinear"}:
  162. logger.debug(f"Provided node is not a Q/DQ node. Op Type: {node.op_type}")
  163. scale = model.get_constant_value(node.input[1])
  164. # Scale is not constant
  165. if scale is None:
  166. return False
  167. # Not per-tensor quantization
  168. scale_has_single_element = scale.ndim == 0 or (scale.ndim == 1 and scale.shape[0] == 1)
  169. if allow_per_tensor_quantization_only and not scale_has_single_element:
  170. return False
  171. # If the Q/DQ node has no zero point input, it is assumed to be 0 (per ONNX spec)
  172. if len(node.input) == 2:
  173. return True
  174. # Zero point should be constant and should have a value of 0
  175. zero_point = model.get_constant_value(node.input[2])
  176. # Zero point and scale should have same number of dims
  177. if scale.ndim != zero_point.ndim:
  178. return False
  179. # Zero point is not constant or zero point is not zero
  180. if zero_point is None:
  181. return False
  182. return numpy.all(zero_point == 0)
  183. def check_node_input_value(self, node, input_index: int, expected_value):
  184. """Verify that a node has expected input value
  185. Args:
  186. node (NodeProto): a node to check
  187. input_index (int): index of its input to be verified
  188. expected_value (Any): expected value of the input
  189. Returns:
  190. bool: whether the check is passed or not
  191. """
  192. assert len(node.input) > input_index
  193. value = self.model.get_constant_value(node.input[input_index])
  194. if isinstance(expected_value, list):
  195. return (isinstance(value, (ndarray, list))) and array_equal(expected_value, value, equal_nan=False)
  196. else:
  197. return value == expected_value
  198. def remove_identity_nodes(self):
  199. """Remove Identity nodes, except those right before graph output."""
  200. nodes_to_remove = []
  201. graph_output_names = self.model.get_graphs_output_names()
  202. for node in self.model.nodes():
  203. if node.op_type == "Identity":
  204. if node.output[0] not in graph_output_names:
  205. self.model.replace_input_of_all_nodes(node.output[0], node.input[0])
  206. nodes_to_remove.append(node)
  207. if nodes_to_remove:
  208. self.model.remove_nodes(nodes_to_remove)
  209. logger.info(f"Removed {len(nodes_to_remove)} Identity nodes")
  210. def remove_cascaded_cast_nodes(self):
  211. self.model.remove_cascaded_cast_nodes()
  212. def remove_useless_cast_nodes(self):
  213. self.model.remove_useless_cast_nodes()
  214. def remove_useless_reshape_nodes(self):
  215. """Remove reshape node that is not needed based on symbolic shape inference: input and output has same shape"""
  216. shape_infer = self.model.infer_runtime_shape(update=True)
  217. if shape_infer is None:
  218. return
  219. nodes_to_remove = []
  220. for node in self.model.nodes():
  221. if node.op_type == "Reshape":
  222. input_shape = shape_infer.get_edge_shape(node.input[0])
  223. output_shape = shape_infer.get_edge_shape(node.output[0])
  224. if input_shape and output_shape and input_shape == output_shape:
  225. logger.info(
  226. f"Remove reshape node {node.name} since its input shape is same as output: {input_shape}"
  227. )
  228. nodes_to_remove.append(node)
  229. if nodes_to_remove:
  230. graph_input_names = set(self.model.get_graphs_input_names())
  231. graph_output_names = set(self.model.get_graphs_output_names())
  232. for node in nodes_to_remove:
  233. if bool(set(node.output) & graph_output_names):
  234. if (
  235. not bool(set(node.input) & graph_input_names)
  236. and len(self.model.input_name_to_nodes()[node.input[0]]) == 1 # parent has only one child
  237. ):
  238. self.model.replace_output_of_all_nodes(node.input[0], node.output[0])
  239. else:
  240. continue
  241. else:
  242. self.model.replace_input_of_all_nodes(node.output[0], node.input[0])
  243. self.model.remove_node(node)
  244. class NumpyHelper:
  245. @staticmethod
  246. def to_array(tensor: TensorProto, fill_zeros: bool = False) -> ndarray:
  247. # When weights are in external data format but not presented, we can still test the optimizer with two changes:
  248. # (1) set fill_zeros = True (2) change load_external_data=False in optimizer.py
  249. if fill_zeros:
  250. return ndarray(
  251. shape=tensor.dims,
  252. dtype=helper.tensor_dtype_to_np_dtype(tensor.data_type),
  253. )
  254. if tensor.data_type == TensorProto.BFLOAT16:
  255. import onnx_ir as ir # noqa: PLC0415
  256. # Use onnx_ir to correctly handle bfloat16 tensors
  257. return ir.from_proto(tensor).numpy()
  258. return numpy_helper.to_array(tensor)