dynamo_onnx_helper.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from collections.abc import Sequence
  6. from logging import getLogger
  7. from typing import Any
  8. import numpy as np
  9. import onnx
  10. from onnx import helper
  11. from onnx_model import OnnxModel
  12. logger = getLogger(__name__)
  13. class DynamoOnnxHelper:
  14. """
  15. Helper class for processing ONNX models exported by Torch Dynamo.
  16. """
  17. def __init__(self, model: onnx.ModelProto):
  18. self.model = OnnxModel(model)
  19. def update_edges(self, edge_mapping: dict) -> None:
  20. """
  21. Updates the edges in the model according to the given mapping.
  22. """
  23. for node in self.model.model.graph.node:
  24. for i in range(len(node.input)):
  25. if node.input[i] in edge_mapping:
  26. node.input[i] = edge_mapping[node.input[i]]
  27. for i in range(len(node.output)):
  28. if node.output[i] in edge_mapping:
  29. node.output[i] = edge_mapping[node.output[i]]
  30. for graph_input in self.model.model.graph.input:
  31. if graph_input.name in edge_mapping:
  32. graph_input.name = edge_mapping[graph_input.name]
  33. for graph_output in self.model.model.graph.output:
  34. if graph_output.name in edge_mapping:
  35. graph_output.name = edge_mapping[graph_output.name]
  36. def unroll_function(self, func_name: str) -> None:
  37. """
  38. Unrolls the function with the given name in the model.
  39. """
  40. logger.debug(f"Unrolling function {func_name}...")
  41. nodes_to_remove = []
  42. nodes_to_add = []
  43. edges_to_remove = []
  44. edges_to_add = []
  45. for node in self.model.model.graph.node:
  46. if node.op_type == func_name:
  47. nodes_to_remove.append(node)
  48. edges_to_remove.extend(list(node.input) + list(node.output))
  49. func_to_remove = None
  50. for f in self.model.model.functions:
  51. if f.name == func_name:
  52. nodes_to_add.extend(list(f.node))
  53. edges_to_add.extend(list(f.input) + list(f.output))
  54. func_to_remove = f
  55. assert len(edges_to_remove) == len(edges_to_add)
  56. for node in nodes_to_remove:
  57. self.model.model.graph.node.remove(node)
  58. for node in nodes_to_add:
  59. self.model.model.graph.node.append(node)
  60. if func_to_remove is not None:
  61. self.model.model.functions.remove(func_to_remove)
  62. edge_mapping = {}
  63. for i in range(len(edges_to_remove)):
  64. k = edges_to_remove[i]
  65. v = edges_to_add[i]
  66. if k != v:
  67. edge_mapping[k] = v
  68. return self.update_edges(edge_mapping)
  69. def remove_function(self, func_name: str, input_id: int, output_id: int) -> None:
  70. """
  71. Removes the function in the model.
  72. """
  73. edge_mapping = {}
  74. nodes_to_remove = []
  75. for node in self.model.model.graph.node:
  76. if node.op_type.find(func_name) != -1:
  77. edge_mapping[node.input[input_id]] = node.output[output_id]
  78. nodes_to_remove.append(node)
  79. for node in nodes_to_remove:
  80. self.model.model.graph.node.remove(node)
  81. self.update_edges(edge_mapping)
  82. def remove_dropout_layer(self) -> None:
  83. """
  84. Removes the dropout layer in the model.
  85. """
  86. logger.debug("Removing dropout layer...")
  87. self.remove_function("Dropout", 0, 0)
  88. def remove_lm_head_layer(self) -> None:
  89. """
  90. Removes the LM head layer in the model.
  91. """
  92. logger.debug("Removing LM head layer...")
  93. # bugbug: need to copy the right vi over
  94. self.remove_function("Linear_lm_head", 2, 0)
  95. def add_initializer(self, name: str, data_type: int, dims: Sequence[int], vals: Any, raw: bool = True):
  96. if raw:
  97. np_type = helper.tensor_dtype_to_np_dtype(data_type)
  98. if not isinstance(vals, np.ndarray):
  99. bytes = np.array(vals, dtype=np_type).tobytes()
  100. else:
  101. bytes = vals.astype(np_type).tobytes()
  102. tensor = helper.make_tensor(
  103. name=name,
  104. data_type=data_type,
  105. dims=dims,
  106. vals=bytes,
  107. raw=True,
  108. )
  109. else:
  110. tensor = helper.make_tensor(
  111. name=name,
  112. data_type=data_type,
  113. dims=dims,
  114. vals=vals,
  115. raw=False,
  116. )
  117. self.model.add_initializer(tensor)
  118. return tensor
  119. def convert_constants_to_initializers(self, min_size: int = 1) -> None:
  120. """
  121. Converts Constant ops of size [min_size] or higher to initializers
  122. """
  123. logger.debug(f"Converting constants greater than size {min_size} to initializers")
  124. constant_nodes = self.model.get_nodes_by_op_type("Constant")
  125. nodes_to_remove = []
  126. for node in constant_nodes:
  127. # Get info from Constant op
  128. np_data = self.model.get_constant_value(node.output[0])
  129. # Skip if there are less than [min_size] elements
  130. if np_data is None or np_data.size < min_size:
  131. continue
  132. # Add new initializer with same name as Constant op's output
  133. for att in node.attribute:
  134. if att.name == "value":
  135. self.add_initializer(
  136. name=node.output[0],
  137. data_type=att.t.data_type,
  138. dims=list(np_data.shape),
  139. vals=np_data,
  140. )
  141. break
  142. nodes_to_remove.append(node)
  143. # Remove Constant ops from graph
  144. self.model.remove_nodes(nodes_to_remove)
  145. def clear_metadata(self) -> None:
  146. """
  147. Clear metadata fields in all nodes
  148. """
  149. for graph in self.model.graphs():
  150. graph.ClearField("metadata_props")
  151. for node in self.model.nodes():
  152. node.ClearField("metadata_props")
  153. @staticmethod
  154. def fold_transpose_initializers(model) -> None:
  155. """
  156. Constant fold Transpose initializers without changing the initializer names
  157. """
  158. from onnxscript import ir # noqa: PLC0415
  159. for name, initializer in model.graph.initializers.items():
  160. user_nodes = initializer.consumers()
  161. if len(user_nodes) == 1 and user_nodes[0].op_type == "Transpose":
  162. transpose_node = user_nodes[0]
  163. perm = transpose_node.attributes.get("perm")
  164. if perm is None:
  165. transposed_tensor = ir.tensor(initializer.const_value.numpy().transpose())
  166. else:
  167. transposed_tensor = ir.tensor(initializer.const_value.numpy().transpose(perm.as_ints()))
  168. new_initializer = ir.Value(
  169. name=initializer.name,
  170. shape=transposed_tensor.shape,
  171. type=ir.TensorType(transposed_tensor.dtype),
  172. const_value=transposed_tensor,
  173. )
  174. ir.convenience.replace_all_uses_with(transpose_node.outputs[0], new_initializer)
  175. model.graph.initializers[name] = new_initializer
  176. transpose_node.graph.remove(transpose_node, safe=True)