onnx_model.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600
  1. # --------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from pathlib import Path
  6. import onnx
  7. import onnx.helper as onnx_helper
  8. import onnx.numpy_helper as onnx_numpy_helper
  9. from onnx.onnx_pb import ModelProto
  10. from .quant_utils import attribute_to_kwarg, find_by_name
  11. def _clean_initializers_helper(graph, model):
  12. """Clean unused initializers from graph.
  13. Returns:
  14. A cleaned graph without unused initializers
  15. A list of tensor names, which are not produced by this graph and its subgraphes
  16. """
  17. requesting_tensor_names = set()
  18. requesting_tensor_names.update(input_name for node in graph.node for input_name in node.input if input_name)
  19. requesting_tensor_names.update(g_out.name for g_out in graph.output if g_out.name)
  20. new_nodes = []
  21. for node in graph.node:
  22. new_node = node
  23. graph_attrs = [
  24. attr
  25. for attr in node.attribute
  26. if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS
  27. ]
  28. if graph_attrs:
  29. kwargs = {}
  30. for attr in node.attribute:
  31. new_attribute = {}
  32. if attr.type == onnx.AttributeProto.GRAPH:
  33. (
  34. cleaned_sub_graph,
  35. sub_requesting_tensor_names,
  36. ) = _clean_initializers_helper(attr.g, model)
  37. new_attribute = {attr.name: cleaned_sub_graph}
  38. requesting_tensor_names.update(sub_requesting_tensor_names)
  39. elif attr.type == onnx.AttributeProto.GRAPHS:
  40. cleaned_graphes = []
  41. for subgraph in attr.graphs:
  42. (
  43. cleaned_sub_graph,
  44. sub_requesting_tensor_names,
  45. ) = _clean_initializers_helper(subgraph, model)
  46. cleaned_graphes.append(cleaned_sub_graph)
  47. requesting_tensor_names.update(sub_requesting_tensor_names)
  48. new_attribute = {attr.name: cleaned_graphes}
  49. else:
  50. new_attribute = attribute_to_kwarg(attr)
  51. kwargs.update(new_attribute)
  52. new_node = onnx_helper.make_node(node.op_type, node.input, node.output, name=node.name, **kwargs)
  53. new_nodes.append(new_node)
  54. graph.ClearField("node")
  55. graph.node.extend(new_nodes)
  56. requesting_tensor_names.difference_update(output for node in graph.node for output in node.output)
  57. unused_initializer = []
  58. for initializer in graph.initializer:
  59. if initializer.name in requesting_tensor_names:
  60. requesting_tensor_names.remove(initializer.name)
  61. else:
  62. # mark it to remove, remove here directly will cause mis-behavier
  63. unused_initializer.append(initializer)
  64. name_to_input = {input.name: input for input in graph.input}
  65. for initializer in unused_initializer:
  66. graph.initializer.remove(initializer)
  67. if initializer.name in name_to_input:
  68. try:
  69. graph.input.remove(name_to_input[initializer.name])
  70. except StopIteration:
  71. if model.ir_version < 4:
  72. print(f"Warning: invalid weight name {initializer.name} found in the graph (not a graph input)")
  73. requesting_tensor_names.difference_update(input.name for input in graph.input)
  74. return graph, requesting_tensor_names
  75. class ONNXModel:
  76. def __init__(self, model: ModelProto):
  77. self.model = model
  78. def nodes(self):
  79. return self.model.graph.node
  80. def initializer(self):
  81. return self.model.graph.initializer
  82. def initializer_extend(self, inits):
  83. if len(inits) == 0:
  84. raise ValueError("Can add an empty list.")
  85. for init in self.initializer():
  86. self._check_init(init, "gain")
  87. for init in inits:
  88. self._check_init(init)
  89. self.model.graph.initializer.append(init)
  90. def graph(self):
  91. return self.model.graph
  92. def ir_version(self):
  93. return self.model.ir_version
  94. def opset_import(self):
  95. return self.model.opset_import
  96. def set_opset_import(self, domain, version):
  97. for opset in self.model.opset_import:
  98. if opset.domain == domain:
  99. opset.version = version
  100. return
  101. self.model.opset_import.extend([onnx_helper.make_opsetid(domain, version)])
  102. def remove_node(self, node):
  103. if node in self.model.graph.node:
  104. self.model.graph.node.remove(node)
  105. def remove_nodes(self, nodes_to_remove):
  106. for node in nodes_to_remove:
  107. self.remove_node(node)
  108. def add_node(self, node):
  109. self.model.graph.node.extend([self._check_node(node)])
  110. def add_nodes(self, nodes_to_add):
  111. for node in nodes_to_add:
  112. self.add_node(node)
  113. def add_initializer(self, tensor):
  114. if find_by_name(tensor.name, self.model.graph.initializer) is None:
  115. self._check_init(tensor)
  116. self.model.graph.initializer.extend([tensor])
  117. def get_initializer(self, name):
  118. for tensor in self.model.graph.initializer:
  119. if tensor.name == name:
  120. return tensor
  121. return None
  122. def find_graph_input(self, input_name):
  123. for input in self.model.graph.input:
  124. if input.name == input_name:
  125. return input
  126. return None
  127. def find_graph_output(self, output_name):
  128. for output in self.model.graph.output:
  129. if output.name == output_name:
  130. return output
  131. return None
  132. def get_tensor_type(self, tensor_name: str):
  133. tensor_type_map = {obj.name: obj.type for obj in self.model.graph.value_info}
  134. if tensor_name in tensor_type_map:
  135. return tensor_type_map[tensor_name].tensor_type
  136. g_input = self.find_graph_input(tensor_name)
  137. if g_input:
  138. return g_input.type.tensor_type
  139. g_output = self.find_graph_output(tensor_name)
  140. if g_output:
  141. return g_output.type.tensor_type
  142. return None
  143. def get_constant_value(self, output_name):
  144. for node in self.model.graph.node:
  145. if node.op_type == "Constant":
  146. if node.output[0] == output_name:
  147. for attr in node.attribute:
  148. if attr.name == "value":
  149. return onnx_numpy_helper.to_array(attr.t)
  150. # Fallback to initializer since constant folding may have been applied.
  151. initializer = self.get_initializer(output_name)
  152. if initializer is not None:
  153. return onnx_numpy_helper.to_array(initializer)
  154. return None
  155. def get_initializer_name_set(self):
  156. return {initializer.name for initializer in self.model.graph.initializer}
  157. def remove_initializer(self, tensor):
  158. if tensor in self.model.graph.initializer:
  159. self.model.graph.initializer.remove(tensor)
  160. for input in self.model.graph.input:
  161. if input.name == tensor.name:
  162. self.model.graph.input.remove(input)
  163. break
  164. def remove_initializers(self, init_to_remove):
  165. for initializer in init_to_remove:
  166. self.remove_initializer(initializer)
  167. def get_non_initializer_inputs(self):
  168. initializer_names = self.get_initializer_name_set()
  169. non_initializer_inputs = set()
  170. for input in self.model.graph.input:
  171. if input.name not in initializer_names:
  172. non_initializer_inputs.add(input.name)
  173. return non_initializer_inputs
  174. def input_name_to_nodes(self):
  175. input_name_to_nodes = {}
  176. for node in self.model.graph.node:
  177. for input_name in node.input:
  178. if input_name: # Could be empty when it is optional
  179. if input_name not in input_name_to_nodes:
  180. input_name_to_nodes[input_name] = [node]
  181. else:
  182. input_name_to_nodes[input_name].append(node)
  183. return input_name_to_nodes
  184. def output_name_to_node(self):
  185. output_name_to_node = {}
  186. for node in self.model.graph.node:
  187. for output_name in node.output:
  188. if output_name: # Could be empty when it is optional
  189. output_name_to_node[output_name] = node
  190. return output_name_to_node
  191. def get_children(self, node, input_name_to_nodes=None):
  192. if input_name_to_nodes is None:
  193. input_name_to_nodes = self.input_name_to_nodes()
  194. children = []
  195. for output in node.output:
  196. if output in input_name_to_nodes:
  197. for node in input_name_to_nodes[output]:
  198. children.append(node) # noqa: PERF402
  199. return children
  200. def get_parents(self, node, output_name_to_node=None):
  201. if output_name_to_node is None:
  202. output_name_to_node = self.output_name_to_node()
  203. parents = []
  204. for input in node.input:
  205. if input in output_name_to_node:
  206. parents.append(output_name_to_node[input])
  207. return parents
  208. def get_parent(self, node, idx, output_name_to_node=None):
  209. if output_name_to_node is None:
  210. output_name_to_node = self.output_name_to_node()
  211. if len(node.input) <= idx:
  212. return None
  213. input = node.input[idx]
  214. if input not in output_name_to_node:
  215. return None
  216. return output_name_to_node[input]
  217. def find_node_by_name(self, node_name, new_nodes_list, graph):
  218. """Find out if a node exists in a graph or a node is in the
  219. new set of nodes created during quantization.
  220. Returns:
  221. The node found or None.
  222. """
  223. graph_nodes_list = list(graph.node) # deep copy
  224. graph_nodes_list.extend(new_nodes_list)
  225. node = find_by_name(node_name, graph_nodes_list)
  226. return node
  227. def get_largest_node_name_suffix(self, node_name_prefix):
  228. """
  229. Gets the largest node name (int) suffix for all node names that begin with `node_name_prefix`.
  230. Example: for nodes my_prefix_0 and my_prefix_3, this method returns 3.
  231. """
  232. suffix = -1
  233. for node in self.model.graph.node:
  234. if node.name and node.name.startswith(node_name_prefix):
  235. try:
  236. index = int(node.name[len(node_name_prefix) :])
  237. suffix = max(index, suffix)
  238. except ValueError:
  239. continue
  240. return suffix
  241. def get_largest_initializer_name_suffix(self, initializer_name_prefix):
  242. """
  243. Gets the largest initializer name integer suffix for all initializer names that begin
  244. with `initializer_name_prefix`. This can be used to create unique initializer names.
  245. Example: for initializer names 'my_weight_0' and 'my_weight_3', this method returns 3 if
  246. `initializer_name_prefix` is 'my_weight_'.
  247. """
  248. suffix = -1
  249. for initializer in self.model.graph.initializer:
  250. if initializer.name.startswith(initializer_name_prefix):
  251. try:
  252. index = int(initializer.name[len(initializer_name_prefix) :])
  253. suffix = max(index, suffix)
  254. except ValueError:
  255. continue
  256. return suffix
  257. def find_nodes_by_initializer(self, graph, initializer):
  258. """
  259. Find all nodes with given initializer as an input.
  260. """
  261. nodes = []
  262. for node in graph.node:
  263. for node_input in node.input:
  264. if node_input == initializer.name:
  265. nodes.append(node)
  266. return nodes
  267. @staticmethod
  268. def __get_initializer(name, graph_path):
  269. for gid in range(len(graph_path) - 1, -1, -1):
  270. graph = graph_path[gid]
  271. for tensor in graph.initializer:
  272. if tensor.name == name:
  273. return tensor, graph
  274. return None, None
  275. @staticmethod
  276. def __replace_gemm_with_matmul(graph_path):
  277. new_nodes = []
  278. graph = graph_path[-1]
  279. for node in graph.node:
  280. graph_attrs = [attr for attr in node.attribute if attr.type == 5 or attr.type == 10]
  281. if graph_attrs:
  282. kwargs = {}
  283. for attr in node.attribute:
  284. if attr.type == 5:
  285. graph_path.append(attr.g)
  286. kv = {attr.name: ONNXModel.__replace_gemm_with_matmul(graph_path)}
  287. elif attr.type == 10:
  288. value = []
  289. for subgraph in attr.graphs:
  290. graph_path.append(subgraph)
  291. value.extend([ONNXModel.__replace_gemm_with_matmul(graph_path)])
  292. kv = {attr.name: value}
  293. else:
  294. kv = attribute_to_kwarg(attr)
  295. kwargs.update(kv)
  296. node = onnx_helper.make_node( # noqa: PLW2901
  297. node.op_type, node.input, node.output, name=node.name, **kwargs
  298. )
  299. if node.op_type == "Gemm":
  300. alpha = 1.0
  301. beta = 1.0
  302. transA = 0 # noqa: N806
  303. transB = 0 # noqa: N806
  304. for attr in node.attribute:
  305. if attr.name == "alpha":
  306. alpha = onnx_helper.get_attribute_value(attr)
  307. elif attr.name == "beta":
  308. beta = onnx_helper.get_attribute_value(attr)
  309. elif attr.name == "transA":
  310. transA = onnx_helper.get_attribute_value(attr) # noqa: N806
  311. elif attr.name == "transB":
  312. transB = onnx_helper.get_attribute_value(attr) # noqa: N806
  313. if alpha == 1.0 and beta == 1.0 and transA == 0:
  314. inputB = node.input[1] # noqa: N806
  315. if transB == 1:
  316. B, Bs_graph = ONNXModel.__get_initializer(node.input[1], graph_path) # noqa: N806
  317. if B:
  318. # assume B is not used by any other node
  319. B_array = onnx_numpy_helper.to_array(B) # noqa: N806
  320. B_trans = onnx_numpy_helper.from_array(B_array.T) # noqa: N806
  321. B_trans.name = B.name
  322. Bs_graph.initializer.remove(B)
  323. for input in Bs_graph.input:
  324. if input.name == inputB:
  325. Bs_graph.input.remove(input)
  326. break
  327. Bs_graph.initializer.extend([B_trans])
  328. else:
  329. inputB += "_Transposed" # noqa: N806
  330. transpose_node = onnx_helper.make_node(
  331. "Transpose",
  332. inputs=[node.input[1]],
  333. outputs=[inputB],
  334. name=node.name + "_Transpose" if node.name else "",
  335. )
  336. new_nodes.append(transpose_node)
  337. matmul_node = onnx_helper.make_node(
  338. "MatMul",
  339. inputs=[node.input[0], inputB],
  340. outputs=[node.output[0] + ("_MatMul" if len(node.input) > 2 else "")],
  341. name=node.name + "_MatMul" if node.name else "",
  342. )
  343. new_nodes.append(matmul_node)
  344. if len(node.input) > 2:
  345. add_node = onnx_helper.make_node(
  346. "Add",
  347. inputs=[node.output[0] + "_MatMul", node.input[2]],
  348. outputs=node.output,
  349. name=node.name + "_Add" if node.name else "",
  350. )
  351. new_nodes.append(add_node)
  352. # unsupported
  353. else:
  354. new_nodes.append(node)
  355. # not GEMM
  356. else:
  357. new_nodes.append(node)
  358. graph.ClearField("node")
  359. graph.node.extend(new_nodes)
  360. graph_path.pop()
  361. return graph
  362. def replace_gemm_with_matmul(self):
  363. graph_path = [self.graph()]
  364. ONNXModel.__replace_gemm_with_matmul(graph_path)
  365. def save_model_to_file(self, output_path, use_external_data_format=False):
  366. """
  367. Save model to external data, which is needed for model size > 2GB
  368. """
  369. self.topological_sort()
  370. if use_external_data_format:
  371. onnx.external_data_helper.convert_model_to_external_data(
  372. self.model,
  373. all_tensors_to_one_file=True,
  374. location=Path(output_path).name + ".data",
  375. convert_attribute=True,
  376. )
  377. for init in self.model.graph.initializer:
  378. self._check_init(init, "end")
  379. onnx.save_model(self.model, output_path)
  380. @staticmethod
  381. def replace_node_input(node, old_input_name, new_input_name):
  382. assert isinstance(old_input_name, str) and isinstance(new_input_name, str)
  383. for j in range(len(node.input)):
  384. if node.input[j] == old_input_name:
  385. node.input[j] = new_input_name
  386. def replace_input_of_all_nodes(self, old_input_name, new_input_name):
  387. for node in self.model.graph.node:
  388. ONNXModel.replace_node_input(node, old_input_name, new_input_name)
  389. def replace_input_of_nodes(self, old_input_name, new_input_name, node_names_set):
  390. for node in self.model.graph.node:
  391. if node.name in node_names_set:
  392. ONNXModel.replace_node_input(node, old_input_name, new_input_name)
  393. @staticmethod
  394. def replace_node_output(node, old_output_name, new_output_name):
  395. assert isinstance(old_output_name, str) and isinstance(new_output_name, str)
  396. for j in range(len(node.output)):
  397. if node.output[j] == old_output_name:
  398. node.output[j] = new_output_name
  399. def replace_output_of_all_nodes(self, old_output_name, new_output_name):
  400. for node in self.model.graph.node:
  401. ONNXModel.replace_node_output(node, old_output_name, new_output_name)
  402. def replace_output_of_nodes(self, old_output_name, new_output_name, node_names_set):
  403. for node in self.model.graph.node:
  404. if node.name in node_names_set:
  405. ONNXModel.replace_node_output(node, old_output_name, new_output_name)
  406. def remove_unused_constant(self):
  407. input_name_to_nodes = self.input_name_to_nodes()
  408. # remove unused constant
  409. unused_nodes = []
  410. nodes = self.nodes()
  411. for node in nodes:
  412. if (
  413. node.op_type == "Constant"
  414. and not self.is_graph_output(node.output[0])
  415. and node.output[0] not in input_name_to_nodes
  416. ):
  417. unused_nodes.append(node)
  418. self.remove_nodes(unused_nodes)
  419. ununsed_weights = []
  420. for w in self.initializer():
  421. if w.name not in input_name_to_nodes and not self.is_graph_output(w.name):
  422. ununsed_weights.append(w)
  423. # Remove from graph.input
  424. for graph_input in self.graph().input:
  425. if graph_input.name == w.name:
  426. self.graph().input.remove(graph_input)
  427. self.remove_initializers(ununsed_weights)
  428. def is_graph_output(self, output_name):
  429. return any(output.name == output_name for output in self.model.graph.output)
  430. def is_graph_input(self, tensor_name: str) -> bool:
  431. return any(input.name == tensor_name for input in self.model.graph.input)
  432. # TODO:use OnnxModel.graph_topological_sort(self.model.graph) from transformers.onnx_model
  433. # Currently it breaks Openvino/Linux training gpu pipeline so hold off for 1.8 release
  434. def topological_sort(self):
  435. deps_count = [0] * len(self.nodes()) # dependency count of each node
  436. deps_to_nodes = {} # input to node indice
  437. sorted_nodes = [] # initialize sorted_nodes
  438. for node_idx, node in enumerate(self.nodes()):
  439. # CANNOT use len(node.input) directly because input can be optional
  440. deps_count[node_idx] = sum(1 for _ in node.input if _)
  441. if deps_count[node_idx] == 0: # Constant doesn't depend on any inputs
  442. sorted_nodes.append(self.nodes()[node_idx])
  443. continue
  444. for input_name in node.input:
  445. if not input_name:
  446. continue
  447. if input_name not in deps_to_nodes:
  448. deps_to_nodes[input_name] = [node_idx]
  449. else:
  450. deps_to_nodes[input_name].append(node_idx)
  451. initializer_names = [init.name for init in self.initializer()]
  452. graph_input_names = [input.name for input in self.model.graph.input]
  453. input_names = initializer_names + graph_input_names
  454. input_names.sort()
  455. prev_input_name = None
  456. for input_name in input_names:
  457. if prev_input_name == input_name:
  458. continue
  459. prev_input_name = input_name
  460. if input_name in deps_to_nodes:
  461. for node_idx in deps_to_nodes[input_name]:
  462. deps_count[node_idx] = deps_count[node_idx] - 1
  463. if deps_count[node_idx] == 0:
  464. sorted_nodes.append(self.nodes()[node_idx])
  465. start = 0
  466. end = len(sorted_nodes)
  467. while start < end:
  468. for output in sorted_nodes[start].output:
  469. if output in deps_to_nodes:
  470. for node_idx in deps_to_nodes[output]:
  471. deps_count[node_idx] = deps_count[node_idx] - 1
  472. if deps_count[node_idx] == 0:
  473. sorted_nodes.append(self.nodes()[node_idx])
  474. end = end + 1
  475. start = start + 1
  476. assert end == len(self.graph().node), "Graph is not a DAG"
  477. self.graph().ClearField("node")
  478. self.graph().node.extend(sorted_nodes)
  479. def clean_initializers(self):
  480. return _clean_initializers_helper(self.graph(), self.model)
  481. def _check_init(self, init, test=None):
  482. if init.data_type == onnx.TensorProto.FLOAT8E4M3FN:
  483. if init.HasField("raw_data"):
  484. b = list(init.raw_data)
  485. if any((i & 127) == 127 for i in b):
  486. raise ValueError(f"Initializer {init.name!r} has nan.")
  487. return init
  488. def _check_node(self, node):
  489. """
  490. A quantization to float 8 does not use quantized bias but float 16 bias.
  491. This function checks that DequantizeLinear is not used to
  492. dequantize from float 16.
  493. """
  494. if node.op_type == "DequantizeLinear":
  495. zero_point = node.input[2]
  496. init = self.get_initializer(zero_point)
  497. dtype = init.data_type
  498. if dtype in {
  499. onnx.TensorProto.FLOAT16,
  500. onnx.TensorProto.FLOAT,
  501. onnx.TensorProto.DOUBLE,
  502. onnx.TensorProto.BFLOAT16,
  503. }:
  504. raise RuntimeError(f"Unsupported DequantizeLinear operator, dequantization from {dtype}.")
  505. return node