qdq_quantizer.py 68 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. from __future__ import annotations
  7. import logging
  8. from dataclasses import dataclass
  9. from enum import Enum
  10. from typing import Any
  11. import numpy as np
  12. import onnx
  13. from onnx import TensorProto
  14. from onnx import onnx_pb as onnx_proto
  15. from .base_quantizer import BaseQuantizer, QuantizationParams
  16. from .calibrate import TensorData
  17. from .quant_utils import (
  18. DEQUANT_OP_NAME,
  19. ONNX_TYPE_TO_NP_TYPE,
  20. QUANT_OP_NAME,
  21. QuantizedValue,
  22. QuantizedValueType,
  23. __producer__,
  24. __version__,
  25. add_dequant_output_suffix,
  26. add_dequant_suffix,
  27. add_quant_input_suffix,
  28. add_quant_output_suffix,
  29. add_quant_suffix,
  30. compute_data_quant_params,
  31. compute_scale_zp,
  32. compute_scale_zp_float8,
  33. find_by_name,
  34. get_qmin_qmax_for_qType,
  35. ms_domain,
  36. normalize_axis,
  37. quantize_onnx_initializer,
  38. tensor_proto_to_array,
  39. )
  40. from .registry import CreateQDQQuantizer
  41. class QDQQuantTensorType(Enum):
  42. ACTIVATION = 0
  43. WEIGHT = 1
  44. BIAS = 2
  45. # Holds the name of the node input from which a node output will share the
  46. # same quantization param initializers (zero-point and scale initializers).
  47. # Ex: A Transpose node's output will use the same quant param initializers used at the input.
  48. @dataclass
  49. class QDQQuantParamProvider:
  50. input_name: str
  51. node_name: str
  52. # Holds information for tensors that have been marked for quantization by operator quantizers.
  53. # Does not hold information for bias tensors.
  54. class QDQTensorQuantInfo:
  55. def __init__(self, tensor_type=QDQQuantTensorType.ACTIVATION, quant_para_provider=None, axis=None, data_type=None):
  56. self.tensor_type = tensor_type
  57. self.quant_para_provider = quant_para_provider
  58. self.axis = axis
  59. self.is_shared = quant_para_provider is not None
  60. assert data_type is not None
  61. self.data_type = data_type
  62. # Holds information for bias tensors that have been marked for quantization by operator quantizers.
  63. @dataclass
  64. class QDQBiasQuantInfo:
  65. node_name: str
  66. input_name: str
  67. weight_name: str
  68. beta: float
  69. # Holds quantization parameter values (scale, zp) for a tensor.
  70. # A tensor typically has a one set of quantization parameters, unless the tensor is
  71. # at a "mixed-precision" boundary where the activation quantization type changes (e.g., from uint8 to uint16).
  72. @dataclass
  73. class QDQTensorQuantParams:
  74. original: QuantizationParams # Generated by producer node.
  75. converted: QuantizationParams | None # Converted type consumed by some (or all/none) consumer nodes.
  76. converted_recv_nodes: set[str] | None # The name of nodes that consume the converted type.
  77. def get_for_consumer(self, consumer_node_name) -> QuantizationParams:
  78. if self.converted is None: # Quantized value is not converted, return original
  79. return self.original
  80. if self.converted_recv_nodes is None: # All consumers receive the converted value
  81. return self.converted
  82. # Check if consumer node name is in the list of nodes that
  83. # receive the converted quantization value. If not, return the original value generated
  84. # by the tensor's producer.
  85. return self.converted if (consumer_node_name in self.converted_recv_nodes) else self.original
  86. # Holds scale and zero_point initializer TensorProtos.
  87. @dataclass
  88. class QDQScaleZpInitializers:
  89. scale: TensorProto
  90. zero_point: TensorProto
  91. # Holds all scale and zero-point initializers for a tensor.
  92. # A tensor typically has a one set of quantization parameters, unless the tensor is
  93. # at a "mixed-precision" boundary where the activation quantization type changes (e.g., from uint8 to uint16).
  94. @dataclass
  95. class QDQTensorScaleZpInitializers:
  96. original: QDQScaleZpInitializers
  97. converted: QDQScaleZpInitializers | None
  98. converted_recv_nodes: set[str] | None
  99. # Holds cached information of a tensor's quantized values (types, zp/scale initializer names, etc.).
  100. # A tensor typically has a one set of quantization parameters, unless the tensor is
  101. # at a "mixed-precision" boundary where the activation quantization type changes (e.g., from uint8 to uint16).
  102. @dataclass
  103. class QDQTensorQuantizedValue:
  104. original: QuantizedValue
  105. converted: QuantizedValue | None
  106. converted_recv_nodes: set[str] | None
  107. def get_for_consumer(self, consumer_node_name) -> QuantizedValue:
  108. if self.converted is None: # Quantized value is not converted, return original
  109. return self.original
  110. if self.converted_recv_nodes is None: # All consumers receive the converted value
  111. return self.converted
  112. # Check if consumer node name is in the list of nodes that
  113. # receive the converted quantization value. If not, return the original value generated
  114. # by the tensor's producer.
  115. return self.converted if (consumer_node_name in self.converted_recv_nodes) else self.original
  116. class QDQQuantizer(BaseQuantizer):
  117. def __init__(
  118. self,
  119. model,
  120. per_channel,
  121. reduce_range,
  122. weight_qType,
  123. activation_qType,
  124. tensors_range,
  125. nodes_to_quantize,
  126. nodes_to_exclude,
  127. op_types_to_quantize,
  128. extra_options=None,
  129. ):
  130. BaseQuantizer.__init__(
  131. self,
  132. model,
  133. per_channel,
  134. reduce_range,
  135. weight_qType,
  136. activation_qType,
  137. tensors_range,
  138. nodes_to_quantize,
  139. nodes_to_exclude,
  140. op_types_to_quantize,
  141. extra_options,
  142. )
  143. self.tensors_to_quantize: dict[str, QDQTensorQuantInfo] = {}
  144. self.bias_to_quantize: dict[str, QDQBiasQuantInfo] = {}
  145. self.nodes_to_remove = []
  146. # Specific op types to exclude qdq quantization for their outputs.
  147. # In TRT, it's not recommended to quantize outputs for weighted ops such as Conv, Matmul, Gemm
  148. # because those ops may be followed by nodes that require high resolution inputs.
  149. # Adding QDQ for those ops' output may end up with worse accuracy.
  150. # So, we don't recommend to add QDQ to node's output under such condition.
  151. self.op_types_to_exclude_output_quantization = extra_options.get("OpTypesToExcludeOutputQuantization", [])
  152. # We do quantization on Dequantizelinear's input to remove Quantizelinear for weight as an optimization.
  153. # In some cases, for example QDQ BERT model for TensorRT, QDQ should always appear as a pair.
  154. # Therefore, we need to disable this optimization and add qdq pair to weight.
  155. self.add_qdq_pair_to_weight = extra_options.get("AddQDQPairToWeight", False)
  156. # Some scenarios do not need the bias quantized. For example, in the case of Quantization Aware Training,
  157. # quantizing the bias is not needed. This is because in QAT, all model parameters are expected to be in
  158. # floating point format. To that end, we can use the FakeQuant operator for weights and activations that
  159. # can always have QDQ pairs (by using AddQDQPairToWeight). But for biases in a quantized model, we can't use
  160. # FakeQuant because it only ever appears before a DQ (since it is quantized as int32).
  161. self.quantize_bias = extra_options.get("QuantizeBias", True)
  162. # The default behavior is that multiple nodes can share a QDQ pair as their inputs.
  163. # In TRT, QDQ pair can`t be shared between nodes, so it will create dedicated QDQ pairs for each node.
  164. self.dedicated_qdq_pair = extra_options.get("DedicatedQDQPair", False)
  165. self.tensor_to_its_receiving_nodes: dict[str, list[onnx.NodeProto]] = {}
  166. # Maps a tensor to the DequantizeLinear node (in the original input model) that outputs the tensor.
  167. # Populated for input models with some pre-quantized weights (typically via a different tool).
  168. self.tensor_to_producing_dq: dict[str, onnx.NodeProto] = {}
  169. # Let user set channel axis for specific op type and it's effective only when per channel quantization is supported and per_channel is True.
  170. self.qdq_op_type_per_channel_support_to_axis = extra_options.get("QDQOpTypePerChannelSupportToAxis", {})
  171. self.qdq_op_domain = ms_domain if extra_options.get("UseQDQContribOps", False) else None
  172. # User can specify if removable activations, like Clip/Relu, should be kept in the graph.
  173. # Used in the QDQRemovableActivation class.
  174. self.qdq_keep_removable_activations = extra_options.get("QDQKeepRemovableActivations", False)
  175. # Let user disable adjustment of weight scales for bias inputs that are quantized to int32.
  176. self.qdq_disable_weight_adjust_for_int32_bias = extra_options.get("QDQDisableWeightAdjustForInt32Bias", False)
  177. # The ONNX spec did not support 16-bit Q/DQ ops before opset 21.
  178. # So, may have to override the Q/DQ op domain to 'com.microsoft' if the activation or weight types
  179. # are 16-bit or 4-bit integers.
  180. if self.opset_version < 21:
  181. opset21_types = (TensorProto.UINT16, TensorProto.INT16, TensorProto.UINT4, TensorProto.INT4)
  182. overrides_have_opset21_types = any(
  183. t.tensor_type in opset21_types for t in self.tensor_quant_override_qtypes
  184. )
  185. if not self.qdq_op_domain and (
  186. self.activation_qType in opset21_types
  187. or self.weight_qType in opset21_types
  188. or overrides_have_opset21_types
  189. ):
  190. logging.warning(
  191. "ONNX QuantizeLinear and DequantizeLinear operators do not support "
  192. "16-bit/4-bit integer quantization types prior to opset 21. "
  193. f"The domain of QuantizeLinear and DequantizeLinear operators will be set to '{ms_domain}' to "
  194. "enable support."
  195. )
  196. self.qdq_op_domain = ms_domain
  197. self.quantization_params = self.calc_graph_quant_params()
  198. self.initializer_quant_params: dict[str, QuantizationParams] = {}
  199. # Map of all original value names to quantized value names
  200. self.quantized_value_map = {}
  201. def _get_tensor_type(self, tensor_name):
  202. """
  203. Check if tensor can be quantized
  204. """
  205. weight = find_by_name(tensor_name, self.model.initializer())
  206. if weight is not None:
  207. return weight.data_type
  208. elif tensor_name in self.value_infos:
  209. vi = self.value_infos[tensor_name]
  210. if vi.type.HasField("tensor_type"):
  211. return vi.type.tensor_type.elem_type
  212. return None
  213. def _is_tensor_quantizable(self, tensor_name):
  214. """
  215. Check if tensor can be quantized
  216. """
  217. weight = find_by_name(tensor_name, self.model.initializer())
  218. if weight is not None:
  219. if weight.data_type in (onnx_proto.TensorProto.FLOAT, onnx_proto.TensorProto.FLOAT16):
  220. return True
  221. elif tensor_name in self.value_infos:
  222. vi = self.value_infos[tensor_name]
  223. if vi.type.HasField("tensor_type") and vi.type.tensor_type.elem_type in (
  224. TensorProto.FLOAT,
  225. TensorProto.FLOAT16,
  226. ):
  227. return True
  228. else:
  229. logging.warning(
  230. f"failed to infer the type of tensor: {tensor_name}. Skip to quantize it. Please check if it is expected."
  231. )
  232. return False
  233. def __quantize_tensor(self, tensor_name, quant_sharing_provider=None, tensor_type=QDQQuantTensorType.ACTIVATION):
  234. """
  235. Adds a tensor to the list (actually a dict) of tensors to quantize. Called indirectly by op quantizers that
  236. want to quantize a tensor (i.e., "mark" a tensor for quantization).
  237. If quant_sharing_provider is not None, tensor with name tensor_name will be quantized with the same
  238. quantization parameters as the node input specified in quant_sharing_provider. Ex: A Tranpose node's output
  239. will typically use the same quantization parameter initializers used at the Transpose node's input.
  240. Args:
  241. tensor_name: name of the tensor to quantize
  242. quant_sharing_provider: name of the tensor and node that provides quantization parameter
  243. tensor_type: QDQQuantTensorType default ACTIVATION
  244. """
  245. if self._is_tensor_quantizable(tensor_name):
  246. if quant_sharing_provider:
  247. if not isinstance(quant_sharing_provider, QDQQuantParamProvider):
  248. raise TypeError(
  249. f"quant_sharing_provider must be of type QDQQuantParamProvider, not {type(quant_sharing_provider)}."
  250. )
  251. data_type = self._get_tensor_type(tensor_name)
  252. self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo(
  253. tensor_type=tensor_type, quant_para_provider=quant_sharing_provider, data_type=data_type
  254. )
  255. elif tensor_name not in self.tensors_to_quantize:
  256. data_type = self._get_tensor_type(tensor_name)
  257. self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo(tensor_type=tensor_type, data_type=data_type)
  258. def quantize_activation_tensor(self, tensor_name: str):
  259. """
  260. Adds a tensor to the list of tensors to quantize. Called by op quantizers that
  261. want to quantize a tensor (i.e., "mark" a tensor for quantization).
  262. Args:
  263. tensor_name: name of the tensor to quantize
  264. """
  265. return self.__quantize_tensor(tensor_name, None, QDQQuantTensorType.ACTIVATION)
  266. def quantize_output_same_as_input(self, output_name: str, input_name: str, node_name: str):
  267. """
  268. Adds a tensor to the list of tensors to quantize. Called by op quantizers that
  269. want to quantize an output tensor using the same quantization parameters as one of the node's inputs.
  270. Ex: A Tranpose node's output will typically use the same quantization parameter initializers used at
  271. the Transpose node's input.
  272. Args:
  273. output_name: name of the node output to quantize so that it uses the same quantization params as an input.
  274. input_name: name of the node input from which the output tensor will get its quantization params.
  275. node_name: name of the node that consumes `input_name`.
  276. """
  277. return self.__quantize_tensor(
  278. output_name, QDQQuantParamProvider(input_name, node_name), QDQQuantTensorType.ACTIVATION
  279. )
  280. def quantize_weight_tensor(self, tensor_name: str):
  281. """
  282. Adds a tensor to the list of weight tensors to quantize. Called by op quantizers that
  283. want to quantize a weight (i.e., "mark" a weight for quantization).
  284. Args:
  285. tensor_name: name of the weight to quantize
  286. """
  287. return self.__quantize_tensor(tensor_name, None, QDQQuantTensorType.WEIGHT)
  288. def quantize_weight_tensor_per_channel(self, tensor_name, axis):
  289. weight = find_by_name(tensor_name, self.model.initializer())
  290. if weight:
  291. if weight.data_type in (onnx_proto.TensorProto.FLOAT, onnx_proto.TensorProto.FLOAT16):
  292. self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo(
  293. tensor_type=QDQQuantTensorType.WEIGHT, axis=axis, data_type=weight.data_type
  294. )
  295. else:
  296. logging.warning(f"only support per-channel quantization on weight. Tensor: {tensor_name} is not quantized.")
  297. def _dup_initializer(self, initializer: onnx.TensorProto) -> onnx.TensorProto:
  298. """
  299. Duplicates an existing initializer and adds it to the model. Returns the new initializer.
  300. """
  301. name_suffix: int = self.model.get_largest_initializer_name_suffix(initializer.name) + 1
  302. new_initializer_name = f"{initializer.name}{name_suffix}"
  303. new_initializer = onnx.TensorProto()
  304. new_initializer.CopyFrom(initializer)
  305. new_initializer.name = new_initializer_name
  306. self.model.add_initializer(new_initializer)
  307. return new_initializer
  308. def quantize_bias_tensor(self, node_name, bias_name, input_name, weight_name, beta=1.0):
  309. """
  310. Adds a bias tensor to the list of bias tensors to quantize. Called by op quantizers that
  311. want to quantize a bias with bias_zero_point = 0 and bias_scale = input_scale * weight_scale * beta.
  312. TODO: Explain the reasoning for using this formula.
  313. Args:
  314. node_name: name of the node that consumes the bias, input, and weight tensors.
  315. bias_name: name of the bias tensor to quantize.
  316. input_name: name of the input tensor whose scale is used to compute the bias's scale.
  317. weight_name: name of the weight tensor whose scale is used to compute the bias's scale.
  318. beta: Multiplier used to compute the bias's scale.
  319. """
  320. # If the user provided quantization overrides for this tensor, treat it as a regular weight.
  321. if self.tensor_quant_overrides.get(bias_name):
  322. logging.info(
  323. f"Quantizing bias tensor '{bias_name}' as a weight due to the presence of user-specified overrides"
  324. )
  325. is_per_channel, axis = self.is_tensor_per_channel(bias_name, default_axis=0)
  326. if is_per_channel:
  327. self.quantize_weight_tensor_per_channel(bias_name, axis)
  328. else:
  329. self.quantize_weight_tensor(bias_name)
  330. return
  331. bias_initializer = find_by_name(bias_name, self.model.initializer())
  332. if bias_initializer is None:
  333. logging.warning(f"Expected bias '{bias_name}' to be an initializer")
  334. return
  335. if bias_initializer.data_type not in (onnx_proto.TensorProto.FLOAT, onnx_proto.TensorProto.FLOAT16):
  336. logging.info(f"Expected bias '{bias_name}' to be an floating-point initializer")
  337. return
  338. actual_bias_name = bias_name
  339. if bias_name in self.bias_to_quantize:
  340. # This bias input is consumed by two different nodes. We need to duplicate the bias so that
  341. # each node has its own bias input. This is necessary because the bias's scale is computed
  342. # from the node's other input scales.
  343. new_bias_initializer = self._dup_initializer(bias_initializer)
  344. actual_bias_name = new_bias_initializer.name
  345. # Replace this node's bias input
  346. self.model.replace_input_of_nodes(bias_name, actual_bias_name, {node_name})
  347. logging.info(f"Created a copy of bias input '{bias_name}' called '{actual_bias_name}'")
  348. # Add this to our list of biases to quantize.
  349. self.bias_to_quantize[actual_bias_name] = QDQBiasQuantInfo(node_name, input_name, weight_name, beta)
  350. def _adjust_weight_scale_for_int32_bias(
  351. self,
  352. input_scale: np.ndarray,
  353. weight_scale: np.ndarray,
  354. weight_name: str,
  355. bias_tp: onnx.TensorProto,
  356. is_per_channel: bool,
  357. ) -> tuple[bool, np.ndarray | None]:
  358. """
  359. Checks if the bias scale (input_scale * weight_scale) that we intend to use is too small.
  360. A bias scale that is too small leads to quantized bias values that fall outside the range of a int32 and have to
  361. be clipped, which decreases accuracy. If this function detects such a scenario, the weight_scale value will be
  362. increased to prevent this from happening.
  363. Although the adjustment method and amount differs, the idea to adjust the weight's scale came from the following
  364. reference:
  365. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/optimize/quantization_utils.cc#L252
  366. :param input_scale: The input's scale.
  367. :param weight_scale: The weight scale to potentially adjust.
  368. :param weight_name: The weight initializer's name. Used for logging.
  369. :param bias_tp: The bias ONNX initializer.
  370. :param is_per_channel: True if the bias and weight are quantized per-channel.
  371. :return: A tuple with a bool indicating if the weight's scale was adjusted and the new weight scale.
  372. """
  373. if not weight_scale.size:
  374. return False, None
  375. bias_float_data = tensor_proto_to_array(bias_tp)
  376. int32_info = np.iinfo(np.int32)
  377. multiplicative_epsilon = 1.0001
  378. qrange = np.array(int32_info.max, dtype=np.float64) - np.array(int32_info.min + 1, dtype=np.float64)
  379. weight_scale_dtype = weight_scale.dtype
  380. updated_an_elem = False
  381. if not is_per_channel:
  382. rmin = np.minimum(bias_float_data.min(), np.array(0, dtype=np.float64))
  383. rmax = np.maximum(bias_float_data.max(), np.array(0, dtype=np.float64))
  384. absmax = np.maximum(np.abs(rmin), np.abs(rmax))
  385. bias_smallest_valid_scale = multiplicative_epsilon * (2.0 * absmax) / qrange
  386. input_scale_fp64 = np.array(input_scale.item(), dtype=np.float64)
  387. weight_scale_fp64 = np.array(weight_scale.item(), dtype=np.float64)
  388. bias_candidate_scale = input_scale_fp64 * weight_scale_fp64
  389. if (bias_candidate_scale < bias_smallest_valid_scale) and (bias_candidate_scale > 0.0):
  390. # The candidate bias scale would be too small, so increase the weight_scale by the necessary ratio.
  391. ratio = bias_smallest_valid_scale / bias_candidate_scale
  392. logging.info(
  393. f"Increasing scale for weight `{weight_name}` by the ratio {ratio} to "
  394. f"ensure bias input `{bias_tp.name}` has a valid scale."
  395. )
  396. new_scale = weight_scale_fp64 * ratio
  397. weight_scale = new_scale.astype(weight_scale_dtype)
  398. updated_an_elem = True
  399. elif weight_scale.shape and len(weight_scale.shape) == 1:
  400. # per-channel case
  401. num_elems = weight_scale.shape[0]
  402. for i in range(num_elems):
  403. bias_rmax = np.abs(bias_float_data[i])
  404. bias_smallest_valid_scale = multiplicative_epsilon * (2.0 * bias_rmax) / qrange
  405. input_scale_fp64 = np.array(input_scale.item(), dtype=np.float64)
  406. weight_scale_fp64 = np.array(weight_scale[i].item(), dtype=np.float64)
  407. bias_candidate_scale = input_scale_fp64 * weight_scale_fp64
  408. if (bias_candidate_scale < bias_smallest_valid_scale) and (bias_candidate_scale > 0.0):
  409. # The candidate bias scale would be too small, so increase the weight_scale by the necessary ratio.
  410. ratio = bias_smallest_valid_scale / bias_candidate_scale
  411. logging.info(
  412. f"Increased scale[{i}] for weight `{weight_name}` by ratio {ratio} "
  413. f"to ensure bias input `{bias_tp.name}` has a valid scale."
  414. )
  415. new_scale = weight_scale_fp64 * ratio
  416. weight_scale[i] = new_scale.astype(weight_scale_dtype)
  417. updated_an_elem = True
  418. return updated_an_elem, weight_scale
  419. def _adjust_weight_quant_params_for_bias_tensors(self):
  420. """
  421. Iterates through all bias inputs that should be quantized to int32. If the intended
  422. bias scale (equal to input_scale * weight_scale) is too small, this function will increase
  423. the associated weight's scale to ensure the bias does not overflow the int32 range when quantized.
  424. """
  425. if self.qdq_disable_weight_adjust_for_int32_bias:
  426. # User passed an extra_option to disable this adjustment.
  427. return
  428. for bias_name, bias_info in self.bias_to_quantize.items():
  429. if (
  430. bias_info.input_name not in self.quantization_params
  431. or bias_info.input_name not in self.tensors_to_quantize
  432. or bias_info.weight_name not in self.initializer_quant_params
  433. ):
  434. continue
  435. # Get the associated input's scale.
  436. input_qparams = self.quantization_params[bias_info.input_name].get_for_consumer(bias_info.node_name)
  437. input_info = self.tensors_to_quantize[bias_info.input_name]
  438. input_scale = np.asarray(
  439. input_qparams["scale"], dtype=onnx.helper.tensor_dtype_to_np_dtype(input_info.data_type)
  440. )
  441. weight_quant_params = self.initializer_quant_params[bias_info.weight_name]
  442. weight_quant_type = weight_quant_params["quant_type"]
  443. if weight_quant_type not in (onnx.TensorProto.INT8, onnx.TensorProto.INT16):
  444. continue
  445. weight_zero_point: np.ndarray = weight_quant_params["zero_point"]
  446. if weight_zero_point.any():
  447. # Skip if zero_point(s) are not all zero (i.e., symmetric quant)
  448. continue
  449. weight_scale: np.ndarray = weight_quant_params["scale"]
  450. is_per_channel = weight_quant_params.get("axis", None) is not None
  451. # Get adjusted weight scales.
  452. did_update_weight_scale, new_weight_scale = self._adjust_weight_scale_for_int32_bias(
  453. input_scale,
  454. weight_scale,
  455. bias_info.weight_name,
  456. find_by_name(bias_name, self.model.initializer()),
  457. is_per_channel,
  458. )
  459. if did_update_weight_scale:
  460. weight_quant_params["scale"] = new_weight_scale
  461. def remove_node(self, node):
  462. self.nodes_to_remove.append(node)
  463. def remove_nodes(self):
  464. self.model.remove_nodes(self.nodes_to_remove)
  465. def quantize_model(self):
  466. for node in self.model.nodes():
  467. if self.should_quantize_node(node):
  468. op_quantizer = CreateQDQQuantizer(self, node)
  469. op_quantizer.quantize()
  470. for tensor_name in node.input:
  471. if tensor_name not in self.tensor_to_its_receiving_nodes:
  472. self.tensor_to_its_receiving_nodes[tensor_name] = []
  473. self.tensor_to_its_receiving_nodes[tensor_name].append(node)
  474. if node.op_type == DEQUANT_OP_NAME:
  475. for tensor_name in node.output:
  476. self.tensor_to_producing_dq[tensor_name] = node
  477. self.initializer_quant_params = self._calc_initializer_quant_params()
  478. self._adjust_weight_quant_params_for_bias_tensors()
  479. self._quantize_normal_tensors()
  480. self._quantize_sharing_param_tensors()
  481. if self.quantize_bias:
  482. self._quantize_bias_tensors()
  483. self.remove_nodes()
  484. if not self.add_qdq_pair_to_weight:
  485. self.model.clean_initializers()
  486. self.model.model.producer_name = __producer__
  487. self.model.model.producer_version = __version__
  488. if self.qdq_op_domain == ms_domain:
  489. self.model.set_opset_import(ms_domain, 1)
  490. return self.model.model
  491. def try_replacing_upstream_output(self, upstream_output_name, output_name):
  492. if (
  493. output_name in self.quantization_params
  494. and self.quantization_params[output_name].converted is None
  495. and self.quantization_params[upstream_output_name].converted is None
  496. and len(self.model.input_name_to_nodes()[upstream_output_name]) == 1
  497. and not self.model.is_graph_output(upstream_output_name)
  498. and not self.model.is_graph_input(upstream_output_name)
  499. ):
  500. self.model.replace_output_of_all_nodes(upstream_output_name, output_name)
  501. if upstream_output_name in self.tensors_to_quantize:
  502. del self.tensors_to_quantize[upstream_output_name]
  503. return True
  504. return False
  505. def _create_q_node(
  506. self,
  507. q_input: str,
  508. q_output: str,
  509. quant_node_name: str,
  510. scale_name: str,
  511. zp_name: str,
  512. axis: int | None = None,
  513. ):
  514. """
  515. Creates a QuantizeLinear node and adds it to the model.
  516. """
  517. qlinear_node = onnx.helper.make_node(
  518. QUANT_OP_NAME,
  519. [q_input, scale_name, zp_name],
  520. [q_output],
  521. quant_node_name,
  522. axis=axis,
  523. domain=self.qdq_op_domain,
  524. )
  525. self.model.add_nodes([qlinear_node])
  526. def _create_dq_node(
  527. self,
  528. dq_input: str,
  529. dq_output: str,
  530. dequant_node_name: str,
  531. scale_name: str,
  532. zp_name: str,
  533. axis: int | None = None,
  534. ):
  535. """
  536. Creates a DequantizeLinear node and adds it to the model.
  537. """
  538. dequant_node = onnx.helper.make_node(
  539. DEQUANT_OP_NAME,
  540. [dq_input, scale_name, zp_name],
  541. [dq_output],
  542. dequant_node_name,
  543. axis=axis,
  544. domain=self.qdq_op_domain,
  545. )
  546. self.model.add_nodes([dequant_node])
  547. def _create_qdq_nodes(
  548. self, q_input, q_output, quant_node_name, dq_input, dq_output, dequant_node_name, scale_name, zp_name, axis=None
  549. ):
  550. qlinear_node = onnx.helper.make_node(
  551. QUANT_OP_NAME,
  552. [q_input, scale_name, zp_name],
  553. [q_output],
  554. quant_node_name,
  555. axis=axis,
  556. domain=self.qdq_op_domain,
  557. )
  558. dequant_node = onnx.helper.make_node(
  559. DEQUANT_OP_NAME,
  560. [dq_input, scale_name, zp_name],
  561. [dq_output],
  562. dequant_node_name,
  563. axis=axis,
  564. domain=self.qdq_op_domain,
  565. )
  566. self.model.add_nodes([qlinear_node, dequant_node])
  567. def _add_qdq_nodes_for_initializer(self, weight_proto: onnx.TensorProto):
  568. """
  569. Adds Q/DQ nodes for an initializer. If `self.add_qdq_pair_to_weight` is true, creates
  570. the sequence (weight_f32 -> Q -> DQ -> ). Otherwise, this function quantizes the initializer
  571. and adds the sequence (weight_quant -> DQ ->).
  572. """
  573. weight_name = weight_proto.name
  574. if weight_name in self.quantized_value_map:
  575. return
  576. quant_params: QuantizationParams = self.initializer_quant_params[weight_name]
  577. axis: int = quant_params.get("axis")
  578. scale_zp_initializers = self._make_scale_zp_initializers(weight_name, quant_params)
  579. q_weight_name: str | None = None
  580. weight_dequant_output = add_dequant_output_suffix(weight_name)
  581. self.model.replace_input_of_all_nodes(weight_name, weight_dequant_output)
  582. if self.add_qdq_pair_to_weight:
  583. # Don't actually quantize the weight. Instead, keep floating-point weight and create the node
  584. # sequence (weight_f32 -> Q -> DQ -> weight_dequant)
  585. weight_quant_output = add_quant_output_suffix(weight_name)
  586. self._create_qdq_nodes(
  587. weight_name,
  588. weight_quant_output,
  589. add_quant_suffix(weight_name),
  590. weight_quant_output,
  591. weight_dequant_output,
  592. add_dequant_suffix(weight_name),
  593. scale_zp_initializers.scale.name,
  594. scale_zp_initializers.zero_point.name,
  595. axis,
  596. )
  597. else:
  598. # Quantize the weight and create the node sequence:
  599. # (weight_quantized -> DQ -> weight_dequant)
  600. quant_weight = quantize_onnx_initializer(
  601. weight_proto,
  602. quant_params["quant_type"],
  603. quant_params["zero_point"],
  604. quant_params["scale"],
  605. axis,
  606. )
  607. self.model.add_initializer(quant_weight)
  608. q_weight_name = quant_weight.name
  609. dequant_node = onnx.helper.make_node(
  610. DEQUANT_OP_NAME,
  611. [quant_weight.name, scale_zp_initializers.scale.name, scale_zp_initializers.zero_point.name],
  612. [weight_dequant_output],
  613. add_dequant_suffix(weight_name),
  614. axis=axis,
  615. domain=self.qdq_op_domain,
  616. )
  617. self.model.add_node(dequant_node)
  618. # Log entry for this quantized weight
  619. quantized_value = QuantizedValue(
  620. weight_name,
  621. q_weight_name,
  622. scale_zp_initializers.scale.name,
  623. scale_zp_initializers.zero_point.name,
  624. QuantizedValueType.Initializer,
  625. axis=axis,
  626. )
  627. self.quantized_value_map[weight_name] = QDQTensorQuantizedValue(quantized_value, None, None)
  628. def _add_qdq_pair_for_activation(self, tensor_name, scale_name, zp_name, data_type=None):
  629. if (
  630. self.dedicated_qdq_pair
  631. and tensor_name in self.tensor_to_its_receiving_nodes
  632. and len(self.tensor_to_its_receiving_nodes[tensor_name]) > 1
  633. ):
  634. num_dedicated_qdq_pair = len(self.tensor_to_its_receiving_nodes[tensor_name])
  635. for i in range(num_dedicated_qdq_pair):
  636. postfix = f"_{i + 1}"
  637. tensor_name_quant_output_postfix = add_quant_output_suffix(tensor_name) + postfix
  638. tensor_name_dequant_output_postfix = add_dequant_output_suffix(tensor_name) + postfix
  639. quant_node_name_postfix = add_quant_suffix(tensor_name) + postfix
  640. dequant_node_name_postfix = add_dequant_suffix(tensor_name) + postfix
  641. self._create_qdq_nodes(
  642. tensor_name,
  643. tensor_name_quant_output_postfix,
  644. quant_node_name_postfix,
  645. tensor_name_quant_output_postfix,
  646. tensor_name_dequant_output_postfix,
  647. dequant_node_name_postfix,
  648. scale_name,
  649. zp_name,
  650. )
  651. node = self.tensor_to_its_receiving_nodes[tensor_name][i]
  652. self.model.replace_node_input(node, tensor_name, tensor_name_dequant_output_postfix)
  653. if i == 0:
  654. quantized_value = QuantizedValue(
  655. tensor_name,
  656. tensor_name_dequant_output_postfix,
  657. scale_name,
  658. zp_name,
  659. QuantizedValueType.Input,
  660. scale_type=data_type,
  661. )
  662. self.quantized_value_map[tensor_name] = QDQTensorQuantizedValue(quantized_value, None, None)
  663. else:
  664. q_input = tensor_name
  665. dq_output = add_dequant_output_suffix(tensor_name)
  666. if self.model.is_graph_output(tensor_name):
  667. q_input = add_quant_input_suffix(tensor_name)
  668. dq_output = tensor_name
  669. self.model.replace_output_of_all_nodes(tensor_name, q_input)
  670. else:
  671. self.model.replace_input_of_all_nodes(tensor_name, dq_output)
  672. self._create_qdq_nodes(
  673. q_input,
  674. add_quant_output_suffix(tensor_name),
  675. add_quant_suffix(tensor_name),
  676. add_quant_output_suffix(tensor_name),
  677. dq_output,
  678. add_dequant_suffix(tensor_name),
  679. scale_name,
  680. zp_name,
  681. )
  682. quantized_value = QuantizedValue(
  683. tensor_name,
  684. dq_output,
  685. scale_name,
  686. zp_name,
  687. QuantizedValueType.Input,
  688. scale_type=data_type,
  689. )
  690. self.quantized_value_map[tensor_name] = QDQTensorQuantizedValue(quantized_value, None, None)
  691. def _add_qdq_ops_for_converted_activation(
  692. self,
  693. tensor_name,
  694. first_scale_name,
  695. first_zp_name,
  696. scale_data_type,
  697. convert_scale_name,
  698. convert_zp_name,
  699. convert_recv_nodes,
  700. ):
  701. """
  702. Adds Q and DQ ops to a tensor whose quantized data type is converted. That is, some consumers may use the
  703. original data type from the producer, while other consumers use the converted data type.
  704. This is generally done by adding a sequence of ops that convert from one data type (e.g., uint8) to another (e.g., uint16).
  705. T_float ---> Quant(to u8) ---> Convert(to u16) ---> Dequant(to float) ---> T_float'
  706. where Convert(to u16) is equivalent to: ---> Dequant(to float) ---> Quant(to u16) --->
  707. This function handles the following scenarios:
  708. 1) Tensor T is not a graph output; all consumers use the converted type
  709. <Producer> ---> Q1 ---> DQ1 ---> Q2 ---> DQ2 ---> <Consumers>
  710. 2) Tensor T is not a graph output; some consumers use the original type, others use the converted type
  711. <Producer> ---> Q1 -+-> DQ1 ---> <Consumers of original type>
  712. |
  713. +-> DQ1' ---> Q2 ---> DQ2 ---> <Consumers of converted type>
  714. 3) Tensor T is a graph output; all consumers use the converted type
  715. <Producer> ---> Q1 ---> DQ1 ---> Q2 ---> DQ2 -+-> <Consumers>
  716. |
  717. +-> <Graph output>
  718. 4) Tensor T is a graph output; some consumers use the original type, others use the converted type
  719. <Producer> ---> Q1 -+-> DQ1 -+-> <Consumers of original type>
  720. | |
  721. | +-> <Graph output>
  722. |
  723. +-> DQ1' ---> Q2 ---> DQ2 ---> <Consumers of converted type>
  724. 5) Tensor T is a graph output that is not consumed by any other nodes.
  725. <Producer> ---> Q1 ---> DQ1 ---> Q2 ---> DQ2 ---> <Graph output>
  726. """
  727. tensor_recv_nodes = {node.name for node in self.tensor_to_its_receiving_nodes.get(tensor_name, [])}
  728. if (
  729. self.dedicated_qdq_pair
  730. and tensor_name in self.tensor_to_its_receiving_nodes
  731. and len(self.tensor_to_its_receiving_nodes[tensor_name]) > 1
  732. ):
  733. # TODO: Add support for dedicated_qdq_pair if/when needed.
  734. raise ValueError(
  735. "Do not currently support converted quant_types in TensorQuantOverrides when the `dedicated_qdq_pair` extra_option is enabled"
  736. )
  737. # Determine which nodes consume the original quantized type and which nodes
  738. # consume the converted quantized type.
  739. original_recv_nodes = tensor_recv_nodes
  740. if convert_recv_nodes is None: # In this case, all consumers receive the converted type.
  741. convert_recv_nodes = tensor_recv_nodes
  742. original_recv_nodes = set()
  743. else:
  744. original_recv_nodes = original_recv_nodes - convert_recv_nodes
  745. all_use_converted = len(convert_recv_nodes) == len(tensor_recv_nodes)
  746. is_graph_output = self.model.is_graph_output(tensor_name)
  747. # Create first Q op.
  748. first_q_input = tensor_name
  749. if is_graph_output:
  750. first_q_input = add_quant_input_suffix(tensor_name)
  751. self.model.replace_output_of_all_nodes(tensor_name, first_q_input)
  752. first_q_output = add_quant_output_suffix(tensor_name)
  753. self._create_q_node(
  754. first_q_input, first_q_output, add_quant_suffix(tensor_name), first_scale_name, first_zp_name
  755. )
  756. # Create first DQ op.
  757. first_dq_output = add_dequant_output_suffix(tensor_name)
  758. if is_graph_output and not all_use_converted:
  759. first_dq_output = tensor_name
  760. if original_recv_nodes and first_dq_output != tensor_name:
  761. self.model.replace_input_of_nodes(tensor_name, first_dq_output, original_recv_nodes)
  762. self._create_dq_node(
  763. first_q_output, first_dq_output, add_dequant_suffix(tensor_name), first_scale_name, first_zp_name
  764. )
  765. # Create parallel clone of first DQ op if _not all_ consumers use the converted type.
  766. # --> DQ1' --> Q2 --> DQ2 --> <Consumers of converted type>
  767. #
  768. # This DQ clone would only have one consumer Q node (Q2) and could be potentially fused with
  769. # it by some EPs (e.g., QNN) without breaking other "node units".
  770. # Ex QNN fusion:
  771. # --> Convert (fused) --> DQ2 --> <Consumers of converted type>
  772. second_q_input = first_dq_output
  773. if not all_use_converted:
  774. second_q_input = add_quant_input_suffix(f"{tensor_name}_convert")
  775. self._create_dq_node(
  776. first_q_output,
  777. second_q_input,
  778. add_dequant_suffix(f"{tensor_name}_convert_clone"),
  779. first_scale_name,
  780. first_zp_name,
  781. )
  782. # Create second Q op.
  783. second_q_output = add_quant_output_suffix(f"{tensor_name}_convert")
  784. self._create_q_node(
  785. second_q_input,
  786. second_q_output,
  787. add_quant_suffix(f"{tensor_name}_convert"),
  788. convert_scale_name,
  789. convert_zp_name,
  790. )
  791. # Create second DQ op.
  792. second_dq_output = add_dequant_output_suffix(f"{tensor_name}_convert")
  793. if is_graph_output and all_use_converted:
  794. second_dq_output = tensor_name
  795. if convert_recv_nodes and second_dq_output != tensor_name:
  796. self.model.replace_input_of_nodes(tensor_name, second_dq_output, convert_recv_nodes)
  797. self._create_dq_node(
  798. second_q_output,
  799. second_dq_output,
  800. add_dequant_suffix(f"{tensor_name}_convert"),
  801. convert_scale_name,
  802. convert_zp_name,
  803. )
  804. # Store in quantized_value_map
  805. original_quantized_value = QuantizedValue(
  806. tensor_name,
  807. first_dq_output,
  808. first_scale_name,
  809. first_zp_name,
  810. QuantizedValueType.Input,
  811. scale_type=scale_data_type,
  812. )
  813. converted_quantized_value = QuantizedValue(
  814. tensor_name,
  815. second_dq_output,
  816. convert_scale_name,
  817. convert_zp_name,
  818. QuantizedValueType.Input,
  819. scale_type=scale_data_type,
  820. )
  821. self.quantized_value_map[tensor_name] = QDQTensorQuantizedValue(
  822. original_quantized_value, converted_quantized_value, convert_recv_nodes
  823. )
  824. def _quantize_normal_tensors(self):
  825. """
  826. Adds Q/DQ ops to tensors (activations and weights) that have been marked for quantization by op quantizers.
  827. """
  828. for tensor_name, tensor_info in self.tensors_to_quantize.copy().items():
  829. if tensor_name in self.quantized_value_map:
  830. continue
  831. if not tensor_info.is_shared:
  832. # Quantize the input
  833. initializer = find_by_name(tensor_name, self.model.initializer())
  834. if initializer:
  835. self._add_qdq_nodes_for_initializer(initializer)
  836. else:
  837. # Check if this tensor is already a dequantized value. If so, skip it.
  838. # This happens if the original input model already has some pre-quantized weights
  839. # generated by a different tool.
  840. # Ex: (quantized_weight -> DequantizeLinear -> this_tensor)
  841. if tensor_name in self.tensor_to_producing_dq:
  842. del self.tensors_to_quantize[tensor_name]
  843. continue
  844. tensor_qparam_initializers = self._make_tensor_scale_zp_initializers(tensor_name)
  845. if not tensor_qparam_initializers:
  846. raise ValueError(
  847. f"Quantization parameters are not specified for param {tensor_name}. "
  848. "In static mode quantization params for inputs and outputs of nodes to be quantized are required."
  849. )
  850. if tensor_qparam_initializers.converted is None:
  851. # Normal case: <producer> --> Q --> DQ --> <consumers>
  852. self._add_qdq_pair_for_activation(
  853. tensor_name,
  854. tensor_qparam_initializers.original.scale.name,
  855. tensor_qparam_initializers.original.zero_point.name,
  856. data_type=tensor_info.data_type,
  857. )
  858. else:
  859. # Conversion case: <producer> ---> Q1 -+-> DQ1 --> <consumers of original type>
  860. # |
  861. # +-> DQ1' --> Q2 --> DQ2 --> <consumers of converted type>
  862. assert tensor_info.data_type == tensor_qparam_initializers.original.scale.data_type
  863. self._add_qdq_ops_for_converted_activation(
  864. tensor_name,
  865. tensor_qparam_initializers.original.scale.name,
  866. tensor_qparam_initializers.original.zero_point.name,
  867. tensor_info.data_type,
  868. tensor_qparam_initializers.converted.scale.name,
  869. tensor_qparam_initializers.converted.zero_point.name,
  870. tensor_qparam_initializers.converted_recv_nodes,
  871. )
  872. del self.tensors_to_quantize[tensor_name]
  873. def _quantize_sharing_param_tensors(self):
  874. """
  875. Adds Q/DQ ops to tensors that have been marked for quantization by op quantizers.
  876. Only operates on tensors that want to use the quantization parameter initializers from an upstream tensor.
  877. For example, a Transpose node's output tensor will typically want to use the same quantization parameter
  878. initializers as the Transpose node's input.
  879. """
  880. while self.tensors_to_quantize:
  881. for tensor_name, tensor_info in self.tensors_to_quantize.copy().items():
  882. quant_provider = tensor_info.quant_para_provider
  883. if quant_provider and quant_provider.input_name in self.quantized_value_map:
  884. del self.tensors_to_quantize[tensor_name]
  885. quantized_value = self.quantized_value_map[quant_provider.input_name].get_for_consumer(
  886. quant_provider.node_name
  887. )
  888. if self.is_input_a_initializer(tensor_name):
  889. raise ValueError("Quantization parameter shared mode is not supported for weight yet")
  890. if tensor_name in self.tensor_to_producing_dq:
  891. raise ValueError(
  892. f"Quantization parameter sharing is invalid for tensor {tensor_name} "
  893. "because it has already been quantized"
  894. )
  895. # Need to check if this tensor's quant_type is converted for some consumers.
  896. # If so, create new scale/zp initializers for these consumers.
  897. converted_qparam_inits = None
  898. converted_recv_nodes = None
  899. if tensor_name in self.quantization_params:
  900. tensor_params = self.quantization_params[tensor_name]
  901. if tensor_params.converted:
  902. converted_qparam_inits = self._make_scale_zp_initializers(
  903. tensor_name, tensor_params.converted, "_convert"
  904. )
  905. converted_recv_nodes = tensor_params.converted_recv_nodes
  906. if converted_qparam_inits is None:
  907. # Normal case: <producer> --> Q_shared --> DQ_shared --> <consumers>
  908. self._add_qdq_pair_for_activation(
  909. tensor_name, quantized_value.scale_name, quantized_value.zp_name
  910. )
  911. else:
  912. # Conversion case: <producer> ---> Q_shared -+-> DQ_shared --> <consumers of original type>
  913. # |
  914. # +-> DQ_shared' --> Q2 --> DQ2 --> <consumers of converted type>
  915. self._add_qdq_ops_for_converted_activation(
  916. tensor_name,
  917. quantized_value.scale_name,
  918. quantized_value.zp_name,
  919. converted_qparam_inits.scale.data_type,
  920. converted_qparam_inits.scale.name,
  921. converted_qparam_inits.zero_point.name,
  922. converted_recv_nodes,
  923. )
  924. def _quantize_bias_tensors(self):
  925. """
  926. Adds DQ ops (or Cast) for bias tensors that have been marked for quantization by op quantizers.
  927. """
  928. for bias_name, bias_info in self.bias_to_quantize.items():
  929. if bias_name in self.quantized_value_map:
  930. continue
  931. # Quantize the input
  932. self.quantize_bias_static(bias_name, bias_info)
  933. init = find_by_name(bias_name, self.model.initializer())
  934. self.model.remove_initializer(init)
  935. quant_value = self.quantized_value_map[bias_name].original
  936. if quant_value.node_type == "Cast":
  937. # simple cast to float 16 and not DequantizeLinear
  938. # cublasLtMatmul only supports (b)float16, float bias.
  939. if not isinstance(init.data_type, int):
  940. raise TypeError(f"Unexpected type {type(init.data_type)} for input={bias_info.input_name!r}")
  941. node_name = add_dequant_suffix(bias_name)
  942. dequant_node = onnx.helper.make_node(
  943. "Cast",
  944. [quant_value.q_name],
  945. [bias_name],
  946. name=node_name,
  947. to=init.data_type,
  948. )
  949. elif quant_value.node_type in (None, "DequantizeLinear"):
  950. if quant_value.node_qtype in {
  951. onnx.TensorProto.FLOAT16,
  952. onnx.TensorProto.BFLOAT16,
  953. onnx.TensorProto.FLOAT,
  954. }:
  955. raise RuntimeError(f"Unexpected quantize type {quant_value.node_qtype} for DequantizeLinear.")
  956. inputs = [quant_value.q_name, quant_value.scale_name, quant_value.zp_name]
  957. node_name = add_dequant_suffix(bias_name)
  958. if quant_value.axis is not None:
  959. dequant_node = onnx.helper.make_node(
  960. "DequantizeLinear",
  961. inputs,
  962. [bias_name],
  963. node_name,
  964. axis=quant_value.axis,
  965. domain=self.qdq_op_domain,
  966. )
  967. else:
  968. dequant_node = onnx.helper.make_node(
  969. "DequantizeLinear",
  970. inputs,
  971. [bias_name],
  972. node_name,
  973. domain=self.qdq_op_domain,
  974. )
  975. else:
  976. raise RuntimeError(f"Unexpected operator type {quant_value.node_type!r}.")
  977. self.model.add_node(dequant_node)
  978. def is_tensor_quantized(self, tensor_name: str):
  979. return tensor_name in self.tensors_to_quantize or tensor_name in self.bias_to_quantize
  980. def is_tensor_per_channel(
  981. self,
  982. tensor_name: str,
  983. default_axis: int,
  984. op_type: str | None = None,
  985. ) -> tuple[bool, int | None]:
  986. """
  987. Checks if a given tensor is configured to be quantized per-channel. If so, also returns the channel axis.
  988. ORT only supports per-channel quantization on static weights (i.e., ONNX initializers). If the user did not provide
  989. tensor quantization overrides for this tensor, then the value of self.per_channel determines if the weight
  990. is to be quantized per-channel.
  991. Params:
  992. tensor_name: The name of the tensor to check.
  993. default_axis: The default channel axis. This method checks if the normalized axis is within bounds.
  994. Can be overridden via the extra_options 'QDQOpTypePerChannelSupportToAxis'
  995. and 'TensorQuantOverrides'.
  996. op_type: Optional, defaults to None. The operator type that is the only consumer of this weight.
  997. Used to access the extra option 'QDQOpTypePerChannelSupportToAxis'.
  998. Returns:
  999. A tuple (is_per_channel, axis) in which the first element indicates whether the tensor is
  1000. quantized per-channel and the second element is the channel axis.
  1001. The returned axis is only None if the tensor is not per-channel or the axis is out of bounds.
  1002. """
  1003. weight_initializer = self.initializers.get(tensor_name)
  1004. if weight_initializer is None:
  1005. return False, None # Only support per-channel weights
  1006. if self.tensor_quant_overrides.has_per_tensor_overrides(tensor_name):
  1007. return False, None # User provided per-tensor overrides for this initializer
  1008. has_per_chan_overrides = self.tensor_quant_overrides.has_per_channel_overrides(tensor_name)
  1009. if not self.per_channel and not has_per_chan_overrides:
  1010. return False, None # global self.per_channel is off and user did not provide per-channel overrides.
  1011. axis = self.qdq_op_type_per_channel_support_to_axis.get(op_type, default_axis) if op_type else default_axis
  1012. if has_per_chan_overrides:
  1013. per_chan_overrides = self.tensor_quant_overrides.get_per_channel_overrides(tensor_name)
  1014. axis = per_chan_overrides[0]["axis"] # Prefer axis from user-specified tensor-level overrides if available
  1015. weight_rank = len(weight_initializer.dims)
  1016. axis_valid, axis = normalize_axis(axis, weight_rank)
  1017. if not axis_valid:
  1018. logging.warning(f"Axis {axis} is out-of-range for weight '{tensor_name}' with rank {weight_rank}")
  1019. return False, None
  1020. return True, axis
  1021. def _get_tensor_quantization_scale(self, tensor_name: str, consumer_node_name: str) -> np.ndarray | None:
  1022. """
  1023. Returns the quantization scale of a tensor that is consumed by the given node.
  1024. :parameter tensor_name: The name of the tensor.
  1025. :parameter consumer_node_name: The name of the node that consumes the tensor as input. Necessary in case
  1026. the quantization type of the tensor was converted.
  1027. Refer: QDQQuantizer::_add_qdq_ops_for_converted_activation.
  1028. :returns: The quantization scale or None.
  1029. """
  1030. initializers = self.model.initializer()
  1031. scale_initializer: onnx.TensorProto | None = None
  1032. if tensor_name in self.quantized_value_map:
  1033. # Tensor was quantized by this tool, so get scale from initializer created by this tool run.
  1034. scale_name = self.quantized_value_map[tensor_name].get_for_consumer(consumer_node_name).scale_name
  1035. scale_initializer = find_by_name(scale_name, initializers)
  1036. else:
  1037. # Tensor was already quantized in original model, so get scale from DQ node that outputs the tensor.
  1038. dq_node = self.tensor_to_producing_dq.get(tensor_name, None)
  1039. if dq_node:
  1040. scale_initializer = find_by_name(dq_node.input[1], initializers)
  1041. return tensor_proto_to_array(scale_initializer) if scale_initializer is not None else None
  1042. def quantize_bias_static(self, bias_name: str, bias_info: QDQBiasQuantInfo) -> str:
  1043. """
  1044. Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale
  1045. """
  1046. # Handle case where bias already in quantization map
  1047. if bias_name in self.quantized_value_map:
  1048. return self.quantized_value_map[bias_name].original.q_name
  1049. # get scale for weight.
  1050. weight_scale = self._get_tensor_quantization_scale(bias_info.weight_name, bias_info.node_name)
  1051. if weight_scale is None:
  1052. raise ValueError(
  1053. f"Unable to get valid quantization scale for weight input '{bias_info.weight_name}' "
  1054. f"when quantizing bias '{bias_name}' to int32."
  1055. )
  1056. # get scale for input.
  1057. input_scale = self._get_tensor_quantization_scale(bias_info.input_name, bias_info.node_name)
  1058. if input_scale is None:
  1059. raise ValueError(
  1060. f"Unable to get valid quantization scale for input '{bias_info.input_name}' "
  1061. f"when quantizing bias '{bias_name}' to int32."
  1062. )
  1063. (
  1064. quantized_bias_name,
  1065. quantized_bias_scale_name,
  1066. quantized_bias_zp_name,
  1067. bias_scale_data,
  1068. node_type,
  1069. node_qtype,
  1070. ) = self.quantize_bias_static_impl(bias_name, input_scale, weight_scale, bias_info.beta)
  1071. quantized_value = QuantizedValue(
  1072. bias_name,
  1073. quantized_bias_name,
  1074. quantized_bias_scale_name,
  1075. quantized_bias_zp_name,
  1076. QuantizedValueType.Initializer,
  1077. 0 if bias_scale_data.size > 1 else None,
  1078. node_type=node_type,
  1079. node_qtype=node_qtype,
  1080. )
  1081. self.quantized_value_map[bias_name] = QDQTensorQuantizedValue(quantized_value, None, None)
  1082. return quantized_bias_name
  1083. def _make_scale_zp_initializers(
  1084. self, param_name: str, quant_params: QuantizationParams, init_name_suffix: str = ""
  1085. ) -> QDQScaleZpInitializers:
  1086. """
  1087. Creates and returns scale and zero-point initializers for the given quantization params. The initializers are
  1088. named:
  1089. - {param_name}_zero_point{init_name_suffix}
  1090. - {param_name}_scale{init_name_suffix}
  1091. """
  1092. zero_point = quant_params["zero_point"]
  1093. scale = quant_params["scale"]
  1094. zero_point_type = quant_params["quant_type"]
  1095. axis: int | None = quant_params.get("axis")
  1096. assert (axis is not None and len(scale.shape) == 1) or (axis is None and len(scale.shape) == 0), (
  1097. "Wrong scale/zp shapes"
  1098. )
  1099. assert len(scale.shape) == len(zero_point.shape), "Scale and zero-point must have the same rank"
  1100. zero_point_name = param_name + "_zero_point" + init_name_suffix
  1101. scale_name = param_name + "_scale" + init_name_suffix
  1102. # Add initializers to model
  1103. init_zp = onnx.helper.make_tensor(
  1104. zero_point_name, zero_point_type, zero_point.shape, zero_point.ravel().tolist()
  1105. )
  1106. self.model.add_initializer(init_zp)
  1107. if scale.dtype == np.float32:
  1108. scale_type = onnx_proto.TensorProto.FLOAT
  1109. elif scale.dtype == np.float16:
  1110. scale_type = onnx_proto.TensorProto.FLOAT16
  1111. else:
  1112. raise ValueError(f"Unexpected dtype={scale.dtype} for param_name={param_name!r}")
  1113. init_scale = onnx.helper.make_tensor(scale_name, scale_type, scale.shape, scale.ravel().tolist())
  1114. self.model.add_initializer(init_scale)
  1115. return QDQScaleZpInitializers(init_scale, init_zp)
  1116. def _make_tensor_scale_zp_initializers(self, tensor_name: str) -> QDQTensorScaleZpInitializers | None:
  1117. """
  1118. Create and returns all scale/zero_point initializers for a given tensor. If the tensor is converted
  1119. to a different quantization type, this function creates two pairs of zp/scale initializers. Otherwise,
  1120. only one pair of zp/scale initializers is created.
  1121. """
  1122. if self.quantization_params is None or tensor_name not in self.quantization_params:
  1123. logging.info(f'Quantization parameters for tensor:"{tensor_name}" not specified')
  1124. return None
  1125. tensor_params = self.quantization_params[tensor_name]
  1126. if not isinstance(tensor_params, QDQTensorQuantParams):
  1127. raise TypeError(f"Unexpected type {type(tensor_params)} for {tensor_name!r}.")
  1128. original_inits = self._make_scale_zp_initializers(tensor_name, tensor_params.original)
  1129. converted_inits = (
  1130. self._make_scale_zp_initializers(tensor_name, tensor_params.converted, "_convert")
  1131. if tensor_params.converted
  1132. else None
  1133. )
  1134. return QDQTensorScaleZpInitializers(original_inits, converted_inits, tensor_params.converted_recv_nodes)
  1135. def calc_quant_params(self, tensor_data: TensorData, quant_overrides: dict[str, Any]) -> QuantizationParams:
  1136. """
  1137. Calculates quantization parameters (scale/zero-point) given a tensor's min/max range and optional
  1138. user-provided overrides.
  1139. """
  1140. quant_type = self.activation_qType
  1141. if "quant_type" in quant_overrides:
  1142. quant_type = quant_overrides["quant_type"].tensor_type
  1143. if "scale" in quant_overrides and "zero_point" in quant_overrides:
  1144. zero, scale = quant_overrides["zero_point"], quant_overrides["scale"]
  1145. elif quant_type == onnx.TensorProto.FLOAT8E4M3FN:
  1146. zero, scale = compute_scale_zp_float8(quant_type, tensor_data.avg_std[1])
  1147. else:
  1148. rmin = quant_overrides.get("rmin", tensor_data.range_value[0])
  1149. rmax = quant_overrides.get("rmax", tensor_data.range_value[1])
  1150. symmetric = quant_overrides.get("symmetric", self.is_activation_symmetric)
  1151. reduce_range = quant_overrides.get("reduce_range", False)
  1152. qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range=reduce_range, symmetric=symmetric)
  1153. zero, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, self.min_real_range)
  1154. return QuantizationParams(zero_point=zero.squeeze(), scale=scale.squeeze(), quant_type=quant_type)
  1155. def calc_graph_quant_params(self) -> dict[str, QDQTensorQuantParams]:
  1156. """
  1157. Calculates quantization parameters (scale/zero-point) for all tensors in the graph using each tensor's min/max range
  1158. and optional user-provided overrides.
  1159. """
  1160. if self.tensors_range is None:
  1161. return {}
  1162. self.adjust_tensor_ranges()
  1163. quantization_params = {}
  1164. for tensor_name in self.tensors_range:
  1165. td = self.tensors_range[tensor_name]
  1166. if not isinstance(td, TensorData):
  1167. raise TypeError(f"Unexpected type {type(td)} for {tensor_name!r}.")
  1168. quant_overrides = self.tensor_quant_overrides.get_per_tensor_overrides(tensor_name, default_val={})
  1169. original = self.calc_quant_params(td, quant_overrides)
  1170. converted = None
  1171. converted_recv_nodes = None
  1172. if "convert" in quant_overrides:
  1173. converted = self.calc_quant_params(td, quant_overrides["convert"])
  1174. converted_recv_nodes = quant_overrides["convert"].get("recv_nodes")
  1175. quantization_params[tensor_name] = QDQTensorQuantParams(original, converted, converted_recv_nodes)
  1176. return quantization_params
  1177. def _calc_initializer_quant_params(self) -> dict[str, QuantizationParams]:
  1178. """
  1179. Returns quantization parameters (scale/zero_point/quant_type) for all initializers.
  1180. """
  1181. quantization_params: dict[str, QuantizationParams] = {}
  1182. for tensor_name, tensor_info in self.tensors_to_quantize.items():
  1183. initializer = find_by_name(tensor_name, self.model.initializer())
  1184. if not initializer:
  1185. continue
  1186. initializer_data = tensor_proto_to_array(initializer)
  1187. initializer_rank = len(initializer_data.shape)
  1188. # initializers for elementwise ops use the quant_type for activations.
  1189. is_weight = tensor_info.tensor_type is QDQQuantTensorType.WEIGHT
  1190. quant_type = self.weight_qType if is_weight else self.activation_qType
  1191. # Try to get scale/zp directly from user's overrides and avoid computation.
  1192. if self.tensor_quant_overrides.overrides_scale_zp(tensor_name):
  1193. overrides = self.tensor_quant_overrides[tensor_name]
  1194. if "quant_type" in overrides[0]:
  1195. quant_type = overrides[0]["quant_type"].tensor_type
  1196. zp_dtype = ONNX_TYPE_TO_NP_TYPE[quant_type]
  1197. is_per_channel = "axis" in overrides[0]
  1198. if not is_per_channel:
  1199. quantization_params[tensor_name] = QuantizationParams(
  1200. zero_point=np.array(overrides[0]["zero_point"], dtype=zp_dtype),
  1201. scale=np.array(overrides[0]["scale"], initializer_data.dtype),
  1202. quant_type=quant_type,
  1203. )
  1204. else:
  1205. zero_points_list = []
  1206. scales_list = []
  1207. for chan_overrides in overrides:
  1208. zero_points_list.append(np.array(chan_overrides["zero_point"], zp_dtype))
  1209. scales_list.append(np.array(chan_overrides["scale"], dtype=initializer_data.dtype))
  1210. channel_axis = overrides[0]["axis"]
  1211. is_axis_valid, norm_channel_axis = normalize_axis(channel_axis, initializer_rank)
  1212. if not is_axis_valid:
  1213. raise ValueError(
  1214. f"Weight {initializer.name} has a per-channel axis with value {channel_axis} that is "
  1215. f"out-of-bounds for rank {initializer_rank}"
  1216. )
  1217. quantization_params[tensor_name] = QuantizationParams(
  1218. zero_point=np.array(zero_points_list),
  1219. scale=np.array(scales_list),
  1220. quant_type=quant_type,
  1221. axis=norm_channel_axis,
  1222. )
  1223. continue
  1224. # Compute scale/zp normally. User's overrides may still override parameters
  1225. # used to compute the scale/zp (e.g., rmin, rmax, symmetric, etc.)
  1226. overrides = self.tensor_quant_overrides.get(tensor_name, [{}])
  1227. if "quant_type" in overrides[0]:
  1228. quant_type = overrides[0]["quant_type"].tensor_type
  1229. channel_axis = overrides[0].get("axis", tensor_info.axis)
  1230. is_per_channel = channel_axis is not None
  1231. # Note: always quantize per-channel initializers as symmetric because QLinear* ops require the
  1232. # same zero-point in every channel, which is necessarily the case for symmetric quantization.
  1233. is_symmetric_default = is_per_channel or (
  1234. self.is_weight_symmetric(quant_type) if is_weight else self.is_activation_symmetric
  1235. )
  1236. is_symmetric = overrides[0].get("symmetric", is_symmetric_default)
  1237. reduce_range = overrides[0].get("reduce_range", self.reduce_range)
  1238. zero_point: np.ndarray | None = None
  1239. scale: np.ndarray | None = None
  1240. if not is_per_channel:
  1241. zero_point, scale = compute_data_quant_params(
  1242. initializer_data.flatten(),
  1243. quant_type,
  1244. is_symmetric,
  1245. reduce_range=reduce_range,
  1246. min_real_range=self.min_real_range,
  1247. rmin_override=overrides[0].get("rmin"),
  1248. rmax_override=overrides[0].get("rmax"),
  1249. )
  1250. else:
  1251. is_axis_valid, norm_channel_axis = normalize_axis(channel_axis, initializer_rank)
  1252. if not is_axis_valid:
  1253. raise ValueError(
  1254. f"Weight {initializer.name} has a per-channel axis with value {channel_axis} that is "
  1255. f"out-of-bounds for rank {initializer_rank}"
  1256. )
  1257. channel_axis = norm_channel_axis
  1258. channel_count = initializer_data.shape[channel_axis]
  1259. zero_points_list = []
  1260. scales_list = []
  1261. for i in range(channel_count):
  1262. per_channel_data = initializer_data.take(i, channel_axis)
  1263. channel_overrides = overrides[i] if overrides and i < len(overrides) else {}
  1264. channel_zero_point, channel_scale = compute_data_quant_params(
  1265. per_channel_data.ravel(),
  1266. quant_type,
  1267. is_symmetric,
  1268. reduce_range=reduce_range,
  1269. min_real_range=self.min_real_range,
  1270. rmin_override=channel_overrides.get("rmin"),
  1271. rmax_override=channel_overrides.get("rmax"),
  1272. )
  1273. zero_points_list.append(channel_zero_point)
  1274. scales_list.append(channel_scale)
  1275. zero_point = np.asarray(zero_points_list)
  1276. scale = np.asarray(scales_list)
  1277. quantization_params[tensor_name] = QuantizationParams(
  1278. zero_point=zero_point,
  1279. scale=scale,
  1280. quant_type=quant_type,
  1281. axis=channel_axis,
  1282. )
  1283. return quantization_params