| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477 |
- # -------------------------------------------------------------------------
- # Copyright (c) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License. See License.txt in the project root for
- # license information.
- # --------------------------------------------------------------------------
- from __future__ import annotations
- import logging
- from dataclasses import dataclass
- from enum import Enum
- from typing import Any
- import numpy as np
- import onnx
- from onnx import TensorProto
- from onnx import onnx_pb as onnx_proto
- from .base_quantizer import BaseQuantizer, QuantizationParams
- from .calibrate import TensorData
- from .quant_utils import (
- DEQUANT_OP_NAME,
- ONNX_TYPE_TO_NP_TYPE,
- QUANT_OP_NAME,
- QuantizedValue,
- QuantizedValueType,
- __producer__,
- __version__,
- add_dequant_output_suffix,
- add_dequant_suffix,
- add_quant_input_suffix,
- add_quant_output_suffix,
- add_quant_suffix,
- compute_data_quant_params,
- compute_scale_zp,
- compute_scale_zp_float8,
- find_by_name,
- get_qmin_qmax_for_qType,
- ms_domain,
- normalize_axis,
- quantize_onnx_initializer,
- tensor_proto_to_array,
- )
- from .registry import CreateQDQQuantizer
- class QDQQuantTensorType(Enum):
- ACTIVATION = 0
- WEIGHT = 1
- BIAS = 2
- # Holds the name of the node input from which a node output will share the
- # same quantization param initializers (zero-point and scale initializers).
- # Ex: A Transpose node's output will use the same quant param initializers used at the input.
- @dataclass
- class QDQQuantParamProvider:
- input_name: str
- node_name: str
- # Holds information for tensors that have been marked for quantization by operator quantizers.
- # Does not hold information for bias tensors.
- class QDQTensorQuantInfo:
- def __init__(self, tensor_type=QDQQuantTensorType.ACTIVATION, quant_para_provider=None, axis=None, data_type=None):
- self.tensor_type = tensor_type
- self.quant_para_provider = quant_para_provider
- self.axis = axis
- self.is_shared = quant_para_provider is not None
- assert data_type is not None
- self.data_type = data_type
- # Holds information for bias tensors that have been marked for quantization by operator quantizers.
- @dataclass
- class QDQBiasQuantInfo:
- node_name: str
- input_name: str
- weight_name: str
- beta: float
- # Holds quantization parameter values (scale, zp) for a tensor.
- # A tensor typically has a one set of quantization parameters, unless the tensor is
- # at a "mixed-precision" boundary where the activation quantization type changes (e.g., from uint8 to uint16).
- @dataclass
- class QDQTensorQuantParams:
- original: QuantizationParams # Generated by producer node.
- converted: QuantizationParams | None # Converted type consumed by some (or all/none) consumer nodes.
- converted_recv_nodes: set[str] | None # The name of nodes that consume the converted type.
- def get_for_consumer(self, consumer_node_name) -> QuantizationParams:
- if self.converted is None: # Quantized value is not converted, return original
- return self.original
- if self.converted_recv_nodes is None: # All consumers receive the converted value
- return self.converted
- # Check if consumer node name is in the list of nodes that
- # receive the converted quantization value. If not, return the original value generated
- # by the tensor's producer.
- return self.converted if (consumer_node_name in self.converted_recv_nodes) else self.original
- # Holds scale and zero_point initializer TensorProtos.
- @dataclass
- class QDQScaleZpInitializers:
- scale: TensorProto
- zero_point: TensorProto
- # Holds all scale and zero-point initializers for a tensor.
- # A tensor typically has a one set of quantization parameters, unless the tensor is
- # at a "mixed-precision" boundary where the activation quantization type changes (e.g., from uint8 to uint16).
- @dataclass
- class QDQTensorScaleZpInitializers:
- original: QDQScaleZpInitializers
- converted: QDQScaleZpInitializers | None
- converted_recv_nodes: set[str] | None
- # Holds cached information of a tensor's quantized values (types, zp/scale initializer names, etc.).
- # A tensor typically has a one set of quantization parameters, unless the tensor is
- # at a "mixed-precision" boundary where the activation quantization type changes (e.g., from uint8 to uint16).
- @dataclass
- class QDQTensorQuantizedValue:
- original: QuantizedValue
- converted: QuantizedValue | None
- converted_recv_nodes: set[str] | None
- def get_for_consumer(self, consumer_node_name) -> QuantizedValue:
- if self.converted is None: # Quantized value is not converted, return original
- return self.original
- if self.converted_recv_nodes is None: # All consumers receive the converted value
- return self.converted
- # Check if consumer node name is in the list of nodes that
- # receive the converted quantization value. If not, return the original value generated
- # by the tensor's producer.
- return self.converted if (consumer_node_name in self.converted_recv_nodes) else self.original
- class QDQQuantizer(BaseQuantizer):
- def __init__(
- self,
- model,
- per_channel,
- reduce_range,
- weight_qType,
- activation_qType,
- tensors_range,
- nodes_to_quantize,
- nodes_to_exclude,
- op_types_to_quantize,
- extra_options=None,
- ):
- BaseQuantizer.__init__(
- self,
- model,
- per_channel,
- reduce_range,
- weight_qType,
- activation_qType,
- tensors_range,
- nodes_to_quantize,
- nodes_to_exclude,
- op_types_to_quantize,
- extra_options,
- )
- self.tensors_to_quantize: dict[str, QDQTensorQuantInfo] = {}
- self.bias_to_quantize: dict[str, QDQBiasQuantInfo] = {}
- self.nodes_to_remove = []
- # Specific op types to exclude qdq quantization for their outputs.
- # In TRT, it's not recommended to quantize outputs for weighted ops such as Conv, Matmul, Gemm
- # because those ops may be followed by nodes that require high resolution inputs.
- # Adding QDQ for those ops' output may end up with worse accuracy.
- # So, we don't recommend to add QDQ to node's output under such condition.
- self.op_types_to_exclude_output_quantization = extra_options.get("OpTypesToExcludeOutputQuantization", [])
- # We do quantization on Dequantizelinear's input to remove Quantizelinear for weight as an optimization.
- # In some cases, for example QDQ BERT model for TensorRT, QDQ should always appear as a pair.
- # Therefore, we need to disable this optimization and add qdq pair to weight.
- self.add_qdq_pair_to_weight = extra_options.get("AddQDQPairToWeight", False)
- # Some scenarios do not need the bias quantized. For example, in the case of Quantization Aware Training,
- # quantizing the bias is not needed. This is because in QAT, all model parameters are expected to be in
- # floating point format. To that end, we can use the FakeQuant operator for weights and activations that
- # can always have QDQ pairs (by using AddQDQPairToWeight). But for biases in a quantized model, we can't use
- # FakeQuant because it only ever appears before a DQ (since it is quantized as int32).
- self.quantize_bias = extra_options.get("QuantizeBias", True)
- # The default behavior is that multiple nodes can share a QDQ pair as their inputs.
- # In TRT, QDQ pair can`t be shared between nodes, so it will create dedicated QDQ pairs for each node.
- self.dedicated_qdq_pair = extra_options.get("DedicatedQDQPair", False)
- self.tensor_to_its_receiving_nodes: dict[str, list[onnx.NodeProto]] = {}
- # Maps a tensor to the DequantizeLinear node (in the original input model) that outputs the tensor.
- # Populated for input models with some pre-quantized weights (typically via a different tool).
- self.tensor_to_producing_dq: dict[str, onnx.NodeProto] = {}
- # Let user set channel axis for specific op type and it's effective only when per channel quantization is supported and per_channel is True.
- self.qdq_op_type_per_channel_support_to_axis = extra_options.get("QDQOpTypePerChannelSupportToAxis", {})
- self.qdq_op_domain = ms_domain if extra_options.get("UseQDQContribOps", False) else None
- # User can specify if removable activations, like Clip/Relu, should be kept in the graph.
- # Used in the QDQRemovableActivation class.
- self.qdq_keep_removable_activations = extra_options.get("QDQKeepRemovableActivations", False)
- # Let user disable adjustment of weight scales for bias inputs that are quantized to int32.
- self.qdq_disable_weight_adjust_for_int32_bias = extra_options.get("QDQDisableWeightAdjustForInt32Bias", False)
- # The ONNX spec did not support 16-bit Q/DQ ops before opset 21.
- # So, may have to override the Q/DQ op domain to 'com.microsoft' if the activation or weight types
- # are 16-bit or 4-bit integers.
- if self.opset_version < 21:
- opset21_types = (TensorProto.UINT16, TensorProto.INT16, TensorProto.UINT4, TensorProto.INT4)
- overrides_have_opset21_types = any(
- t.tensor_type in opset21_types for t in self.tensor_quant_override_qtypes
- )
- if not self.qdq_op_domain and (
- self.activation_qType in opset21_types
- or self.weight_qType in opset21_types
- or overrides_have_opset21_types
- ):
- logging.warning(
- "ONNX QuantizeLinear and DequantizeLinear operators do not support "
- "16-bit/4-bit integer quantization types prior to opset 21. "
- f"The domain of QuantizeLinear and DequantizeLinear operators will be set to '{ms_domain}' to "
- "enable support."
- )
- self.qdq_op_domain = ms_domain
- self.quantization_params = self.calc_graph_quant_params()
- self.initializer_quant_params: dict[str, QuantizationParams] = {}
- # Map of all original value names to quantized value names
- self.quantized_value_map = {}
- def _get_tensor_type(self, tensor_name):
- """
- Check if tensor can be quantized
- """
- weight = find_by_name(tensor_name, self.model.initializer())
- if weight is not None:
- return weight.data_type
- elif tensor_name in self.value_infos:
- vi = self.value_infos[tensor_name]
- if vi.type.HasField("tensor_type"):
- return vi.type.tensor_type.elem_type
- return None
- def _is_tensor_quantizable(self, tensor_name):
- """
- Check if tensor can be quantized
- """
- weight = find_by_name(tensor_name, self.model.initializer())
- if weight is not None:
- if weight.data_type in (onnx_proto.TensorProto.FLOAT, onnx_proto.TensorProto.FLOAT16):
- return True
- elif tensor_name in self.value_infos:
- vi = self.value_infos[tensor_name]
- if vi.type.HasField("tensor_type") and vi.type.tensor_type.elem_type in (
- TensorProto.FLOAT,
- TensorProto.FLOAT16,
- ):
- return True
- else:
- logging.warning(
- f"failed to infer the type of tensor: {tensor_name}. Skip to quantize it. Please check if it is expected."
- )
- return False
- def __quantize_tensor(self, tensor_name, quant_sharing_provider=None, tensor_type=QDQQuantTensorType.ACTIVATION):
- """
- Adds a tensor to the list (actually a dict) of tensors to quantize. Called indirectly by op quantizers that
- want to quantize a tensor (i.e., "mark" a tensor for quantization).
- If quant_sharing_provider is not None, tensor with name tensor_name will be quantized with the same
- quantization parameters as the node input specified in quant_sharing_provider. Ex: A Tranpose node's output
- will typically use the same quantization parameter initializers used at the Transpose node's input.
- Args:
- tensor_name: name of the tensor to quantize
- quant_sharing_provider: name of the tensor and node that provides quantization parameter
- tensor_type: QDQQuantTensorType default ACTIVATION
- """
- if self._is_tensor_quantizable(tensor_name):
- if quant_sharing_provider:
- if not isinstance(quant_sharing_provider, QDQQuantParamProvider):
- raise TypeError(
- f"quant_sharing_provider must be of type QDQQuantParamProvider, not {type(quant_sharing_provider)}."
- )
- data_type = self._get_tensor_type(tensor_name)
- self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo(
- tensor_type=tensor_type, quant_para_provider=quant_sharing_provider, data_type=data_type
- )
- elif tensor_name not in self.tensors_to_quantize:
- data_type = self._get_tensor_type(tensor_name)
- self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo(tensor_type=tensor_type, data_type=data_type)
- def quantize_activation_tensor(self, tensor_name: str):
- """
- Adds a tensor to the list of tensors to quantize. Called by op quantizers that
- want to quantize a tensor (i.e., "mark" a tensor for quantization).
- Args:
- tensor_name: name of the tensor to quantize
- """
- return self.__quantize_tensor(tensor_name, None, QDQQuantTensorType.ACTIVATION)
- def quantize_output_same_as_input(self, output_name: str, input_name: str, node_name: str):
- """
- Adds a tensor to the list of tensors to quantize. Called by op quantizers that
- want to quantize an output tensor using the same quantization parameters as one of the node's inputs.
- Ex: A Tranpose node's output will typically use the same quantization parameter initializers used at
- the Transpose node's input.
- Args:
- output_name: name of the node output to quantize so that it uses the same quantization params as an input.
- input_name: name of the node input from which the output tensor will get its quantization params.
- node_name: name of the node that consumes `input_name`.
- """
- return self.__quantize_tensor(
- output_name, QDQQuantParamProvider(input_name, node_name), QDQQuantTensorType.ACTIVATION
- )
- def quantize_weight_tensor(self, tensor_name: str):
- """
- Adds a tensor to the list of weight tensors to quantize. Called by op quantizers that
- want to quantize a weight (i.e., "mark" a weight for quantization).
- Args:
- tensor_name: name of the weight to quantize
- """
- return self.__quantize_tensor(tensor_name, None, QDQQuantTensorType.WEIGHT)
- def quantize_weight_tensor_per_channel(self, tensor_name, axis):
- weight = find_by_name(tensor_name, self.model.initializer())
- if weight:
- if weight.data_type in (onnx_proto.TensorProto.FLOAT, onnx_proto.TensorProto.FLOAT16):
- self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo(
- tensor_type=QDQQuantTensorType.WEIGHT, axis=axis, data_type=weight.data_type
- )
- else:
- logging.warning(f"only support per-channel quantization on weight. Tensor: {tensor_name} is not quantized.")
- def _dup_initializer(self, initializer: onnx.TensorProto) -> onnx.TensorProto:
- """
- Duplicates an existing initializer and adds it to the model. Returns the new initializer.
- """
- name_suffix: int = self.model.get_largest_initializer_name_suffix(initializer.name) + 1
- new_initializer_name = f"{initializer.name}{name_suffix}"
- new_initializer = onnx.TensorProto()
- new_initializer.CopyFrom(initializer)
- new_initializer.name = new_initializer_name
- self.model.add_initializer(new_initializer)
- return new_initializer
- def quantize_bias_tensor(self, node_name, bias_name, input_name, weight_name, beta=1.0):
- """
- Adds a bias tensor to the list of bias tensors to quantize. Called by op quantizers that
- want to quantize a bias with bias_zero_point = 0 and bias_scale = input_scale * weight_scale * beta.
- TODO: Explain the reasoning for using this formula.
- Args:
- node_name: name of the node that consumes the bias, input, and weight tensors.
- bias_name: name of the bias tensor to quantize.
- input_name: name of the input tensor whose scale is used to compute the bias's scale.
- weight_name: name of the weight tensor whose scale is used to compute the bias's scale.
- beta: Multiplier used to compute the bias's scale.
- """
- # If the user provided quantization overrides for this tensor, treat it as a regular weight.
- if self.tensor_quant_overrides.get(bias_name):
- logging.info(
- f"Quantizing bias tensor '{bias_name}' as a weight due to the presence of user-specified overrides"
- )
- is_per_channel, axis = self.is_tensor_per_channel(bias_name, default_axis=0)
- if is_per_channel:
- self.quantize_weight_tensor_per_channel(bias_name, axis)
- else:
- self.quantize_weight_tensor(bias_name)
- return
- bias_initializer = find_by_name(bias_name, self.model.initializer())
- if bias_initializer is None:
- logging.warning(f"Expected bias '{bias_name}' to be an initializer")
- return
- if bias_initializer.data_type not in (onnx_proto.TensorProto.FLOAT, onnx_proto.TensorProto.FLOAT16):
- logging.info(f"Expected bias '{bias_name}' to be an floating-point initializer")
- return
- actual_bias_name = bias_name
- if bias_name in self.bias_to_quantize:
- # This bias input is consumed by two different nodes. We need to duplicate the bias so that
- # each node has its own bias input. This is necessary because the bias's scale is computed
- # from the node's other input scales.
- new_bias_initializer = self._dup_initializer(bias_initializer)
- actual_bias_name = new_bias_initializer.name
- # Replace this node's bias input
- self.model.replace_input_of_nodes(bias_name, actual_bias_name, {node_name})
- logging.info(f"Created a copy of bias input '{bias_name}' called '{actual_bias_name}'")
- # Add this to our list of biases to quantize.
- self.bias_to_quantize[actual_bias_name] = QDQBiasQuantInfo(node_name, input_name, weight_name, beta)
- def _adjust_weight_scale_for_int32_bias(
- self,
- input_scale: np.ndarray,
- weight_scale: np.ndarray,
- weight_name: str,
- bias_tp: onnx.TensorProto,
- is_per_channel: bool,
- ) -> tuple[bool, np.ndarray | None]:
- """
- Checks if the bias scale (input_scale * weight_scale) that we intend to use is too small.
- A bias scale that is too small leads to quantized bias values that fall outside the range of a int32 and have to
- be clipped, which decreases accuracy. If this function detects such a scenario, the weight_scale value will be
- increased to prevent this from happening.
- Although the adjustment method and amount differs, the idea to adjust the weight's scale came from the following
- reference:
- https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/optimize/quantization_utils.cc#L252
- :param input_scale: The input's scale.
- :param weight_scale: The weight scale to potentially adjust.
- :param weight_name: The weight initializer's name. Used for logging.
- :param bias_tp: The bias ONNX initializer.
- :param is_per_channel: True if the bias and weight are quantized per-channel.
- :return: A tuple with a bool indicating if the weight's scale was adjusted and the new weight scale.
- """
- if not weight_scale.size:
- return False, None
- bias_float_data = tensor_proto_to_array(bias_tp)
- int32_info = np.iinfo(np.int32)
- multiplicative_epsilon = 1.0001
- qrange = np.array(int32_info.max, dtype=np.float64) - np.array(int32_info.min + 1, dtype=np.float64)
- weight_scale_dtype = weight_scale.dtype
- updated_an_elem = False
- if not is_per_channel:
- rmin = np.minimum(bias_float_data.min(), np.array(0, dtype=np.float64))
- rmax = np.maximum(bias_float_data.max(), np.array(0, dtype=np.float64))
- absmax = np.maximum(np.abs(rmin), np.abs(rmax))
- bias_smallest_valid_scale = multiplicative_epsilon * (2.0 * absmax) / qrange
- input_scale_fp64 = np.array(input_scale.item(), dtype=np.float64)
- weight_scale_fp64 = np.array(weight_scale.item(), dtype=np.float64)
- bias_candidate_scale = input_scale_fp64 * weight_scale_fp64
- if (bias_candidate_scale < bias_smallest_valid_scale) and (bias_candidate_scale > 0.0):
- # The candidate bias scale would be too small, so increase the weight_scale by the necessary ratio.
- ratio = bias_smallest_valid_scale / bias_candidate_scale
- logging.info(
- f"Increasing scale for weight `{weight_name}` by the ratio {ratio} to "
- f"ensure bias input `{bias_tp.name}` has a valid scale."
- )
- new_scale = weight_scale_fp64 * ratio
- weight_scale = new_scale.astype(weight_scale_dtype)
- updated_an_elem = True
- elif weight_scale.shape and len(weight_scale.shape) == 1:
- # per-channel case
- num_elems = weight_scale.shape[0]
- for i in range(num_elems):
- bias_rmax = np.abs(bias_float_data[i])
- bias_smallest_valid_scale = multiplicative_epsilon * (2.0 * bias_rmax) / qrange
- input_scale_fp64 = np.array(input_scale.item(), dtype=np.float64)
- weight_scale_fp64 = np.array(weight_scale[i].item(), dtype=np.float64)
- bias_candidate_scale = input_scale_fp64 * weight_scale_fp64
- if (bias_candidate_scale < bias_smallest_valid_scale) and (bias_candidate_scale > 0.0):
- # The candidate bias scale would be too small, so increase the weight_scale by the necessary ratio.
- ratio = bias_smallest_valid_scale / bias_candidate_scale
- logging.info(
- f"Increased scale[{i}] for weight `{weight_name}` by ratio {ratio} "
- f"to ensure bias input `{bias_tp.name}` has a valid scale."
- )
- new_scale = weight_scale_fp64 * ratio
- weight_scale[i] = new_scale.astype(weight_scale_dtype)
- updated_an_elem = True
- return updated_an_elem, weight_scale
- def _adjust_weight_quant_params_for_bias_tensors(self):
- """
- Iterates through all bias inputs that should be quantized to int32. If the intended
- bias scale (equal to input_scale * weight_scale) is too small, this function will increase
- the associated weight's scale to ensure the bias does not overflow the int32 range when quantized.
- """
- if self.qdq_disable_weight_adjust_for_int32_bias:
- # User passed an extra_option to disable this adjustment.
- return
- for bias_name, bias_info in self.bias_to_quantize.items():
- if (
- bias_info.input_name not in self.quantization_params
- or bias_info.input_name not in self.tensors_to_quantize
- or bias_info.weight_name not in self.initializer_quant_params
- ):
- continue
- # Get the associated input's scale.
- input_qparams = self.quantization_params[bias_info.input_name].get_for_consumer(bias_info.node_name)
- input_info = self.tensors_to_quantize[bias_info.input_name]
- input_scale = np.asarray(
- input_qparams["scale"], dtype=onnx.helper.tensor_dtype_to_np_dtype(input_info.data_type)
- )
- weight_quant_params = self.initializer_quant_params[bias_info.weight_name]
- weight_quant_type = weight_quant_params["quant_type"]
- if weight_quant_type not in (onnx.TensorProto.INT8, onnx.TensorProto.INT16):
- continue
- weight_zero_point: np.ndarray = weight_quant_params["zero_point"]
- if weight_zero_point.any():
- # Skip if zero_point(s) are not all zero (i.e., symmetric quant)
- continue
- weight_scale: np.ndarray = weight_quant_params["scale"]
- is_per_channel = weight_quant_params.get("axis", None) is not None
- # Get adjusted weight scales.
- did_update_weight_scale, new_weight_scale = self._adjust_weight_scale_for_int32_bias(
- input_scale,
- weight_scale,
- bias_info.weight_name,
- find_by_name(bias_name, self.model.initializer()),
- is_per_channel,
- )
- if did_update_weight_scale:
- weight_quant_params["scale"] = new_weight_scale
- def remove_node(self, node):
- self.nodes_to_remove.append(node)
- def remove_nodes(self):
- self.model.remove_nodes(self.nodes_to_remove)
- def quantize_model(self):
- for node in self.model.nodes():
- if self.should_quantize_node(node):
- op_quantizer = CreateQDQQuantizer(self, node)
- op_quantizer.quantize()
- for tensor_name in node.input:
- if tensor_name not in self.tensor_to_its_receiving_nodes:
- self.tensor_to_its_receiving_nodes[tensor_name] = []
- self.tensor_to_its_receiving_nodes[tensor_name].append(node)
- if node.op_type == DEQUANT_OP_NAME:
- for tensor_name in node.output:
- self.tensor_to_producing_dq[tensor_name] = node
- self.initializer_quant_params = self._calc_initializer_quant_params()
- self._adjust_weight_quant_params_for_bias_tensors()
- self._quantize_normal_tensors()
- self._quantize_sharing_param_tensors()
- if self.quantize_bias:
- self._quantize_bias_tensors()
- self.remove_nodes()
- if not self.add_qdq_pair_to_weight:
- self.model.clean_initializers()
- self.model.model.producer_name = __producer__
- self.model.model.producer_version = __version__
- if self.qdq_op_domain == ms_domain:
- self.model.set_opset_import(ms_domain, 1)
- return self.model.model
- def try_replacing_upstream_output(self, upstream_output_name, output_name):
- if (
- output_name in self.quantization_params
- and self.quantization_params[output_name].converted is None
- and self.quantization_params[upstream_output_name].converted is None
- and len(self.model.input_name_to_nodes()[upstream_output_name]) == 1
- and not self.model.is_graph_output(upstream_output_name)
- and not self.model.is_graph_input(upstream_output_name)
- ):
- self.model.replace_output_of_all_nodes(upstream_output_name, output_name)
- if upstream_output_name in self.tensors_to_quantize:
- del self.tensors_to_quantize[upstream_output_name]
- return True
- return False
- def _create_q_node(
- self,
- q_input: str,
- q_output: str,
- quant_node_name: str,
- scale_name: str,
- zp_name: str,
- axis: int | None = None,
- ):
- """
- Creates a QuantizeLinear node and adds it to the model.
- """
- qlinear_node = onnx.helper.make_node(
- QUANT_OP_NAME,
- [q_input, scale_name, zp_name],
- [q_output],
- quant_node_name,
- axis=axis,
- domain=self.qdq_op_domain,
- )
- self.model.add_nodes([qlinear_node])
- def _create_dq_node(
- self,
- dq_input: str,
- dq_output: str,
- dequant_node_name: str,
- scale_name: str,
- zp_name: str,
- axis: int | None = None,
- ):
- """
- Creates a DequantizeLinear node and adds it to the model.
- """
- dequant_node = onnx.helper.make_node(
- DEQUANT_OP_NAME,
- [dq_input, scale_name, zp_name],
- [dq_output],
- dequant_node_name,
- axis=axis,
- domain=self.qdq_op_domain,
- )
- self.model.add_nodes([dequant_node])
- def _create_qdq_nodes(
- self, q_input, q_output, quant_node_name, dq_input, dq_output, dequant_node_name, scale_name, zp_name, axis=None
- ):
- qlinear_node = onnx.helper.make_node(
- QUANT_OP_NAME,
- [q_input, scale_name, zp_name],
- [q_output],
- quant_node_name,
- axis=axis,
- domain=self.qdq_op_domain,
- )
- dequant_node = onnx.helper.make_node(
- DEQUANT_OP_NAME,
- [dq_input, scale_name, zp_name],
- [dq_output],
- dequant_node_name,
- axis=axis,
- domain=self.qdq_op_domain,
- )
- self.model.add_nodes([qlinear_node, dequant_node])
- def _add_qdq_nodes_for_initializer(self, weight_proto: onnx.TensorProto):
- """
- Adds Q/DQ nodes for an initializer. If `self.add_qdq_pair_to_weight` is true, creates
- the sequence (weight_f32 -> Q -> DQ -> ). Otherwise, this function quantizes the initializer
- and adds the sequence (weight_quant -> DQ ->).
- """
- weight_name = weight_proto.name
- if weight_name in self.quantized_value_map:
- return
- quant_params: QuantizationParams = self.initializer_quant_params[weight_name]
- axis: int = quant_params.get("axis")
- scale_zp_initializers = self._make_scale_zp_initializers(weight_name, quant_params)
- q_weight_name: str | None = None
- weight_dequant_output = add_dequant_output_suffix(weight_name)
- self.model.replace_input_of_all_nodes(weight_name, weight_dequant_output)
- if self.add_qdq_pair_to_weight:
- # Don't actually quantize the weight. Instead, keep floating-point weight and create the node
- # sequence (weight_f32 -> Q -> DQ -> weight_dequant)
- weight_quant_output = add_quant_output_suffix(weight_name)
- self._create_qdq_nodes(
- weight_name,
- weight_quant_output,
- add_quant_suffix(weight_name),
- weight_quant_output,
- weight_dequant_output,
- add_dequant_suffix(weight_name),
- scale_zp_initializers.scale.name,
- scale_zp_initializers.zero_point.name,
- axis,
- )
- else:
- # Quantize the weight and create the node sequence:
- # (weight_quantized -> DQ -> weight_dequant)
- quant_weight = quantize_onnx_initializer(
- weight_proto,
- quant_params["quant_type"],
- quant_params["zero_point"],
- quant_params["scale"],
- axis,
- )
- self.model.add_initializer(quant_weight)
- q_weight_name = quant_weight.name
- dequant_node = onnx.helper.make_node(
- DEQUANT_OP_NAME,
- [quant_weight.name, scale_zp_initializers.scale.name, scale_zp_initializers.zero_point.name],
- [weight_dequant_output],
- add_dequant_suffix(weight_name),
- axis=axis,
- domain=self.qdq_op_domain,
- )
- self.model.add_node(dequant_node)
- # Log entry for this quantized weight
- quantized_value = QuantizedValue(
- weight_name,
- q_weight_name,
- scale_zp_initializers.scale.name,
- scale_zp_initializers.zero_point.name,
- QuantizedValueType.Initializer,
- axis=axis,
- )
- self.quantized_value_map[weight_name] = QDQTensorQuantizedValue(quantized_value, None, None)
- def _add_qdq_pair_for_activation(self, tensor_name, scale_name, zp_name, data_type=None):
- if (
- self.dedicated_qdq_pair
- and tensor_name in self.tensor_to_its_receiving_nodes
- and len(self.tensor_to_its_receiving_nodes[tensor_name]) > 1
- ):
- num_dedicated_qdq_pair = len(self.tensor_to_its_receiving_nodes[tensor_name])
- for i in range(num_dedicated_qdq_pair):
- postfix = f"_{i + 1}"
- tensor_name_quant_output_postfix = add_quant_output_suffix(tensor_name) + postfix
- tensor_name_dequant_output_postfix = add_dequant_output_suffix(tensor_name) + postfix
- quant_node_name_postfix = add_quant_suffix(tensor_name) + postfix
- dequant_node_name_postfix = add_dequant_suffix(tensor_name) + postfix
- self._create_qdq_nodes(
- tensor_name,
- tensor_name_quant_output_postfix,
- quant_node_name_postfix,
- tensor_name_quant_output_postfix,
- tensor_name_dequant_output_postfix,
- dequant_node_name_postfix,
- scale_name,
- zp_name,
- )
- node = self.tensor_to_its_receiving_nodes[tensor_name][i]
- self.model.replace_node_input(node, tensor_name, tensor_name_dequant_output_postfix)
- if i == 0:
- quantized_value = QuantizedValue(
- tensor_name,
- tensor_name_dequant_output_postfix,
- scale_name,
- zp_name,
- QuantizedValueType.Input,
- scale_type=data_type,
- )
- self.quantized_value_map[tensor_name] = QDQTensorQuantizedValue(quantized_value, None, None)
- else:
- q_input = tensor_name
- dq_output = add_dequant_output_suffix(tensor_name)
- if self.model.is_graph_output(tensor_name):
- q_input = add_quant_input_suffix(tensor_name)
- dq_output = tensor_name
- self.model.replace_output_of_all_nodes(tensor_name, q_input)
- else:
- self.model.replace_input_of_all_nodes(tensor_name, dq_output)
- self._create_qdq_nodes(
- q_input,
- add_quant_output_suffix(tensor_name),
- add_quant_suffix(tensor_name),
- add_quant_output_suffix(tensor_name),
- dq_output,
- add_dequant_suffix(tensor_name),
- scale_name,
- zp_name,
- )
- quantized_value = QuantizedValue(
- tensor_name,
- dq_output,
- scale_name,
- zp_name,
- QuantizedValueType.Input,
- scale_type=data_type,
- )
- self.quantized_value_map[tensor_name] = QDQTensorQuantizedValue(quantized_value, None, None)
- def _add_qdq_ops_for_converted_activation(
- self,
- tensor_name,
- first_scale_name,
- first_zp_name,
- scale_data_type,
- convert_scale_name,
- convert_zp_name,
- convert_recv_nodes,
- ):
- """
- Adds Q and DQ ops to a tensor whose quantized data type is converted. That is, some consumers may use the
- original data type from the producer, while other consumers use the converted data type.
- This is generally done by adding a sequence of ops that convert from one data type (e.g., uint8) to another (e.g., uint16).
- T_float ---> Quant(to u8) ---> Convert(to u16) ---> Dequant(to float) ---> T_float'
- where Convert(to u16) is equivalent to: ---> Dequant(to float) ---> Quant(to u16) --->
- This function handles the following scenarios:
- 1) Tensor T is not a graph output; all consumers use the converted type
- <Producer> ---> Q1 ---> DQ1 ---> Q2 ---> DQ2 ---> <Consumers>
- 2) Tensor T is not a graph output; some consumers use the original type, others use the converted type
- <Producer> ---> Q1 -+-> DQ1 ---> <Consumers of original type>
- |
- +-> DQ1' ---> Q2 ---> DQ2 ---> <Consumers of converted type>
- 3) Tensor T is a graph output; all consumers use the converted type
- <Producer> ---> Q1 ---> DQ1 ---> Q2 ---> DQ2 -+-> <Consumers>
- |
- +-> <Graph output>
- 4) Tensor T is a graph output; some consumers use the original type, others use the converted type
- <Producer> ---> Q1 -+-> DQ1 -+-> <Consumers of original type>
- | |
- | +-> <Graph output>
- |
- +-> DQ1' ---> Q2 ---> DQ2 ---> <Consumers of converted type>
- 5) Tensor T is a graph output that is not consumed by any other nodes.
- <Producer> ---> Q1 ---> DQ1 ---> Q2 ---> DQ2 ---> <Graph output>
- """
- tensor_recv_nodes = {node.name for node in self.tensor_to_its_receiving_nodes.get(tensor_name, [])}
- if (
- self.dedicated_qdq_pair
- and tensor_name in self.tensor_to_its_receiving_nodes
- and len(self.tensor_to_its_receiving_nodes[tensor_name]) > 1
- ):
- # TODO: Add support for dedicated_qdq_pair if/when needed.
- raise ValueError(
- "Do not currently support converted quant_types in TensorQuantOverrides when the `dedicated_qdq_pair` extra_option is enabled"
- )
- # Determine which nodes consume the original quantized type and which nodes
- # consume the converted quantized type.
- original_recv_nodes = tensor_recv_nodes
- if convert_recv_nodes is None: # In this case, all consumers receive the converted type.
- convert_recv_nodes = tensor_recv_nodes
- original_recv_nodes = set()
- else:
- original_recv_nodes = original_recv_nodes - convert_recv_nodes
- all_use_converted = len(convert_recv_nodes) == len(tensor_recv_nodes)
- is_graph_output = self.model.is_graph_output(tensor_name)
- # Create first Q op.
- first_q_input = tensor_name
- if is_graph_output:
- first_q_input = add_quant_input_suffix(tensor_name)
- self.model.replace_output_of_all_nodes(tensor_name, first_q_input)
- first_q_output = add_quant_output_suffix(tensor_name)
- self._create_q_node(
- first_q_input, first_q_output, add_quant_suffix(tensor_name), first_scale_name, first_zp_name
- )
- # Create first DQ op.
- first_dq_output = add_dequant_output_suffix(tensor_name)
- if is_graph_output and not all_use_converted:
- first_dq_output = tensor_name
- if original_recv_nodes and first_dq_output != tensor_name:
- self.model.replace_input_of_nodes(tensor_name, first_dq_output, original_recv_nodes)
- self._create_dq_node(
- first_q_output, first_dq_output, add_dequant_suffix(tensor_name), first_scale_name, first_zp_name
- )
- # Create parallel clone of first DQ op if _not all_ consumers use the converted type.
- # --> DQ1' --> Q2 --> DQ2 --> <Consumers of converted type>
- #
- # This DQ clone would only have one consumer Q node (Q2) and could be potentially fused with
- # it by some EPs (e.g., QNN) without breaking other "node units".
- # Ex QNN fusion:
- # --> Convert (fused) --> DQ2 --> <Consumers of converted type>
- second_q_input = first_dq_output
- if not all_use_converted:
- second_q_input = add_quant_input_suffix(f"{tensor_name}_convert")
- self._create_dq_node(
- first_q_output,
- second_q_input,
- add_dequant_suffix(f"{tensor_name}_convert_clone"),
- first_scale_name,
- first_zp_name,
- )
- # Create second Q op.
- second_q_output = add_quant_output_suffix(f"{tensor_name}_convert")
- self._create_q_node(
- second_q_input,
- second_q_output,
- add_quant_suffix(f"{tensor_name}_convert"),
- convert_scale_name,
- convert_zp_name,
- )
- # Create second DQ op.
- second_dq_output = add_dequant_output_suffix(f"{tensor_name}_convert")
- if is_graph_output and all_use_converted:
- second_dq_output = tensor_name
- if convert_recv_nodes and second_dq_output != tensor_name:
- self.model.replace_input_of_nodes(tensor_name, second_dq_output, convert_recv_nodes)
- self._create_dq_node(
- second_q_output,
- second_dq_output,
- add_dequant_suffix(f"{tensor_name}_convert"),
- convert_scale_name,
- convert_zp_name,
- )
- # Store in quantized_value_map
- original_quantized_value = QuantizedValue(
- tensor_name,
- first_dq_output,
- first_scale_name,
- first_zp_name,
- QuantizedValueType.Input,
- scale_type=scale_data_type,
- )
- converted_quantized_value = QuantizedValue(
- tensor_name,
- second_dq_output,
- convert_scale_name,
- convert_zp_name,
- QuantizedValueType.Input,
- scale_type=scale_data_type,
- )
- self.quantized_value_map[tensor_name] = QDQTensorQuantizedValue(
- original_quantized_value, converted_quantized_value, convert_recv_nodes
- )
- def _quantize_normal_tensors(self):
- """
- Adds Q/DQ ops to tensors (activations and weights) that have been marked for quantization by op quantizers.
- """
- for tensor_name, tensor_info in self.tensors_to_quantize.copy().items():
- if tensor_name in self.quantized_value_map:
- continue
- if not tensor_info.is_shared:
- # Quantize the input
- initializer = find_by_name(tensor_name, self.model.initializer())
- if initializer:
- self._add_qdq_nodes_for_initializer(initializer)
- else:
- # Check if this tensor is already a dequantized value. If so, skip it.
- # This happens if the original input model already has some pre-quantized weights
- # generated by a different tool.
- # Ex: (quantized_weight -> DequantizeLinear -> this_tensor)
- if tensor_name in self.tensor_to_producing_dq:
- del self.tensors_to_quantize[tensor_name]
- continue
- tensor_qparam_initializers = self._make_tensor_scale_zp_initializers(tensor_name)
- if not tensor_qparam_initializers:
- raise ValueError(
- f"Quantization parameters are not specified for param {tensor_name}. "
- "In static mode quantization params for inputs and outputs of nodes to be quantized are required."
- )
- if tensor_qparam_initializers.converted is None:
- # Normal case: <producer> --> Q --> DQ --> <consumers>
- self._add_qdq_pair_for_activation(
- tensor_name,
- tensor_qparam_initializers.original.scale.name,
- tensor_qparam_initializers.original.zero_point.name,
- data_type=tensor_info.data_type,
- )
- else:
- # Conversion case: <producer> ---> Q1 -+-> DQ1 --> <consumers of original type>
- # |
- # +-> DQ1' --> Q2 --> DQ2 --> <consumers of converted type>
- assert tensor_info.data_type == tensor_qparam_initializers.original.scale.data_type
- self._add_qdq_ops_for_converted_activation(
- tensor_name,
- tensor_qparam_initializers.original.scale.name,
- tensor_qparam_initializers.original.zero_point.name,
- tensor_info.data_type,
- tensor_qparam_initializers.converted.scale.name,
- tensor_qparam_initializers.converted.zero_point.name,
- tensor_qparam_initializers.converted_recv_nodes,
- )
- del self.tensors_to_quantize[tensor_name]
- def _quantize_sharing_param_tensors(self):
- """
- Adds Q/DQ ops to tensors that have been marked for quantization by op quantizers.
- Only operates on tensors that want to use the quantization parameter initializers from an upstream tensor.
- For example, a Transpose node's output tensor will typically want to use the same quantization parameter
- initializers as the Transpose node's input.
- """
- while self.tensors_to_quantize:
- for tensor_name, tensor_info in self.tensors_to_quantize.copy().items():
- quant_provider = tensor_info.quant_para_provider
- if quant_provider and quant_provider.input_name in self.quantized_value_map:
- del self.tensors_to_quantize[tensor_name]
- quantized_value = self.quantized_value_map[quant_provider.input_name].get_for_consumer(
- quant_provider.node_name
- )
- if self.is_input_a_initializer(tensor_name):
- raise ValueError("Quantization parameter shared mode is not supported for weight yet")
- if tensor_name in self.tensor_to_producing_dq:
- raise ValueError(
- f"Quantization parameter sharing is invalid for tensor {tensor_name} "
- "because it has already been quantized"
- )
- # Need to check if this tensor's quant_type is converted for some consumers.
- # If so, create new scale/zp initializers for these consumers.
- converted_qparam_inits = None
- converted_recv_nodes = None
- if tensor_name in self.quantization_params:
- tensor_params = self.quantization_params[tensor_name]
- if tensor_params.converted:
- converted_qparam_inits = self._make_scale_zp_initializers(
- tensor_name, tensor_params.converted, "_convert"
- )
- converted_recv_nodes = tensor_params.converted_recv_nodes
- if converted_qparam_inits is None:
- # Normal case: <producer> --> Q_shared --> DQ_shared --> <consumers>
- self._add_qdq_pair_for_activation(
- tensor_name, quantized_value.scale_name, quantized_value.zp_name
- )
- else:
- # Conversion case: <producer> ---> Q_shared -+-> DQ_shared --> <consumers of original type>
- # |
- # +-> DQ_shared' --> Q2 --> DQ2 --> <consumers of converted type>
- self._add_qdq_ops_for_converted_activation(
- tensor_name,
- quantized_value.scale_name,
- quantized_value.zp_name,
- converted_qparam_inits.scale.data_type,
- converted_qparam_inits.scale.name,
- converted_qparam_inits.zero_point.name,
- converted_recv_nodes,
- )
- def _quantize_bias_tensors(self):
- """
- Adds DQ ops (or Cast) for bias tensors that have been marked for quantization by op quantizers.
- """
- for bias_name, bias_info in self.bias_to_quantize.items():
- if bias_name in self.quantized_value_map:
- continue
- # Quantize the input
- self.quantize_bias_static(bias_name, bias_info)
- init = find_by_name(bias_name, self.model.initializer())
- self.model.remove_initializer(init)
- quant_value = self.quantized_value_map[bias_name].original
- if quant_value.node_type == "Cast":
- # simple cast to float 16 and not DequantizeLinear
- # cublasLtMatmul only supports (b)float16, float bias.
- if not isinstance(init.data_type, int):
- raise TypeError(f"Unexpected type {type(init.data_type)} for input={bias_info.input_name!r}")
- node_name = add_dequant_suffix(bias_name)
- dequant_node = onnx.helper.make_node(
- "Cast",
- [quant_value.q_name],
- [bias_name],
- name=node_name,
- to=init.data_type,
- )
- elif quant_value.node_type in (None, "DequantizeLinear"):
- if quant_value.node_qtype in {
- onnx.TensorProto.FLOAT16,
- onnx.TensorProto.BFLOAT16,
- onnx.TensorProto.FLOAT,
- }:
- raise RuntimeError(f"Unexpected quantize type {quant_value.node_qtype} for DequantizeLinear.")
- inputs = [quant_value.q_name, quant_value.scale_name, quant_value.zp_name]
- node_name = add_dequant_suffix(bias_name)
- if quant_value.axis is not None:
- dequant_node = onnx.helper.make_node(
- "DequantizeLinear",
- inputs,
- [bias_name],
- node_name,
- axis=quant_value.axis,
- domain=self.qdq_op_domain,
- )
- else:
- dequant_node = onnx.helper.make_node(
- "DequantizeLinear",
- inputs,
- [bias_name],
- node_name,
- domain=self.qdq_op_domain,
- )
- else:
- raise RuntimeError(f"Unexpected operator type {quant_value.node_type!r}.")
- self.model.add_node(dequant_node)
- def is_tensor_quantized(self, tensor_name: str):
- return tensor_name in self.tensors_to_quantize or tensor_name in self.bias_to_quantize
- def is_tensor_per_channel(
- self,
- tensor_name: str,
- default_axis: int,
- op_type: str | None = None,
- ) -> tuple[bool, int | None]:
- """
- Checks if a given tensor is configured to be quantized per-channel. If so, also returns the channel axis.
- ORT only supports per-channel quantization on static weights (i.e., ONNX initializers). If the user did not provide
- tensor quantization overrides for this tensor, then the value of self.per_channel determines if the weight
- is to be quantized per-channel.
- Params:
- tensor_name: The name of the tensor to check.
- default_axis: The default channel axis. This method checks if the normalized axis is within bounds.
- Can be overridden via the extra_options 'QDQOpTypePerChannelSupportToAxis'
- and 'TensorQuantOverrides'.
- op_type: Optional, defaults to None. The operator type that is the only consumer of this weight.
- Used to access the extra option 'QDQOpTypePerChannelSupportToAxis'.
- Returns:
- A tuple (is_per_channel, axis) in which the first element indicates whether the tensor is
- quantized per-channel and the second element is the channel axis.
- The returned axis is only None if the tensor is not per-channel or the axis is out of bounds.
- """
- weight_initializer = self.initializers.get(tensor_name)
- if weight_initializer is None:
- return False, None # Only support per-channel weights
- if self.tensor_quant_overrides.has_per_tensor_overrides(tensor_name):
- return False, None # User provided per-tensor overrides for this initializer
- has_per_chan_overrides = self.tensor_quant_overrides.has_per_channel_overrides(tensor_name)
- if not self.per_channel and not has_per_chan_overrides:
- return False, None # global self.per_channel is off and user did not provide per-channel overrides.
- axis = self.qdq_op_type_per_channel_support_to_axis.get(op_type, default_axis) if op_type else default_axis
- if has_per_chan_overrides:
- per_chan_overrides = self.tensor_quant_overrides.get_per_channel_overrides(tensor_name)
- axis = per_chan_overrides[0]["axis"] # Prefer axis from user-specified tensor-level overrides if available
- weight_rank = len(weight_initializer.dims)
- axis_valid, axis = normalize_axis(axis, weight_rank)
- if not axis_valid:
- logging.warning(f"Axis {axis} is out-of-range for weight '{tensor_name}' with rank {weight_rank}")
- return False, None
- return True, axis
- def _get_tensor_quantization_scale(self, tensor_name: str, consumer_node_name: str) -> np.ndarray | None:
- """
- Returns the quantization scale of a tensor that is consumed by the given node.
- :parameter tensor_name: The name of the tensor.
- :parameter consumer_node_name: The name of the node that consumes the tensor as input. Necessary in case
- the quantization type of the tensor was converted.
- Refer: QDQQuantizer::_add_qdq_ops_for_converted_activation.
- :returns: The quantization scale or None.
- """
- initializers = self.model.initializer()
- scale_initializer: onnx.TensorProto | None = None
- if tensor_name in self.quantized_value_map:
- # Tensor was quantized by this tool, so get scale from initializer created by this tool run.
- scale_name = self.quantized_value_map[tensor_name].get_for_consumer(consumer_node_name).scale_name
- scale_initializer = find_by_name(scale_name, initializers)
- else:
- # Tensor was already quantized in original model, so get scale from DQ node that outputs the tensor.
- dq_node = self.tensor_to_producing_dq.get(tensor_name, None)
- if dq_node:
- scale_initializer = find_by_name(dq_node.input[1], initializers)
- return tensor_proto_to_array(scale_initializer) if scale_initializer is not None else None
- def quantize_bias_static(self, bias_name: str, bias_info: QDQBiasQuantInfo) -> str:
- """
- Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale
- """
- # Handle case where bias already in quantization map
- if bias_name in self.quantized_value_map:
- return self.quantized_value_map[bias_name].original.q_name
- # get scale for weight.
- weight_scale = self._get_tensor_quantization_scale(bias_info.weight_name, bias_info.node_name)
- if weight_scale is None:
- raise ValueError(
- f"Unable to get valid quantization scale for weight input '{bias_info.weight_name}' "
- f"when quantizing bias '{bias_name}' to int32."
- )
- # get scale for input.
- input_scale = self._get_tensor_quantization_scale(bias_info.input_name, bias_info.node_name)
- if input_scale is None:
- raise ValueError(
- f"Unable to get valid quantization scale for input '{bias_info.input_name}' "
- f"when quantizing bias '{bias_name}' to int32."
- )
- (
- quantized_bias_name,
- quantized_bias_scale_name,
- quantized_bias_zp_name,
- bias_scale_data,
- node_type,
- node_qtype,
- ) = self.quantize_bias_static_impl(bias_name, input_scale, weight_scale, bias_info.beta)
- quantized_value = QuantizedValue(
- bias_name,
- quantized_bias_name,
- quantized_bias_scale_name,
- quantized_bias_zp_name,
- QuantizedValueType.Initializer,
- 0 if bias_scale_data.size > 1 else None,
- node_type=node_type,
- node_qtype=node_qtype,
- )
- self.quantized_value_map[bias_name] = QDQTensorQuantizedValue(quantized_value, None, None)
- return quantized_bias_name
- def _make_scale_zp_initializers(
- self, param_name: str, quant_params: QuantizationParams, init_name_suffix: str = ""
- ) -> QDQScaleZpInitializers:
- """
- Creates and returns scale and zero-point initializers for the given quantization params. The initializers are
- named:
- - {param_name}_zero_point{init_name_suffix}
- - {param_name}_scale{init_name_suffix}
- """
- zero_point = quant_params["zero_point"]
- scale = quant_params["scale"]
- zero_point_type = quant_params["quant_type"]
- axis: int | None = quant_params.get("axis")
- assert (axis is not None and len(scale.shape) == 1) or (axis is None and len(scale.shape) == 0), (
- "Wrong scale/zp shapes"
- )
- assert len(scale.shape) == len(zero_point.shape), "Scale and zero-point must have the same rank"
- zero_point_name = param_name + "_zero_point" + init_name_suffix
- scale_name = param_name + "_scale" + init_name_suffix
- # Add initializers to model
- init_zp = onnx.helper.make_tensor(
- zero_point_name, zero_point_type, zero_point.shape, zero_point.ravel().tolist()
- )
- self.model.add_initializer(init_zp)
- if scale.dtype == np.float32:
- scale_type = onnx_proto.TensorProto.FLOAT
- elif scale.dtype == np.float16:
- scale_type = onnx_proto.TensorProto.FLOAT16
- else:
- raise ValueError(f"Unexpected dtype={scale.dtype} for param_name={param_name!r}")
- init_scale = onnx.helper.make_tensor(scale_name, scale_type, scale.shape, scale.ravel().tolist())
- self.model.add_initializer(init_scale)
- return QDQScaleZpInitializers(init_scale, init_zp)
- def _make_tensor_scale_zp_initializers(self, tensor_name: str) -> QDQTensorScaleZpInitializers | None:
- """
- Create and returns all scale/zero_point initializers for a given tensor. If the tensor is converted
- to a different quantization type, this function creates two pairs of zp/scale initializers. Otherwise,
- only one pair of zp/scale initializers is created.
- """
- if self.quantization_params is None or tensor_name not in self.quantization_params:
- logging.info(f'Quantization parameters for tensor:"{tensor_name}" not specified')
- return None
- tensor_params = self.quantization_params[tensor_name]
- if not isinstance(tensor_params, QDQTensorQuantParams):
- raise TypeError(f"Unexpected type {type(tensor_params)} for {tensor_name!r}.")
- original_inits = self._make_scale_zp_initializers(tensor_name, tensor_params.original)
- converted_inits = (
- self._make_scale_zp_initializers(tensor_name, tensor_params.converted, "_convert")
- if tensor_params.converted
- else None
- )
- return QDQTensorScaleZpInitializers(original_inits, converted_inits, tensor_params.converted_recv_nodes)
- def calc_quant_params(self, tensor_data: TensorData, quant_overrides: dict[str, Any]) -> QuantizationParams:
- """
- Calculates quantization parameters (scale/zero-point) given a tensor's min/max range and optional
- user-provided overrides.
- """
- quant_type = self.activation_qType
- if "quant_type" in quant_overrides:
- quant_type = quant_overrides["quant_type"].tensor_type
- if "scale" in quant_overrides and "zero_point" in quant_overrides:
- zero, scale = quant_overrides["zero_point"], quant_overrides["scale"]
- elif quant_type == onnx.TensorProto.FLOAT8E4M3FN:
- zero, scale = compute_scale_zp_float8(quant_type, tensor_data.avg_std[1])
- else:
- rmin = quant_overrides.get("rmin", tensor_data.range_value[0])
- rmax = quant_overrides.get("rmax", tensor_data.range_value[1])
- symmetric = quant_overrides.get("symmetric", self.is_activation_symmetric)
- reduce_range = quant_overrides.get("reduce_range", False)
- qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range=reduce_range, symmetric=symmetric)
- zero, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, self.min_real_range)
- return QuantizationParams(zero_point=zero.squeeze(), scale=scale.squeeze(), quant_type=quant_type)
- def calc_graph_quant_params(self) -> dict[str, QDQTensorQuantParams]:
- """
- Calculates quantization parameters (scale/zero-point) for all tensors in the graph using each tensor's min/max range
- and optional user-provided overrides.
- """
- if self.tensors_range is None:
- return {}
- self.adjust_tensor_ranges()
- quantization_params = {}
- for tensor_name in self.tensors_range:
- td = self.tensors_range[tensor_name]
- if not isinstance(td, TensorData):
- raise TypeError(f"Unexpected type {type(td)} for {tensor_name!r}.")
- quant_overrides = self.tensor_quant_overrides.get_per_tensor_overrides(tensor_name, default_val={})
- original = self.calc_quant_params(td, quant_overrides)
- converted = None
- converted_recv_nodes = None
- if "convert" in quant_overrides:
- converted = self.calc_quant_params(td, quant_overrides["convert"])
- converted_recv_nodes = quant_overrides["convert"].get("recv_nodes")
- quantization_params[tensor_name] = QDQTensorQuantParams(original, converted, converted_recv_nodes)
- return quantization_params
- def _calc_initializer_quant_params(self) -> dict[str, QuantizationParams]:
- """
- Returns quantization parameters (scale/zero_point/quant_type) for all initializers.
- """
- quantization_params: dict[str, QuantizationParams] = {}
- for tensor_name, tensor_info in self.tensors_to_quantize.items():
- initializer = find_by_name(tensor_name, self.model.initializer())
- if not initializer:
- continue
- initializer_data = tensor_proto_to_array(initializer)
- initializer_rank = len(initializer_data.shape)
- # initializers for elementwise ops use the quant_type for activations.
- is_weight = tensor_info.tensor_type is QDQQuantTensorType.WEIGHT
- quant_type = self.weight_qType if is_weight else self.activation_qType
- # Try to get scale/zp directly from user's overrides and avoid computation.
- if self.tensor_quant_overrides.overrides_scale_zp(tensor_name):
- overrides = self.tensor_quant_overrides[tensor_name]
- if "quant_type" in overrides[0]:
- quant_type = overrides[0]["quant_type"].tensor_type
- zp_dtype = ONNX_TYPE_TO_NP_TYPE[quant_type]
- is_per_channel = "axis" in overrides[0]
- if not is_per_channel:
- quantization_params[tensor_name] = QuantizationParams(
- zero_point=np.array(overrides[0]["zero_point"], dtype=zp_dtype),
- scale=np.array(overrides[0]["scale"], initializer_data.dtype),
- quant_type=quant_type,
- )
- else:
- zero_points_list = []
- scales_list = []
- for chan_overrides in overrides:
- zero_points_list.append(np.array(chan_overrides["zero_point"], zp_dtype))
- scales_list.append(np.array(chan_overrides["scale"], dtype=initializer_data.dtype))
- channel_axis = overrides[0]["axis"]
- is_axis_valid, norm_channel_axis = normalize_axis(channel_axis, initializer_rank)
- if not is_axis_valid:
- raise ValueError(
- f"Weight {initializer.name} has a per-channel axis with value {channel_axis} that is "
- f"out-of-bounds for rank {initializer_rank}"
- )
- quantization_params[tensor_name] = QuantizationParams(
- zero_point=np.array(zero_points_list),
- scale=np.array(scales_list),
- quant_type=quant_type,
- axis=norm_channel_axis,
- )
- continue
- # Compute scale/zp normally. User's overrides may still override parameters
- # used to compute the scale/zp (e.g., rmin, rmax, symmetric, etc.)
- overrides = self.tensor_quant_overrides.get(tensor_name, [{}])
- if "quant_type" in overrides[0]:
- quant_type = overrides[0]["quant_type"].tensor_type
- channel_axis = overrides[0].get("axis", tensor_info.axis)
- is_per_channel = channel_axis is not None
- # Note: always quantize per-channel initializers as symmetric because QLinear* ops require the
- # same zero-point in every channel, which is necessarily the case for symmetric quantization.
- is_symmetric_default = is_per_channel or (
- self.is_weight_symmetric(quant_type) if is_weight else self.is_activation_symmetric
- )
- is_symmetric = overrides[0].get("symmetric", is_symmetric_default)
- reduce_range = overrides[0].get("reduce_range", self.reduce_range)
- zero_point: np.ndarray | None = None
- scale: np.ndarray | None = None
- if not is_per_channel:
- zero_point, scale = compute_data_quant_params(
- initializer_data.flatten(),
- quant_type,
- is_symmetric,
- reduce_range=reduce_range,
- min_real_range=self.min_real_range,
- rmin_override=overrides[0].get("rmin"),
- rmax_override=overrides[0].get("rmax"),
- )
- else:
- is_axis_valid, norm_channel_axis = normalize_axis(channel_axis, initializer_rank)
- if not is_axis_valid:
- raise ValueError(
- f"Weight {initializer.name} has a per-channel axis with value {channel_axis} that is "
- f"out-of-bounds for rank {initializer_rank}"
- )
- channel_axis = norm_channel_axis
- channel_count = initializer_data.shape[channel_axis]
- zero_points_list = []
- scales_list = []
- for i in range(channel_count):
- per_channel_data = initializer_data.take(i, channel_axis)
- channel_overrides = overrides[i] if overrides and i < len(overrides) else {}
- channel_zero_point, channel_scale = compute_data_quant_params(
- per_channel_data.ravel(),
- quant_type,
- is_symmetric,
- reduce_range=reduce_range,
- min_real_range=self.min_real_range,
- rmin_override=channel_overrides.get("rmin"),
- rmax_override=channel_overrides.get("rmax"),
- )
- zero_points_list.append(channel_zero_point)
- scales_list.append(channel_scale)
- zero_point = np.asarray(zero_points_list)
- scale = np.asarray(scales_list)
- quantization_params[tensor_name] = QuantizationParams(
- zero_point=zero_point,
- scale=scale,
- quant_type=quant_type,
- axis=channel_axis,
- )
- return quantization_params
|