shape_inference.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. # --------------------------------------------------------------------------
  2. # Copyright (c) Microsoft, Intel Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import logging
  7. import tempfile
  8. import traceback
  9. from pathlib import Path
  10. import onnx
  11. import onnxruntime
  12. from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
  13. from onnxruntime.transformers.onnx_utils import extract_raw_data_from_model, has_external_data
  14. from .fusions import ReplaceUpsampleWithResize
  15. from .onnx_model import ONNXModel
  16. from .quant_utils import add_pre_process_metadata, save_and_reload_model_with_shape_infer
  17. logger = logging.getLogger(__name__)
  18. def quant_pre_process(
  19. input_model: str | Path | onnx.ModelProto | None = None,
  20. output_model_path: str | Path | None = None,
  21. skip_optimization: bool = False,
  22. skip_onnx_shape: bool = False,
  23. skip_symbolic_shape: bool = False,
  24. auto_merge: bool = False,
  25. int_max: int = 2**31 - 1,
  26. guess_output_rank: bool = False,
  27. verbose: int = 0,
  28. save_as_external_data: bool = False,
  29. all_tensors_to_one_file: bool = False,
  30. external_data_location: str | None = None,
  31. external_data_size_threshold: int = 1024,
  32. **deprecated_kwargs,
  33. ) -> None:
  34. """Shape inference and model optimization, in preparation for quantization.
  35. Args:
  36. input_model: Path to the input model file or ModelProto
  37. output_model_path: Path to the output model file
  38. skip_optimization: Skip model optimization step if true. This may result in ONNX shape
  39. inference failure for some models.
  40. skip_onnx_shape: Skip ONNX shape inference. Symbolic shape inference is most effective
  41. with transformer based models. Skipping all shape inferences may
  42. reduce the effectiveness of quantization, as a tensor with unknown
  43. shape can not be quantized.
  44. skip_symbolic_shape: Skip symbolic shape inference. Symbolic shape inference is most
  45. effective with transformer based models. Skipping all shape
  46. inferences may reduce the effectiveness of quantization, as a tensor
  47. with unknown shape can not be quantized.
  48. auto_merge: For symbolic shape inference, automatically merge symbolic dims when
  49. conflict happens.
  50. int_max: For symbolic shape inference, specify the maximum value for integer to be
  51. treated as boundless for ops like slice
  52. guess_output_rank: Guess output rank to be the same as input 0 for unknown ops
  53. verbose: Logs detailed info of inference, 0: turn off, 1: warnings, 3: detailed
  54. save_as_external_data: Saving an ONNX model to external data
  55. all_tensors_to_one_file: Saving all the external data to one file
  56. external_data_location: The file location to save the external file
  57. external_data_size_threshold: The size threshold for external data
  58. """
  59. if input_model is None:
  60. input_model = deprecated_kwargs.pop("input_model_path", None)
  61. assert input_model is not None
  62. assert output_model_path is not None, "output_model_path is required."
  63. with tempfile.TemporaryDirectory(prefix="pre.quant.") as quant_tmp_dir:
  64. temp_path = Path(quant_tmp_dir)
  65. model = input_model if isinstance(input_model, onnx.ModelProto) else onnx.load(input_model)
  66. # Since Upsample is deprecated after opset v10, and the model's opset will
  67. # be upgraded to at least v11 during quantization, we need to replace Upsample
  68. # with Resize first to avoid generating an invalid model.
  69. ai_onnx_domain = [opset for opset in model.opset_import if not opset.domain or opset.domain == "ai.onnx"]
  70. if len(ai_onnx_domain) == 1:
  71. opset_version = ai_onnx_domain[0].version
  72. if opset_version <= 10:
  73. ReplaceUpsampleWithResize(ONNXModel(model), opset_version).apply()
  74. model = onnx.version_converter.convert_version(model, 11)
  75. model = save_and_reload_model_with_shape_infer(model)
  76. if not skip_symbolic_shape:
  77. logger.info("Performing symbolic shape inference...")
  78. model = SymbolicShapeInference.infer_shapes(
  79. model,
  80. int_max,
  81. auto_merge,
  82. guess_output_rank,
  83. verbose,
  84. )
  85. if not skip_optimization:
  86. # Use ORT optimizers (native code) to optimize model
  87. if not skip_symbolic_shape:
  88. # Need to save the inferenced model to file so as to run the optimizer
  89. input_model = str(temp_path / "symbolic_shape_inferred.onnx")
  90. if save_as_external_data:
  91. onnx.save_model(
  92. model,
  93. input_model,
  94. save_as_external_data=True,
  95. all_tensors_to_one_file=all_tensors_to_one_file,
  96. size_threshold=external_data_size_threshold,
  97. convert_attribute=False,
  98. )
  99. else:
  100. onnx.save(model, input_model)
  101. model = None
  102. opt_model_path = str(temp_path / "optimized.onnx")
  103. try:
  104. sess_option = onnxruntime.SessionOptions()
  105. sess_option.optimized_model_filepath = opt_model_path
  106. sess_option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC
  107. # For large model, extract external data from model and add to session options
  108. if isinstance(input_model, onnx.ModelProto):
  109. if has_external_data(input_model):
  110. raise ValueError(
  111. "ModelProto has external data not loaded into memory, ORT cannot create session. "
  112. "Please load external data before calling this function. "
  113. "See https://onnx.ai/onnx/repo-docs/ExternalData.html for more information."
  114. )
  115. external_names, external_values = extract_raw_data_from_model(input_model)
  116. sess_option.add_external_initializers(list(external_names), list(external_values))
  117. input_model = input_model.SerializeToString()
  118. # the saved optimized model otherwise points to the original external data file name
  119. # which is not available relative to the optimized model file
  120. elif skip_symbolic_shape and save_as_external_data:
  121. sess_option.add_session_config_entry(
  122. "session.optimized_model_external_initializers_file_name", "optimized.onnx.data"
  123. )
  124. sess = onnxruntime.InferenceSession(input_model, sess_option, providers=["CPUExecutionProvider"])
  125. # Close the session to avoid the cleanup error on Windows for temp folders
  126. # https://github.com/microsoft/onnxruntime/issues/17627
  127. del sess
  128. except Exception:
  129. logger.error(
  130. "ONNX Runtime Model Optimization Failed! Consider rerun with option `--skip_optimization'."
  131. )
  132. logger.error(traceback.format_exc())
  133. input_model = opt_model_path
  134. if not skip_onnx_shape:
  135. # ONNX shape inference.
  136. # According to docs, infer_shapes_path should be used for 2G+ models.
  137. # If the skip optimization is specified, we could be dealing with a
  138. # large model. So be on the safe side, save the model
  139. if model is not None:
  140. input_model = str(temp_path / "symbolic_shape_inferred.onnx")
  141. if save_as_external_data:
  142. onnx.save_model(
  143. model,
  144. input_model,
  145. save_as_external_data=True,
  146. all_tensors_to_one_file=all_tensors_to_one_file,
  147. size_threshold=external_data_size_threshold,
  148. convert_attribute=False,
  149. )
  150. else:
  151. onnx.save(model, input_model)
  152. model = None
  153. if isinstance(input_model, onnx.ModelProto):
  154. input_model = str(Path(quant_tmp_dir) / "model_input.onnx")
  155. onnx.save_model(
  156. model,
  157. input_model,
  158. save_as_external_data=True,
  159. all_tensors_to_one_file=all_tensors_to_one_file,
  160. size_threshold=external_data_size_threshold,
  161. convert_attribute=False,
  162. )
  163. inferred_model_path = str(temp_path / "onnx_shape_inferred.onnx")
  164. onnx.shape_inference.infer_shapes_path(input_model, inferred_model_path)
  165. model = onnx.load(inferred_model_path)
  166. if model is None:
  167. model = input_model if isinstance(input_model, onnx.ModelProto) else onnx.load(input_model)
  168. add_pre_process_metadata(model)
  169. if save_as_external_data:
  170. onnx.save_model(
  171. model,
  172. output_model_path,
  173. save_as_external_data=True,
  174. all_tensors_to_one_file=all_tensors_to_one_file,
  175. location=external_data_location,
  176. size_threshold=external_data_size_threshold,
  177. convert_attribute=False,
  178. )
  179. else:
  180. onnx.save(model, output_model_path)