quantize.py 52 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. from __future__ import annotations
  7. import copy
  8. import logging
  9. import tempfile
  10. from collections.abc import Callable
  11. from pathlib import Path
  12. from typing import Any
  13. import onnx
  14. from .calibrate import CalibrationDataReader, CalibrationMethod, TensorsData, create_calibrator
  15. from .onnx_quantizer import ONNXQuantizer
  16. from .qdq_quantizer import QDQQuantizer
  17. from .quant_utils import (
  18. MODEL_SIZE_THRESHOLD,
  19. QuantFormat,
  20. QuantizationMode,
  21. QuantType,
  22. load_model_with_shape_infer,
  23. model_has_pre_process_metadata,
  24. save_and_reload_model_with_shape_infer,
  25. update_opset_version,
  26. )
  27. from .registry import IntegerOpsRegistry, QDQRegistry, QLinearOpsRegistry
  28. from .tensor_quant_overrides import TensorQuantOverridesHelper
  29. class QuantConfig:
  30. def __init__(
  31. self,
  32. activation_type=QuantType.QUInt8,
  33. weight_type=QuantType.QInt8,
  34. op_types_to_quantize=None,
  35. nodes_to_quantize=None,
  36. nodes_to_exclude=None,
  37. per_channel=False,
  38. reduce_range=False,
  39. use_external_data_format=False,
  40. ):
  41. """
  42. This is the Base class for both Static and Dynamic Quantize Configuration
  43. Args:
  44. activation_type:
  45. quantization data type of activation. Please refer to
  46. https://onnxruntime.ai/docs/performance/quantization.html for more details on data type selection
  47. weight_type:
  48. quantization data type of weight. Please refer to
  49. https://onnxruntime.ai/docs/performance/quantization.html for more details on data type selection
  50. op_types_to_quantize:
  51. specify the types of operators to quantize, like ['Conv'] to quantize Conv only.
  52. It quantizes all supported operators by default.
  53. nodes_to_quantize:
  54. List of nodes names to quantize. When this list is not None only the nodes in this list
  55. are quantized.
  56. example:
  57. [
  58. 'Conv__224',
  59. 'Conv__252'
  60. ]
  61. nodes_to_exclude:
  62. List of nodes names to exclude. The nodes in this list will be excluded from quantization
  63. when it is not None.
  64. per_channel: quantize weights per channel
  65. reduce_range:
  66. quantize weights with 7-bits. It may improve the accuracy for some models running on non-VNNI machine,
  67. especially for per-channel mode
  68. use_external_data_format: option used for large size (>2GB) model. Set to False by default.
  69. """
  70. nodes_to_exclude = nodes_to_exclude or []
  71. nodes_to_quantize = nodes_to_quantize or []
  72. op_types_to_quantize = op_types_to_quantize or []
  73. self.op_types_to_quantize = op_types_to_quantize
  74. self.per_channel = per_channel
  75. self.reduce_range = reduce_range
  76. self.weight_type = weight_type
  77. self.activation_type = activation_type
  78. self.nodes_to_quantize = nodes_to_quantize
  79. self.nodes_to_exclude = nodes_to_exclude
  80. self.use_external_data_format = use_external_data_format
  81. class StaticQuantConfig(QuantConfig):
  82. def __init__(
  83. self,
  84. calibration_data_reader: CalibrationDataReader,
  85. calibrate_method=CalibrationMethod.MinMax,
  86. quant_format=QuantFormat.QDQ,
  87. activation_type=QuantType.QInt8,
  88. weight_type=QuantType.QInt8,
  89. op_types_to_quantize=None,
  90. nodes_to_quantize=None,
  91. nodes_to_exclude=None,
  92. per_channel=False,
  93. reduce_range=False,
  94. use_external_data_format=False,
  95. calibration_providers=None,
  96. extra_options=None,
  97. ):
  98. """
  99. This is the derived class for static Quantize Configuration
  100. Args:
  101. calibration_data_reader:
  102. a calibration data reader. It enumerates calibration data and generates inputs for the original model.
  103. calibrate_method:
  104. Current calibration methods supported are MinMax, Entropy and Percentile.
  105. quant_format: QuantFormat{QOperator, QDQ}.
  106. QOperator format quantizes the model with quantized operators directly.
  107. QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
  108. calibration_providers: Execution providers to run the session during calibration. Default is None which uses
  109. [ "CPUExecutionProvider" ].
  110. extra_options:
  111. key value pair dictionary for various options in different case. Current used:
  112. extra.Sigmoid.nnapi = True/False (Default is False)
  113. ActivationSymmetric = True/False: symmetrize calibration data for activations (default is False).
  114. WeightSymmetric = True/False: symmetrize calibration data for weights (default is True).
  115. EnableSubgraph = True/False : Default is False. If enabled, subgraph will be quantized.
  116. Dyanmic mode currently is supported. Will support more in future.
  117. ForceQuantizeNoInputCheck = True/False :
  118. By default, some latent operators like maxpool, transpose, do not quantize if their input is not
  119. quantized already. Setting to True to force such operator always quantize input and so generate
  120. quantized output. Also the True behavior could be disabled per node using the nodes_to_exclude.
  121. MatMulConstBOnly = True/False:
  122. Default is False for static mode. If enabled, only MatMul with const B will be quantized.
  123. AddQDQPairToWeight = True/False :
  124. Default is False which quantizes floating-point weight and feeds it to solely inserted
  125. DeQuantizeLinear node. If True, it remains floating-point weight and inserts both
  126. QuantizeLinear/DeQuantizeLinear nodes to weight.
  127. OpTypesToExcludeOutputQuantization = list of op type :
  128. Default is []. If any op type is specified, it won't quantize the output of ops with this
  129. specific op types.
  130. DedicatedQDQPair = True/False :
  131. Default is False. When inserting QDQ pair, multiple nodes can share a single QDQ pair as their
  132. inputs. If True, it will create identical and dedicated QDQ pair for each node.
  133. QDQOpTypePerChannelSupportToAxis = dictionary :
  134. Default is {}. Set channel axis for specific op type, for example: {'MatMul': 1}, and it's
  135. effective only when per channel quantization is supported and per_channel is True. If specific
  136. op type supports per channel quantization but not explicitly specified with channel axis,
  137. default channel axis will be used.
  138. CalibTensorRangeSymmetric = True/False :
  139. Default is False. If enabled, the final range of tensor during calibration will be explicitly
  140. set to symmetric to central point "0".
  141. CalibMovingAverage = True/False :
  142. Default is False. If enabled, the moving average of the minimum and maximum values will be
  143. computed when the calibration method selected is MinMax.
  144. CalibMovingAverageConstant = float :
  145. Default is 0.01. Constant smoothing factor to use when computing the moving average of the
  146. minimum and maximum values. Effective only when the calibration method selected is MinMax and
  147. when CalibMovingAverage is set to True.
  148. QuantizeBias = True/False :
  149. Default is True which quantizes floating-point biases and it solely inserts
  150. a DeQuantizeLinear node. If False, it remains floating-point bias and does not insert
  151. any quantization nodes associated with biases.
  152. This extra option is only effective when quant_format is QuantFormat.QDQ.
  153. SmoothQuant = True/False :
  154. Default is False. If enabled, SmoothQuant algorithm will be applied before quantization to do
  155. fake input channel quantization.
  156. SmoothQuantAlpha = float :
  157. Default is 0.5. It only works if SmoothQuant is True. It controls the difficulty of weight
  158. and activation quantization. A larger alpha value could be used on models with more significant
  159. activation outliers to migrate more quantization difficulty to weights.
  160. SmoothQuantFolding = True/False :
  161. Default is True. It only works if SmoothQuant is True. If enabled, inserted Mul ops during
  162. SmoothQuant will be folded into the previous op if the previous op is foldable.
  163. UseQDQContribOps = True/False :
  164. Default is False. If enabled, the inserted QuantizeLinear and DequantizeLinear ops will have the
  165. `com.microsoft` domain, which forces use of ONNX Runtime's QuantizeLinear and DequantizeLinear
  166. contrib op implementations. The contrib op implementations may support features not standardized
  167. into the ONNX specification (e.g., 16-bit quantization types).
  168. MinimumRealRange = float|None :
  169. Default is None. If set to a floating-point value, the calculation of the quantization parameters
  170. (i.e., scale and zero point) will enforce a minimum range between rmin and rmax. If (rmax-rmin)
  171. is less than the specified minimum range, rmax will be set to rmin + MinimumRealRange. This is
  172. necessary for EPs like QNN that require a minimum floating-point range when determining
  173. quantization parameters.
  174. TensorQuantOverrides = dictionary :
  175. Default is {}. Set tensor quantization overrides. The key is a tensor name and the value is a
  176. list of dictionaries. For per-tensor quantization, the list contains a single dictionary. For
  177. per-channel quantization, the list contains a dictionary for each channel in the tensor.
  178. Each dictionary contains optional overrides with the following keys and values.
  179. 'quant_type' = QuantType : The tensor's quantization data type.
  180. 'scale' = Float : The scale value to use. Must also specify `zero_point` if set.
  181. 'zero_point' = Int : The zero-point value to use. Must also specify `scale` is set.
  182. 'symmetric' = Bool : If the tensor should use symmetric quantization. Invalid if also
  183. set `scale` or `zero_point`.
  184. 'reduce_range' = Bool : If the quantization range should be reduced. Invalid if also
  185. set `scale` or `zero_point`.
  186. 'rmax' = Float : Override the maximum real tensor value in calibration data.
  187. Invalid if also set `scale` or `zero_point`.
  188. 'rmin' = Float : Override the minimum real tensor value in calibration data.
  189. Invalid if also set `scale` or `zero_point`.
  190. QDQKeepRemovableActivations = True/False:
  191. Default is False. If true, "removable" activations (e.g., Clip or Relu) will not be removed, and
  192. will be explicitly represented in the QDQ model. If false, these activations are automatically
  193. removed if activations are asymmetrically quantized. Keeping these activations is necessary if
  194. optimizations or EP transformations will later remove QuantizeLinear/DequantizeLinear
  195. operators from the model.
  196. QDQDisableWeightAdjustForInt32Bias = True/False:
  197. Default is False. If true, QDQ quantizer will not adjust the weight's scale when the bias
  198. has a scale (input_scale * weight_scale) that is too small.
  199. execution_provider : A enum indicates the Execution Provider such as: CPU, TRT, NNAPI, SNE, etc.
  200. Raises:
  201. ValueError: Raise ValueError if execution provider is unknown
  202. """
  203. super().__init__(
  204. activation_type=activation_type,
  205. weight_type=weight_type,
  206. op_types_to_quantize=op_types_to_quantize,
  207. nodes_to_quantize=nodes_to_quantize,
  208. nodes_to_exclude=nodes_to_exclude,
  209. per_channel=per_channel,
  210. reduce_range=reduce_range,
  211. use_external_data_format=use_external_data_format,
  212. )
  213. self.calibration_data_reader = calibration_data_reader
  214. self.calibrate_method = calibrate_method
  215. self.quant_format = quant_format
  216. self.calibration_providers = calibration_providers
  217. self.extra_options = extra_options or {}
  218. def get_qdq_config(
  219. model_input: str | Path | onnx.ModelProto,
  220. calibration_data_reader: CalibrationDataReader,
  221. calibrate_method=CalibrationMethod.MinMax,
  222. calibrate_args: dict[str, Any] | None = None,
  223. activation_type=QuantType.QUInt8,
  224. weight_type=QuantType.QInt8,
  225. activation_symmetric: bool = False,
  226. weight_symmetric: bool | None = None,
  227. per_channel: bool = False,
  228. reduce_range: bool = False,
  229. keep_removable_activations: bool = False,
  230. min_real_range: float | None = None,
  231. tensor_quant_overrides: dict[str, list[dict[str, Any]]] | None = None,
  232. calibration_providers: list[str] | None = None,
  233. op_types_to_quantize: list[str] | None = None,
  234. nodes_to_exclude: list[str] | Callable[[onnx.ModelProto, onnx.NodeProto], bool] | None = None,
  235. extra_options: dict | None = None,
  236. ) -> StaticQuantConfig:
  237. """
  238. Returns a configuration suitable that quantizes the entire model to integer precision.
  239. Params:
  240. model_input: Path to the input model file or ModelProto.
  241. calibration_data_reader: Calibration data reader.
  242. calibrate_methode: The calibration method. Defaults to MinMax.
  243. activation_type: The default activation quantization type. Defaults to QUInt8.
  244. weight_type: The default weight quantization type. Defaults to QInt8.
  245. activation_symmetric: True if activations should be quantized symmetrically (i.e, rmax == -rmin) by default.
  246. Defaults to false. For int8 and int16, this results in zero-point values of 0. For uint8 and uint16,
  247. the zero-point values are 127 and 32,767, respectively.
  248. weight_symmetric: True if weights should be quantized symmetrically (i.e., rmax == -rmin) by default.
  249. Defaults to None. If set to None, weight_symmetric is assumed true if a weight's quant type is a signed int.
  250. per_channel: Global option that determines if a fixed set of operator types should be quantized per-channel.
  251. Defaults to false. Alternatively, use the tensor-level `tensor_quant_overrides` to select individual operators
  252. and their quantization axes.
  253. reduce_range: quantize weights with 1 less bit of precision (e.g., 7 bits for QInt8). Defaults to false.
  254. May improve the accuracy for some models running on non-VNNI machine, especially for per-channel mode.
  255. keep_removable_activations: Defaults to false. If true, "removable" activations (e.g., Clip or Relu) will not
  256. be removed, and will be explicitly represented in the QDQ model. If false, these activations
  257. are automatically removed if activations are asymmetrically quantized. Keeping these activations
  258. is necessary if optimizations or EP transformations will later remove
  259. QuantizeLinear/DequantizeLinear operators from the model.
  260. min_real_range: Default is None. If set to a floating-point value, the calculation of the quantization parameters
  261. (i.e., scale and zero point) will enforce a minimum range between rmin and rmax. If (rmax - rmin)
  262. is less than the specified minimum range, rmax will be set to rmin + min_real_range.
  263. tensor_quant_overrides: tensor-level quantization overrides. Defaults to None.
  264. The key is a tensor name and the value is a list of dictionaries. For per-tensor quantization, the list
  265. contains a single dictionary. For per-channel quantization, the list contains either a dictionary for
  266. each channel in the tensor or a single dictionary that is assumed to apply to all channels. An 'axis'
  267. key must be present in the first dictionary for per-channel quantization.
  268. Each dictionary contains optional overrides with the following keys and values.
  269. 'quant_type' = QuantType : The tensor's quantization data type.
  270. 'axis' = Int : The per-channel axis. Must be present for per-channel weights.
  271. 'scale' = Float : The scale value to use. Must also specify `zero_point` if set.
  272. 'zero_point' = Int : The zero-point value to use. Must also specify `scale` is set.
  273. 'symmetric' = Bool : If the tensor should use symmetric quantization. Invalid if also
  274. set `scale` or `zero_point`.
  275. 'reduce_range' = Bool : If the quantization range should be reduced. Invalid if also
  276. set `scale` or `zero_point`. Only valid for initializers.
  277. 'rmax' = Float : Override the maximum real tensor value in calibration data.
  278. Invalid if also set `scale` or `zero_point`.
  279. 'rmin' = Float : Override the minimum real tensor value in calibration data.
  280. Invalid if also set `scale` or `zero_point`.
  281. 'convert' = Dict : A nested dictionary with the same keys for an activation
  282. tensor that should be converted to another quantization type.
  283. 'convert["recv_nodes"] = Set : Set of node names that consume the converted activation,
  284. other nodes get the original type. If not specified,
  285. assume all consumer nodes get the converted type.
  286. calibration_providers: Execution providers to run the session during calibration. Default is None which uses
  287. [ "CPUExecutionProvider" ].
  288. op_types_to_quantize: List of operator types to quantize. If None, all operators other than Cast, DequantizeLinear,
  289. and QuantizeLinear are quantized.
  290. nodes_to_exclude: List of nodes names to exclude from quantization. Alternatively, can provide a function that
  291. accepts an onnx.ModelProto and onnx.NodeProto as arguments and returns true if the give onnx.NodeProto
  292. should be excluded from quantization.
  293. extra_options: Additional options specified as string key/value pairs. Refer to the documentation for
  294. `quantize_static` for valid keys and values.
  295. Returns:
  296. A StaticQuantConfig object
  297. """
  298. q16_types = {QuantType.QInt16, QuantType.QUInt16}
  299. q4_types = {QuantType.QInt4, QuantType.QUInt4}
  300. op_types_to_exclude = {"Cast", "DequantizeLinear", "QuantizeLinear"}
  301. model = (
  302. model_input
  303. if isinstance(model_input, onnx.ModelProto)
  304. else onnx.load_model(model_input, load_external_data=False)
  305. )
  306. op_types = set()
  307. model_has_external_data = False
  308. overrides_helper = TensorQuantOverridesHelper(
  309. copy.deepcopy(tensor_quant_overrides) if tensor_quant_overrides else {}
  310. )
  311. # check if the model has external data.
  312. for initializer in model.graph.initializer:
  313. if onnx.external_data_helper.uses_external_data(initializer):
  314. model_has_external_data = True
  315. op_types_to_quantize_set = set(op_types_to_quantize) if op_types_to_quantize else None
  316. nodes_to_exclude_set = set(nodes_to_exclude) if isinstance(nodes_to_exclude, list) else set()
  317. # Iterate through nodes to get all operator types in the model and
  318. # call user's function to filter out nodes from quantization.
  319. for node in model.graph.node:
  320. if op_types_to_quantize_set and node.op_type not in op_types_to_quantize_set:
  321. continue
  322. if node.name in nodes_to_exclude_set:
  323. continue
  324. if callable(nodes_to_exclude) and nodes_to_exclude(model, node):
  325. nodes_to_exclude_set.add(node.name)
  326. else:
  327. op_types.add(node.op_type)
  328. final_extra_options = {
  329. "MinimumRealRange": min_real_range,
  330. "QDQKeepRemovableActivations": keep_removable_activations,
  331. "ActivationSymmetric": activation_symmetric,
  332. "WeightSymmetric": weight_symmetric,
  333. "ForceQuantizeNoInputCheck": True,
  334. "TensorQuantOverrides": overrides_helper.get_dict(),
  335. }
  336. # Pass along known calibration options
  337. if calibrate_args:
  338. calib_extra_options_keys = [
  339. ("symmetric", "CalibTensorRangeSymmetric"),
  340. ("moving_average", "CalibMovingAverage"),
  341. ("averaging_constant", "CalibMovingAverageConstant"),
  342. ("max_intermediate_outputs", "CalibMaxIntermediateOutputs"),
  343. ("percentile", "CalibPercentile"),
  344. ]
  345. calib_extra_options = {
  346. key: calibrate_args.get(name) for (name, key) in calib_extra_options_keys if name in calibrate_args
  347. }
  348. final_extra_options.update(calib_extra_options)
  349. # ONNX opset < 21 does not support 16-bit quantization, so must use 'com.microsoft' domain
  350. # on Q/DQ operators if using 16-bit or 4-bit quantization.
  351. onnx_opset = next(x for x in model.opset_import if x.domain == "" or x.domain == "ai.onnx")
  352. if onnx_opset.version < 21:
  353. opset21_types = q16_types.union(q4_types)
  354. overrides_have_opset21_types = any(t in opset21_types for t in overrides_helper.get_quant_types())
  355. if activation_type in opset21_types or weight_type in opset21_types or overrides_have_opset21_types:
  356. final_extra_options["UseQDQContribOps"] = True
  357. # Allow user's extra_options to override our final_extra_options.
  358. if extra_options:
  359. final_extra_options.update(extra_options)
  360. return StaticQuantConfig(
  361. calibration_data_reader,
  362. calibrate_method=calibrate_method,
  363. quant_format=QuantFormat.QDQ,
  364. activation_type=activation_type,
  365. weight_type=weight_type,
  366. op_types_to_quantize=(
  367. op_types_to_quantize if op_types_to_quantize else list(op_types.difference(op_types_to_exclude))
  368. ),
  369. nodes_to_exclude=list(nodes_to_exclude_set),
  370. per_channel=per_channel,
  371. reduce_range=reduce_range,
  372. use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD),
  373. calibration_providers=calibration_providers,
  374. extra_options=final_extra_options,
  375. )
  376. class DynamicQuantConfig(QuantConfig):
  377. def __init__(
  378. self,
  379. weight_type=QuantType.QInt8,
  380. op_types_to_quantize=None,
  381. nodes_to_quantize=None,
  382. nodes_to_exclude=None,
  383. per_channel=False,
  384. reduce_range=False,
  385. use_external_data_format=False,
  386. extra_options=None,
  387. ):
  388. """
  389. This is a class for dynamic Quant Configuration
  390. Args:
  391. extra_options: key value pair dictionary for various options in different case. Current used:
  392. extra.Sigmoid.nnapi = True/False (Default is False)
  393. ActivationSymmetric = True/False: symmetrize calibration data for activations (default is False).
  394. WeightSymmetric = True/False: symmetrize calibration data for weights (default is True).
  395. EnableSubgraph = True/False :
  396. Default is False. If enabled, subgraph will be quantized. Dynamic mode currently is supported. Will
  397. support more in the future.
  398. ForceQuantizeNoInputCheck = True/False :
  399. By default, some latent operators like maxpool, transpose, do not quantize if their input is not
  400. quantized already. Setting to True to force such operator always quantize input and so generate
  401. quantized output. Also the True behavior could be disabled per node using the nodes_to_exclude.
  402. MatMulConstBOnly = True/False:
  403. Default is True for dynamic mode. If enabled, only MatMul with const B will be quantized.
  404. execution_provider : A enum indicates the Execution Provider such as: CPU, TRT, NNAPI, SNE, etc.
  405. Raises:
  406. ValueError: Raise ValueError if execution provider is unknown
  407. """
  408. super().__init__(
  409. op_types_to_quantize=op_types_to_quantize,
  410. per_channel=per_channel,
  411. reduce_range=reduce_range,
  412. weight_type=weight_type,
  413. nodes_to_quantize=nodes_to_quantize,
  414. nodes_to_exclude=nodes_to_exclude,
  415. use_external_data_format=use_external_data_format,
  416. )
  417. self.extra_options = extra_options or {}
  418. def check_static_quant_arguments(quant_format: QuantFormat, activation_type: QuantType, weight_type: QuantType):
  419. if activation_type == QuantType.QInt8 and weight_type == QuantType.QUInt8:
  420. raise ValueError(
  421. "ONNXRuntime quantization doesn't support data format:"
  422. "activation_type=QuantType.QInt8, weight_type=QuantType.QUInt8"
  423. )
  424. if activation_type != QuantType.QFLOAT8E4M3FN and weight_type == QuantType.QFLOAT8E4M3FN:
  425. raise ValueError(
  426. f"ONNXRuntime quantization doesn't support data format: activation_type={activation_type} "
  427. "!=QuantType.QFLOAT8E4M3FN, weight_type=QuantType.QFLOAT8E4M3FN."
  428. )
  429. if activation_type == QuantType.QFLOAT8E4M3FN and weight_type != QuantType.QFLOAT8E4M3FN:
  430. raise ValueError(
  431. "ONNXRuntime quantization doesn't support data format: activation_type=QuantType.QFLOAT8E4M3FN, "
  432. f"weight_type={weight_type}!=QuantType.QFLOAT8E4M3FN"
  433. )
  434. q16_types = [QuantType.QInt16, QuantType.QUInt16]
  435. if (activation_type in q16_types or weight_type in q16_types) and quant_format != QuantFormat.QDQ:
  436. raise ValueError("Only QuantFormat.QDQ supports 16-bit quantization types.")
  437. if activation_type == QuantType.QInt8 and weight_type == QuantType.QInt8 and quant_format != QuantFormat.QDQ:
  438. logging.warning(
  439. "Please use QuantFormat.QDQ for activation type QInt8 and weight type QInt8. "
  440. "Or it will lead to bad performance on x64."
  441. )
  442. def quantize_static(
  443. model_input: str | Path | onnx.ModelProto,
  444. model_output: str | Path,
  445. calibration_data_reader: CalibrationDataReader,
  446. quant_format=QuantFormat.QDQ,
  447. op_types_to_quantize=None,
  448. per_channel=False,
  449. reduce_range=False,
  450. activation_type=QuantType.QInt8,
  451. weight_type=QuantType.QInt8,
  452. nodes_to_quantize=None,
  453. nodes_to_exclude=None,
  454. use_external_data_format=False,
  455. calibrate_method=CalibrationMethod.MinMax,
  456. calibration_providers=None,
  457. extra_options=None,
  458. ):
  459. """
  460. Given an onnx model and calibration data reader, create a quantized onnx model and save it into a file
  461. It is recommended to use QuantFormat.QDQ format from 1.11 with activation_type = QuantType.QInt8 and weight_type
  462. = QuantType.QInt8. If model is targeted to GPU/TRT, symmetric activation and weight are required. If model is
  463. targeted to CPU, asymmetric activation and symmetric weight are recommended for balance of performance and
  464. accuracy.
  465. Args:
  466. model_input: file path of model or ModelProto to quantize
  467. model_output: file path of quantized model
  468. calibration_data_reader: a calibration data reader. It
  469. enumerates calibration data and generates inputs for the
  470. original model.
  471. quant_format: QuantFormat{QOperator, QDQ}.
  472. QOperator format quantizes the model with quantized operators directly.
  473. QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
  474. activation_type:
  475. quantization data type of activation. Please refer to
  476. https://onnxruntime.ai/docs/performance/quantization.html for more details on data type selection
  477. calibrate_method:
  478. Current calibration methods supported are MinMax and Entropy.
  479. Please use CalibrationMethod.MinMax or CalibrationMethod.Entropy as options.
  480. op_types_to_quantize:
  481. specify the types of operators to quantize, like ['Conv'] to quantize Conv only.
  482. It quantizes all supported operators by default.
  483. per_channel: quantize weights per channel
  484. reduce_range:
  485. quantize weights with 7-bits. It may improve the accuracy for some models running on non-VNNI machine,
  486. especially for per-channel mode
  487. weight_type:
  488. quantization data type of weight. Please refer to
  489. https://onnxruntime.ai/docs/performance/quantization.html for more details on data type selection
  490. nodes_to_quantize:
  491. List of nodes names to quantize. When this list is not None only the nodes in this list
  492. are quantized.
  493. example:
  494. [
  495. 'Conv__224',
  496. 'Conv__252'
  497. ]
  498. nodes_to_exclude:
  499. List of nodes names to exclude. The nodes in this list will be excluded from quantization
  500. when it is not None.
  501. use_external_data_format: option used for large size (>2GB) model. Set to False by default.
  502. calibration_providers: Execution providers to run the session during calibration. Default is None which uses
  503. [ "CPUExecutionProvider" ]
  504. extra_options:
  505. key value pair dictionary for various options in different case. Current used:
  506. extra.Sigmoid.nnapi = True/False (Default is False)
  507. ActivationSymmetric = True/False: symmetrize calibration data for activations (default is False).
  508. WeightSymmetric = True/False: symmetrize calibration data for weights (default is True).
  509. EnableSubgraph = True/False : Default is False. If enabled, subgraph will be quantized.
  510. Dyanmic mode currently is supported. Will support more in the future.
  511. ForceQuantizeNoInputCheck = True/False :
  512. By default, some latent operators like maxpool, transpose, do not quantize if their input is not
  513. quantized already. Setting to True to force such operator always quantize input and so generate
  514. quantized output. Also, the True behavior could be disabled per node using the nodes_to_exclude.
  515. MatMulConstBOnly = True/False:
  516. Default is False for static mode. If enabled, only MatMul with const B will be quantized.
  517. AddQDQPairToWeight = True/False :
  518. Default is False which quantizes floating-point weight and feeds it to solely inserted
  519. DeQuantizeLinear node. If True, it remains floating-point weight and inserts both
  520. QuantizeLinear/DeQuantizeLinear nodes to weight.
  521. OpTypesToExcludeOutputQuantization = list of op type :
  522. Default is []. If any op type is specified, it won't quantize the output of ops with this
  523. specific op types.
  524. DedicatedQDQPair = True/False :
  525. Default is False. When inserting QDQ pair, multiple nodes can share a single QDQ pair as their
  526. inputs. If True, it will create identical and dedicated QDQ pair for each node.
  527. QDQOpTypePerChannelSupportToAxis = dictionary :
  528. Default is {}. Set channel axis for specific op type, for example: {'MatMul': 1}, and it's
  529. effective only when per channel quantization is supported and per_channel is True. If specific
  530. op type supports per channel quantization but not explicitly specified with channel axis,
  531. default channel axis will be used.
  532. CalibTensorRangeSymmetric = True/False :
  533. Default is False. If enabled, the final range of tensor during calibration will be explicitly
  534. set to symmetric to central point "0".
  535. CalibStridedMinMax = Optional[int] :
  536. Default is None. If set to an integer, during calculation of the min-max, only stride amount of
  537. data will be used and then all results will be merged in the end.
  538. CalibMovingAverage = True/False :
  539. Default is False. If enabled, the moving average of the minimum and maximum values will be
  540. computed when the calibration method selected is MinMax.
  541. CalibMovingAverageConstant = float :
  542. Default is 0.01. Constant smoothing factor to use when computing the moving average of the
  543. minimum and maximum values. Effective only when the calibration method selected is MinMax and
  544. when CalibMovingAverage is set to True.
  545. CalibMaxIntermediateOutputs = Optional[int] :
  546. Default is None. If set to an integer, during calculation of the min-max range of the tensors
  547. it will load at max value number of outputs before computing and merging the range. This will
  548. produce the same result as all computing with None, but is more memory efficient.
  549. SmoothQuant = True/False :
  550. Default is False. If enabled, SmoothQuant algorithm will be applied before quantization to do
  551. fake input channel quantization.
  552. SmoothQuantAlpha = float :
  553. Default is 0.5. It only works if SmoothQuant is True. It controls the difficulty of weight
  554. and activation quantization. A larger alpha value could be used on models with more significant
  555. activation outliers to migrate more quantization difficulty to weights.
  556. SmoothQuantFolding = True/False :
  557. Default is True. It only works if SmoothQuant is True. If enabled, inserted Mul ops during
  558. SmoothQuant will be folded into the previous op if the previous op is foldable.
  559. UseQDQContribOps = True/False :
  560. Default is False. If enabled, the inserted QuantizeLinear and DequantizeLinear ops will have the
  561. `com.microsoft` domain, which forces use of ONNX Runtime's QuantizeLinear and DequantizeLinear
  562. contrib op implementations. The contrib op implementations may support features not standardized
  563. into the ONNX specification (e.g., 16-bit quantization types).
  564. MinimumRealRange = float|None :
  565. Default is None. If set to a floating-point value, the calculation of the quantization parameters
  566. (i.e., scale and zero point) will enforce a minimum range between rmin and rmax. If (rmax - rmin)
  567. is less than the specified minimum range, rmax will be set to rmin + MinimumRealRange. This is
  568. necessary for EPs like QNN that require a minimum floating-point range when determining
  569. quantization parameters.
  570. TensorQuantOverrides = dictionary :
  571. Default is {}. Set tensor quantization overrides. The key is a tensor name and the value is a
  572. list of dictionaries. For per-tensor quantization, the list contains a single dictionary. For
  573. per-channel quantization, the list contains a dictionary for each channel in the tensor.
  574. Each dictionary contains optional overrides with the following keys and values.
  575. 'quant_type' = QuantType : The tensor's quantization data type.
  576. 'scale' = Float : The scale value to use. Must also specify `zero_point` if set.
  577. 'zero_point' = Int : The zero-point value to use. Must also specify `scale` is set.
  578. 'symmetric' = Bool : If the tensor should use symmetric quantization. Invalid if also
  579. set `scale` or `zero_point`.
  580. 'reduce_range' = Bool : If the quantization range should be reduced. Invalid if also
  581. set `scale` or `zero_point`.
  582. 'rmax' = Float : Override the maximum real tensor value in calibration data.
  583. Invalid if also set `scale` or `zero_point`.
  584. 'rmin' = Float : Override the minimum real tensor value in calibration data.
  585. Invalid if also set `scale` or `zero_point`.
  586. QDQKeepRemovableActivations = True/False:
  587. Default is False. If true, "removable" activations (e.g., Clip or Relu) will not be removed, and
  588. will be explicitly represented in the QDQ model. If false, these activations are automatically
  589. removed if activations are asymmetrically quantized. Keeping these activations is necessary if
  590. optimizations or EP transformations will later remove QuantizeLinear/DequantizeLinear
  591. operators from the model.
  592. QDQDisableWeightAdjustForInt32Bias = True/False:
  593. Default is False. If true, QDQ quantizer will not adjust the weight's scale when the bias
  594. has a scale (input_scale * weight_scale) that is too small.
  595. """
  596. if activation_type == QuantType.QFLOAT8E4M3FN or weight_type == QuantType.QFLOAT8E4M3FN:
  597. if calibrate_method != CalibrationMethod.Distribution:
  598. raise ValueError("Only Distribution calibration method is supported for float quantization.")
  599. extra_options = extra_options or {}
  600. nodes_to_exclude = nodes_to_exclude or []
  601. nodes_to_quantize = nodes_to_quantize or []
  602. op_types_to_quantize = op_types_to_quantize or []
  603. mode = QuantizationMode.QLinearOps
  604. if not op_types_to_quantize or len(op_types_to_quantize) == 0:
  605. q_linear_ops = list(QLinearOpsRegistry.keys())
  606. qdq_ops = list(QDQRegistry.keys())
  607. op_types_to_quantize = list(set(q_linear_ops + qdq_ops))
  608. model = (
  609. save_and_reload_model_with_shape_infer(model_input)
  610. if isinstance(model_input, onnx.ModelProto)
  611. else load_model_with_shape_infer(Path(model_input))
  612. )
  613. pre_processed: bool = model_has_pre_process_metadata(model)
  614. if not pre_processed:
  615. logging.warning(
  616. "Please consider to run pre-processing before quantization. Refer to example: "
  617. "https://github.com/microsoft/onnxruntime-inference-examples/blob/main/quantization/image_classification"
  618. "/cpu/ReadMe.md "
  619. )
  620. calib_extra_options_keys = [
  621. ("CalibTensorRangeSymmetric", "symmetric"),
  622. ("CalibMovingAverage", "moving_average"),
  623. ("CalibMovingAverageConstant", "averaging_constant"),
  624. ("CalibMaxIntermediateOutputs", "max_intermediate_outputs"),
  625. ("CalibPercentile", "percentile"),
  626. ]
  627. calib_extra_options = {
  628. key: extra_options.get(name) for (name, key) in calib_extra_options_keys if name in extra_options
  629. }
  630. if extra_options.get("SmoothQuant", False):
  631. import importlib # noqa: PLC0415
  632. try:
  633. importlib.import_module("neural_compressor.adaptor.ox_utils.smooth_quant")
  634. except Exception as e:
  635. logging.error(f"{e}.")
  636. raise RuntimeError("neural-compressor is not correctly installed. Please check your environment.") from e
  637. from neural_compressor.adaptor.ox_utils.smooth_quant import ORTSmoothQuant # noqa: PLC0415
  638. def inc_dataloader():
  639. data_reader = copy.deepcopy(calibration_data_reader)
  640. for data in data_reader:
  641. yield data, None
  642. orig_nodes = [i.name for i in model.graph.node]
  643. dataloader = inc_dataloader()
  644. sq = ORTSmoothQuant(model_input, dataloader, reduce_range)
  645. del dataloader
  646. model = sq.transform(extra_options.get("SmoothQuantAlpha", 0.5), extra_options.get("SmoothQuantFolding", True))
  647. sq_path = tempfile.TemporaryDirectory(prefix="ort.quant.")
  648. model_input = Path(sq_path.name).joinpath("sq_model.onnx").as_posix()
  649. model.save(model_input)
  650. nodes_to_exclude.extend([i.name for i in model.model.graph.node if i.name not in orig_nodes])
  651. model = load_model_with_shape_infer(Path(model_input)) # use smooth quant model for calibration
  652. updated_model = update_opset_version(model, weight_type)
  653. is_model_updated = updated_model is not model
  654. if is_model_updated:
  655. model = updated_model
  656. with tempfile.TemporaryDirectory(prefix="ort.quant.") as quant_tmp_dir:
  657. if is_model_updated:
  658. # Update model_input and avoid to use the original one
  659. model_input = copy.deepcopy(model)
  660. if isinstance(model_input, onnx.ModelProto):
  661. output_path = Path(quant_tmp_dir).joinpath("model_input.onnx").as_posix()
  662. onnx.save_model(
  663. model_input,
  664. output_path,
  665. save_as_external_data=True,
  666. )
  667. model_input = output_path
  668. calibrator = create_calibrator(
  669. Path(model_input),
  670. op_types_to_quantize,
  671. augmented_model_path=Path(quant_tmp_dir).joinpath("augmented_model.onnx").as_posix(),
  672. calibrate_method=calibrate_method,
  673. use_external_data_format=use_external_data_format,
  674. providers=calibration_providers,
  675. extra_options=calib_extra_options,
  676. )
  677. stride = extra_options.get("CalibStridedMinMax", None)
  678. if stride:
  679. total_data_size = len(calibration_data_reader)
  680. if total_data_size % stride != 0:
  681. raise ValueError(f"Total data size ({total_data_size}) is not divisible by stride size ({stride}).")
  682. for start in range(0, total_data_size, stride):
  683. end_index = start + stride
  684. calibration_data_reader.set_range(start_index=start, end_index=end_index)
  685. calibrator.collect_data(calibration_data_reader)
  686. else:
  687. calibrator.collect_data(calibration_data_reader)
  688. tensors_range = calibrator.compute_data()
  689. if not isinstance(tensors_range, TensorsData):
  690. raise TypeError(
  691. f"Unexpected type {type(tensors_range)} for tensors_range and calibrator={type(calibrator)}."
  692. )
  693. del calibrator
  694. check_static_quant_arguments(quant_format, activation_type, weight_type)
  695. if quant_format is QuantFormat.QOperator:
  696. quantizer = ONNXQuantizer(
  697. model,
  698. per_channel,
  699. reduce_range,
  700. mode,
  701. True, # static
  702. weight_type,
  703. activation_type,
  704. tensors_range,
  705. nodes_to_quantize,
  706. nodes_to_exclude,
  707. op_types_to_quantize,
  708. extra_options,
  709. )
  710. else:
  711. quantizer = QDQQuantizer(
  712. model,
  713. per_channel,
  714. reduce_range,
  715. weight_type,
  716. activation_type,
  717. tensors_range,
  718. nodes_to_quantize,
  719. nodes_to_exclude,
  720. op_types_to_quantize,
  721. extra_options,
  722. )
  723. quantizer.quantize_model()
  724. quantizer.model.save_model_to_file(model_output, use_external_data_format)
  725. if not pre_processed:
  726. logging.warning(
  727. "Please consider pre-processing before quantization. See "
  728. "https://github.com/microsoft/onnxruntime-inference-examples/blob/main/quantization/image_classification"
  729. "/cpu/ReadMe.md "
  730. )
  731. if extra_options.get("SmoothQuant", False):
  732. sq_path.cleanup()
  733. def quantize_dynamic(
  734. model_input: str | Path | onnx.ModelProto,
  735. model_output: str | Path,
  736. op_types_to_quantize=None,
  737. per_channel=False,
  738. reduce_range=False,
  739. weight_type=QuantType.QInt8,
  740. nodes_to_quantize=None,
  741. nodes_to_exclude=None,
  742. use_external_data_format=False,
  743. extra_options=None,
  744. ):
  745. """Given an onnx model, create a quantized onnx model and save it into a file
  746. Args:
  747. model_input: file path of model or ModelProto to quantize
  748. model_output: file path of quantized model
  749. op_types_to_quantize:
  750. specify the types of operators to quantize, like ['Conv'] to quantize Conv only.
  751. It quantizes all supported operators by default.
  752. per_channel: quantize weights per channel
  753. reduce_range:
  754. quantize weights with 7-bits. It may improve the accuracy for some models running on non-VNNI machine,
  755. especially for per-channel mode
  756. weight_type:
  757. quantization data type of weight. Please refer to
  758. https://onnxruntime.ai/docs/performance/quantization.html for more details on data type selection
  759. nodes_to_quantize:
  760. List of nodes names to quantize. When this list is not None only the nodes in this list
  761. are quantized.
  762. example:
  763. [
  764. 'Conv__224',
  765. 'Conv__252'
  766. ]
  767. nodes_to_exclude:
  768. List of nodes names to exclude. The nodes in this list will be excluded from quantization
  769. when it is not None.
  770. use_external_data_format: option used for large size (>2GB) model. Set to False by default.
  771. extra_options:
  772. key value pair dictionary for various options in different case. Current used:
  773. extra.Sigmoid.nnapi = True/False (Default is False)
  774. ActivationSymmetric = True/False: symmetrize calibration data for activations (default is False).
  775. WeightSymmetric = True/False: symmetrize calibration data for weights (default is True).
  776. EnableSubgraph = True/False :
  777. Default is False. If enabled, subgraph will be quantized. Dynamic mode currently is supported. Will
  778. support more in the future.
  779. ForceQuantizeNoInputCheck = True/False :
  780. By default, some latent operators like maxpool, transpose, do not quantize if their input is not
  781. quantized already. Setting to True to force such operator always quantize input and so generate
  782. quantized output. Also the True behavior could be disabled per node using the nodes_to_exclude.
  783. MatMulConstBOnly = True/False:
  784. Default is True for dynamic mode. If enabled, only MatMul with const B will be quantized.
  785. """
  786. extra_options = extra_options or {}
  787. nodes_to_exclude = nodes_to_exclude or []
  788. nodes_to_quantize = nodes_to_quantize or []
  789. op_types_to_quantize = op_types_to_quantize or []
  790. mode = QuantizationMode.IntegerOps
  791. if not op_types_to_quantize or len(op_types_to_quantize) == 0:
  792. op_types_to_quantize = list(IntegerOpsRegistry.keys())
  793. model = (
  794. save_and_reload_model_with_shape_infer(model_input)
  795. if isinstance(model_input, onnx.ModelProto)
  796. else load_model_with_shape_infer(Path(model_input))
  797. )
  798. pre_processed: bool = model_has_pre_process_metadata(model)
  799. if not pre_processed:
  800. logging.warning(
  801. "Please consider to run pre-processing before quantization. Refer to example: "
  802. "https://github.com/microsoft/onnxruntime-inference-examples/blob/main/quantization/image_classification"
  803. "/cpu/ReadMe.md "
  804. )
  805. if "MatMulConstBOnly" not in extra_options:
  806. extra_options["MatMulConstBOnly"] = True
  807. model = update_opset_version(model, weight_type)
  808. quantizer = ONNXQuantizer(
  809. model,
  810. per_channel,
  811. reduce_range,
  812. mode,
  813. False, # static
  814. weight_type,
  815. QuantType.QUInt8, # dynamic activation only supports uint8
  816. None,
  817. nodes_to_quantize,
  818. nodes_to_exclude,
  819. op_types_to_quantize,
  820. extra_options,
  821. )
  822. quantizer.quantize_model()
  823. quantizer.model.save_model_to_file(model_output, use_external_data_format)
  824. def quantize(
  825. model_input: str | Path | onnx.ModelProto,
  826. model_output: str | Path,
  827. quant_config: QuantConfig,
  828. ):
  829. """Quantize a model with QuantConfig.
  830. Args:
  831. model_input (str | Path | ModelProto): Path to the model or ModelProto to quantize.
  832. model_output (str | Path): Path to save the quantized model.
  833. quant_config (QuantConfig | WeightOnlyQuantConfig): Quantization Configuration.
  834. """
  835. if isinstance(quant_config, StaticQuantConfig):
  836. quantize_static(
  837. model_input,
  838. model_output,
  839. quant_config.calibration_data_reader,
  840. calibrate_method=quant_config.calibrate_method,
  841. quant_format=quant_config.quant_format,
  842. activation_type=quant_config.activation_type,
  843. weight_type=quant_config.weight_type,
  844. op_types_to_quantize=quant_config.op_types_to_quantize,
  845. nodes_to_quantize=quant_config.nodes_to_quantize,
  846. nodes_to_exclude=quant_config.nodes_to_exclude,
  847. per_channel=quant_config.per_channel,
  848. reduce_range=quant_config.reduce_range,
  849. use_external_data_format=quant_config.use_external_data_format,
  850. calibration_providers=quant_config.calibration_providers,
  851. extra_options=quant_config.extra_options,
  852. )
  853. elif isinstance(quant_config, DynamicQuantConfig):
  854. quantize_dynamic(
  855. model_input,
  856. model_output,
  857. weight_type=quant_config.weight_type,
  858. op_types_to_quantize=quant_config.op_types_to_quantize,
  859. nodes_to_quantize=quant_config.nodes_to_quantize,
  860. nodes_to_exclude=quant_config.nodes_to_exclude,
  861. per_channel=quant_config.per_channel,
  862. reduce_range=quant_config.reduce_range,
  863. use_external_data_format=quant_config.use_external_data_format,
  864. extra_options=quant_config.extra_options,
  865. )
  866. else:
  867. # training package doesn't has quantize_matmul_4bits, avoid global import
  868. from .matmul_nbits_quantizer import MatMulNBitsQuantizer, WeightOnlyQuantConfig # noqa: PLC0415
  869. if isinstance(quant_config, WeightOnlyQuantConfig):
  870. model = model_input if isinstance(model_input, onnx.ModelProto) else onnx.load(model_input)
  871. quant = MatMulNBitsQuantizer(model, algo_config=quant_config)
  872. quant.process()
  873. quant.model.save_model_to_file(model_output, True)
  874. else:
  875. raise TypeError(
  876. "Invalid quantization config type, it must be either StaticQuantConfig, "
  877. "DynamicQuantConfig, or WeightOnlyQuantConfig."
  878. )