add_trans_cast.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import json
  6. from argparse import ArgumentParser
  7. import onnx
  8. from onnx import TensorProto, helper
  9. def graph_topological_sort(graph):
  10. deps_count = [0] * len(graph.node) # dependency count of each node
  11. deps_to_nodes = {} # input to node indice
  12. sorted_nodes = [] # initialize sorted_nodes
  13. for node_idx, node in enumerate(graph.node):
  14. # CANNOT use len(node.input) directly because input can be optional
  15. deps_count[node_idx] = sum(1 for _ in node.input if _)
  16. if deps_count[node_idx] == 0: # Constant doesn't depend on any inputs
  17. sorted_nodes.append(graph.node[node_idx])
  18. continue
  19. for input_name in node.input:
  20. if input_name not in deps_to_nodes:
  21. deps_to_nodes[input_name] = [node_idx]
  22. else:
  23. deps_to_nodes[input_name].append(node_idx)
  24. # Note: this logic only applies to top level graph since a sub graph could use intializer from parent graph
  25. initializer_names = [init.name for init in graph.initializer]
  26. graph_input_names = [input.name for input in graph.input]
  27. input_names = initializer_names + graph_input_names
  28. input_names.sort()
  29. prev_input_name = None
  30. for input_name in input_names:
  31. if prev_input_name == input_name:
  32. continue
  33. prev_input_name = input_name
  34. if input_name in deps_to_nodes:
  35. for node_idx in deps_to_nodes[input_name]:
  36. deps_count[node_idx] = deps_count[node_idx] - 1
  37. if deps_count[node_idx] == 0:
  38. sorted_nodes.append(graph.node[node_idx])
  39. start = 0
  40. end = len(sorted_nodes)
  41. while start < end:
  42. for output in sorted_nodes[start].output:
  43. if output in deps_to_nodes:
  44. for node_idx in deps_to_nodes[output]:
  45. deps_count[node_idx] = deps_count[node_idx] - 1
  46. if deps_count[node_idx] == 0:
  47. sorted_nodes.append(graph.node[node_idx])
  48. end = end + 1
  49. start = start + 1
  50. assert end == len(graph.node), "Graph is not a DAG"
  51. graph.ClearField("node")
  52. graph.node.extend(sorted_nodes)
  53. class QnnTensorStruct:
  54. def __init__(self):
  55. self.name = ""
  56. self.onnx_data_type = TensorProto.FLOAT
  57. self.dim = []
  58. def qnn_data_type_to_onnx_data_type(qnn_data_type):
  59. # QNN_DATATYPE_UFIXED_POINT_8 QNN_DATATYPE_UINT_8
  60. if qnn_data_type == 0x0408 or qnn_data_type == 0x0108:
  61. return TensorProto.UINT8
  62. # QNN_DATATYPE_UFIXED_POINT_16 QNN_DATATYPE_UINT_16
  63. elif qnn_data_type == 0x0416 or qnn_data_type == 0x0116:
  64. return TensorProto.UINT16
  65. # QNN_DATATYPE_UFIXED_POINT_32 QNN_DATATYPE_UINT_32
  66. elif qnn_data_type == 0x0432 or qnn_data_type == 0x0132:
  67. return TensorProto.UINT32
  68. # QNN_DATATYPE_UINT_64
  69. elif qnn_data_type == 0x0164:
  70. return TensorProto.UINT64
  71. # QNN_DATATYPE_FIXED_POINT_8 QNN_DATATYPE_INT_8
  72. elif qnn_data_type == 0x0308 or qnn_data_type == 0x0008:
  73. return TensorProto.INT8
  74. # QNN_DATATYPE_FIXED_POINT_16 QNN_DATATYPE_INT_16
  75. elif qnn_data_type == 0x0316 or qnn_data_type == 0x0016:
  76. return TensorProto.INT16
  77. # QNN_DATATYPE_FIXED_POINT_32 QNN_DATATYPE_INT_32
  78. elif qnn_data_type == 0x0332 or qnn_data_type == 0x0032:
  79. return TensorProto.INT32
  80. # QNN_DATATYPE_INT_64
  81. elif qnn_data_type == 0x0064:
  82. return TensorProto.INT64
  83. # QNN_DATATYPE_FLOAT_16
  84. elif qnn_data_type == 0x0216:
  85. return TensorProto.FLOAT16
  86. # QNN_DATATYPE_FLOAT_32
  87. elif qnn_data_type == 0x0232:
  88. return TensorProto.FLOAT
  89. # QNN_DATATYPE_BOOL_8
  90. elif qnn_data_type == 0x0508:
  91. return TensorProto.BOOL
  92. else:
  93. return TensorProto.UNDEFINED
  94. def parse_qnn_json_file(qnn_json_file_path, qnn_input_output_tensor_dic):
  95. with open(qnn_json_file_path) as qnn_json_file:
  96. qnn_json = json.load(qnn_json_file)
  97. assert "graph" in qnn_json, "QNN converted json file not valid. Can't find graph."
  98. assert "tensors" in qnn_json["graph"], "QNN converted json file not valid. Can't find tensors."
  99. for qnn_tensor_name, qnn_tensor_attribute in qnn_json["graph"]["tensors"].items():
  100. # type:0 - QNN input tensor, type:1 - QNN output tensor
  101. assert (
  102. "type" in qnn_tensor_attribute
  103. and "data_type" in qnn_tensor_attribute
  104. and "dims" in qnn_tensor_attribute
  105. ), "QNN converted json file not valid. Can't find some keys from tensors"
  106. if qnn_tensor_attribute["type"] == 0 or qnn_tensor_attribute["type"] == 1:
  107. qnn_tensor = QnnTensorStruct()
  108. qnn_tensor.name = qnn_tensor_name
  109. qnn_tensor.onnx_data_type = qnn_data_type_to_onnx_data_type(qnn_tensor_attribute["data_type"])
  110. qnn_tensor.dim = qnn_tensor_attribute["dims"]
  111. qnn_input_output_tensor_dic[qnn_tensor_name] = qnn_tensor
  112. assert len(qnn_input_output_tensor_dic) > 1, (
  113. "Converted QNN model not valid. It should have at least 1 input & 1 output."
  114. )
  115. def compare_onnx_shape_with_qnn_shape(onnx_dims, qnn_dims):
  116. assert len(onnx_dims) == len(qnn_dims), "Onnx shape and Qnn shape has different rank."
  117. return all(onnx_dims[i].dim_value == qnn_dims[i] for i in range(len(onnx_dims)))
  118. def gen_to_channel_first_perm(rank):
  119. assert rank > 2, "Shape rank should >2 for the Transpose node."
  120. perm = []
  121. perm.append(0)
  122. perm.append(rank - 1)
  123. for i in range(1, rank - 1):
  124. perm.append(i) # noqa: PERF402
  125. return perm
  126. def gen_to_channel_last_perm(rank):
  127. assert rank > 2, "Shape rank should >2 for the Transpose node."
  128. perm = []
  129. perm.append(0)
  130. for i in range(2, rank):
  131. perm.append(i) # noqa: PERF402
  132. perm.append(1)
  133. return perm
  134. # Onnxruntime QNN EP can support context binary file generated by QNN tool chain. However QNN generated context binary file
  135. # uses channel last data layout and 8 bits or 16 bits for input and output.
  136. # This script gets the QNN model input & output information from QNN converted model_net.json file, compare them with Onnx model
  137. # and inserts Cast, Transpose nodes to Onnx model if required
  138. def main():
  139. parser = ArgumentParser(
  140. "Insert Cast, Transpose nodes into Onnx model to make it aligned with QNN generated context binary."
  141. )
  142. parser.add_argument("-m", "--onnx_model", help="Required. Path to Onnx model file.", required=True, type=str)
  143. parser.add_argument(
  144. "-q", "--qnn_json", help="Required. Path to Qnn converted model_net.json file.", required=True, type=str
  145. )
  146. args = parser.parse_args()
  147. # Parse Qnn model_net.json file to get the graph input output information
  148. qnn_input_output_tensor_dic = {}
  149. parse_qnn_json_file(args.qnn_json, qnn_input_output_tensor_dic)
  150. model = onnx.load(args.onnx_model)
  151. nodes_to_add = []
  152. # Tranch the tensor name change to update the consumer nodes
  153. graph_input_output_name_dic = {}
  154. for graph_input in model.graph.input:
  155. if graph_input.name in qnn_input_output_tensor_dic:
  156. input_name_fater_node_insert = graph_input.name
  157. qnn_input_tensor = qnn_input_output_tensor_dic[graph_input.name]
  158. # Insert Cast node if Onnx input and Qnn input has different data type
  159. if graph_input.type.tensor_type.elem_type != qnn_input_tensor.onnx_data_type:
  160. # Insert Cast node
  161. cast_input_name = input_name_fater_node_insert
  162. cast_output_name = cast_input_name + "_qnn_cast"
  163. input_cast_node = helper.make_node(
  164. "Cast",
  165. name=cast_output_name,
  166. inputs=[cast_input_name],
  167. outputs=[cast_output_name],
  168. to=graph_input.type.tensor_type.elem_type,
  169. )
  170. # Change input data type to Qnn input data type
  171. graph_input.type.tensor_type.elem_type = qnn_input_tensor.onnx_data_type
  172. nodes_to_add.extend([input_cast_node])
  173. input_name_fater_node_insert = cast_output_name
  174. graph_input_output_name_dic[graph_input.name] = cast_output_name
  175. if not compare_onnx_shape_with_qnn_shape(graph_input.type.tensor_type.shape.dim, qnn_input_tensor.dim):
  176. # Add Transpose node (channel last to channel first)
  177. transpose_perm = gen_to_channel_first_perm(len(graph_input.type.tensor_type.shape.dim))
  178. transpose_input_name = input_name_fater_node_insert
  179. transpose_output_name = transpose_input_name + "_qnn_trans"
  180. input_transpose_node = helper.make_node(
  181. "Transpose",
  182. name=transpose_output_name,
  183. inputs=[transpose_input_name],
  184. outputs=[transpose_output_name],
  185. perm=transpose_perm,
  186. )
  187. nodes_to_add.extend([input_transpose_node])
  188. graph_input_output_name_dic[graph_input.name] = transpose_output_name
  189. # Change input shape to Qnn input shape
  190. for i in range(len(graph_input.type.tensor_type.shape.dim)):
  191. graph_input.type.tensor_type.shape.dim[i].dim_value = qnn_input_tensor.dim[i]
  192. else:
  193. raise AssertionError("Error: Onnx model input: " + graph_input.name + " not exist from QNN model input.")
  194. for graph_output in model.graph.output:
  195. if graph_output.name in qnn_input_output_tensor_dic:
  196. output_name_after_node_insert = graph_output.name
  197. # Insert Cast node if Onnx input and Qnn input has idfferent data type
  198. qnn_output_tensor = qnn_input_output_tensor_dic[graph_output.name]
  199. if graph_output.type.tensor_type.elem_type != qnn_output_tensor.onnx_data_type:
  200. # Insert Cast node
  201. cast_output_name = output_name_after_node_insert
  202. cast_input_name = cast_output_name + "_qnn_cast"
  203. output_cast_node = helper.make_node(
  204. "Cast",
  205. name=cast_input_name,
  206. inputs=[cast_input_name],
  207. outputs=[cast_output_name],
  208. to=qnn_output_tensor.onnx_data_type,
  209. )
  210. # Change output data type to Onn output data type
  211. graph_output.type.tensor_type.elem_type = qnn_output_tensor.onnx_data_type
  212. nodes_to_add.extend([output_cast_node])
  213. output_name_after_node_insert = cast_input_name
  214. graph_input_output_name_dic[graph_output.name] = cast_input_name
  215. if not compare_onnx_shape_with_qnn_shape(graph_output.type.tensor_type.shape.dim, qnn_output_tensor.dim):
  216. # Add Transpose node (channel first to channel last)
  217. transpose_perm = gen_to_channel_last_perm(len(graph_output.type.tensor_type.shape.dim))
  218. transpose_output_name = output_name_after_node_insert
  219. transpose_input_name = transpose_output_name + "_qnn_trans"
  220. output_transpose_node = helper.make_node(
  221. "Transpose",
  222. name=transpose_input_name,
  223. inputs=[transpose_input_name],
  224. outputs=[transpose_output_name],
  225. perm=transpose_perm,
  226. )
  227. nodes_to_add.extend([output_transpose_node])
  228. graph_input_output_name_dic[graph_output.name] = transpose_input_name
  229. # Change output shape to Qnn output shape
  230. for i in range(len(graph_output.type.tensor_type.shape.dim)):
  231. graph_output.type.tensor_type.shape.dim[i].dim_value = qnn_input_output_tensor_dic[
  232. graph_output.name
  233. ].dim[i]
  234. else:
  235. raise AssertionError("Error: Onnx model output: " + graph_output.name + " not exist from QNN model output.")
  236. for node in model.graph.node:
  237. for node_input_index, node_input in enumerate(node.input):
  238. # update consumer node for graph inputs to connect to inserted node
  239. if node_input in graph_input_output_name_dic:
  240. node.input[node_input_index] = graph_input_output_name_dic[node_input]
  241. for node_output_index, node_output in enumerate(node.output):
  242. # update producer node for graph outputs to connect to inserted node
  243. if node_output in graph_input_output_name_dic:
  244. node.output[node_output_index] = graph_input_output_name_dic[node_output]
  245. model.graph.node.extend(nodes_to_add)
  246. graph_topological_sort(model.graph)
  247. # Add extra parameter all_tensors_to_one_file=False, size_threshold=5000 if the model exceeds protobuf 2GB limit e.g below
  248. # onnx.save(model, args.onnx_model.replace(".onnx", "_add_trans.onnx"), all_tensors_to_one_file=False, size_threshold=5000)
  249. onnx.save(model, args.onnx_model.replace(".onnx", "_add_trans.onnx"))
  250. if __name__ == "__main__":
  251. main()