quant_utils.py 39 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. from __future__ import annotations
  7. import copy
  8. import logging
  9. import os
  10. import tempfile
  11. from enum import Enum
  12. from pathlib import Path
  13. import numpy
  14. import onnx
  15. from ml_dtypes import float8_e4m3fn, int4, uint4
  16. from onnx import ModelProto, TensorProto, external_data_helper
  17. from onnx import onnx_pb as onnx_proto
  18. from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info
  19. from onnx.reference import ReferenceEvaluator
  20. from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions
  21. try:
  22. from onnx.reference.op_run import to_array_extended
  23. except ImportError:
  24. # old version of onnx.
  25. to_array_extended = None
  26. __producer__ = "onnx.quantize"
  27. __version__ = "0.1.0"
  28. onnx_domain = "ai.onnx"
  29. ms_domain = "com.microsoft"
  30. QUANT_OP_NAME = "QuantizeLinear"
  31. QUANT_INPUT_SUFFIX = "_QuantizeLinear_Input"
  32. DEQUANT_OP_NAME = "DequantizeLinear"
  33. DEQUANT_OUTPUT_SUFFIX = "_DequantizeLinear_Output"
  34. TENSOR_NAME_QUANT_SUFFIX = "_quantized"
  35. MODEL_SIZE_THRESHOLD = 2147483648 # Quant model should use external data if >= 2GB
  36. FLOAT8_DISTRIBUTIONS = {}
  37. type_to_name = {getattr(TensorProto, k): k for k in dir(TensorProto) if isinstance(getattr(TensorProto, k), int)}
  38. # Quantization mode
  39. # IntegerOps: Use IntegerOps in quantized model. Only ConvInteger and MatMulInteger ops are supported now.
  40. # QLinearOps: Use QLinearOps in quantized model. Only QLinearConv and QLinearMatMul ops are supported now.
  41. class QuantizationMode(Enum):
  42. IntegerOps = 0
  43. QLinearOps = 1
  44. def __str__(self):
  45. return self.name
  46. @staticmethod
  47. def from_string(mode):
  48. try:
  49. return QuantizationMode[mode]
  50. except KeyError:
  51. raise ValueError() # noqa: B904
  52. class QuantizedValueType(Enum):
  53. Input = 0
  54. Initializer = 1
  55. def __str__(self):
  56. return self.name
  57. @staticmethod
  58. def from_string(v):
  59. try:
  60. return QuantizedValueType[v]
  61. except KeyError:
  62. raise ValueError() # noqa: B904
  63. class QuantType(Enum):
  64. QInt8 = 0
  65. QUInt8 = 1
  66. QFLOAT8E4M3FN = 2
  67. QInt16 = 3
  68. QUInt16 = 4
  69. QInt4 = 5
  70. QUInt4 = 6
  71. def __str__(self):
  72. return self.name
  73. @staticmethod
  74. def from_string(t):
  75. try:
  76. return QuantType[t]
  77. except KeyError:
  78. raise ValueError() # noqa: B904
  79. @property
  80. def tensor_type(self):
  81. if self == QuantType.QInt8:
  82. return TensorProto.INT8
  83. if self == QuantType.QUInt8:
  84. return TensorProto.UINT8
  85. if self == QuantType.QUInt16:
  86. return TensorProto.UINT16
  87. if self == QuantType.QInt16:
  88. return TensorProto.INT16
  89. if self == QuantType.QFLOAT8E4M3FN:
  90. return TensorProto.FLOAT8E4M3FN
  91. if self == QuantType.QUInt4:
  92. return TensorProto.UINT4
  93. if self == QuantType.QInt4:
  94. return TensorProto.INT4
  95. raise ValueError(f"Unexpected value qtype={self!r}.")
  96. class QuantFormat(Enum):
  97. QOperator = 0
  98. QDQ = 1
  99. def __str__(self):
  100. return self.name
  101. @staticmethod
  102. def from_string(format):
  103. try:
  104. return QuantFormat[format]
  105. except KeyError:
  106. raise ValueError() # noqa: B904
  107. ONNX_TYPE_TO_NP_TYPE = {
  108. onnx_proto.TensorProto.INT8: numpy.dtype("int8"),
  109. onnx_proto.TensorProto.UINT8: numpy.dtype("uint8"),
  110. onnx_proto.TensorProto.INT16: numpy.dtype("int16"),
  111. onnx_proto.TensorProto.UINT16: numpy.dtype("uint16"),
  112. onnx_proto.TensorProto.FLOAT8E4M3FN: float8_e4m3fn,
  113. onnx_proto.TensorProto.INT4: int4,
  114. onnx_proto.TensorProto.UINT4: uint4,
  115. }
  116. ONNX_INT_TYPE_RANGE = {
  117. onnx_proto.TensorProto.UINT8: (numpy.array(0, dtype=numpy.uint8), numpy.array(255, dtype=numpy.uint8)),
  118. onnx_proto.TensorProto.INT8: (numpy.array(-128, dtype=numpy.int8), numpy.array(127, dtype=numpy.int8)),
  119. onnx_proto.TensorProto.UINT16: (numpy.array(0, dtype=numpy.uint16), numpy.array(65535, dtype=numpy.uint16)),
  120. onnx_proto.TensorProto.INT16: (numpy.array(-32768, dtype=numpy.int16), numpy.array(32767, dtype=numpy.int16)),
  121. onnx_proto.TensorProto.UINT4: (numpy.array(0, dtype=uint4), numpy.array(15, dtype=uint4)),
  122. onnx_proto.TensorProto.INT4: (numpy.array(-8, dtype=int4), numpy.array(7, dtype=int4)),
  123. }
  124. ONNX_INT_TYPE_SYMMETRIC_RANGE = {
  125. onnx_proto.TensorProto.INT8: (numpy.array(-127, dtype=numpy.int8), numpy.array(127, dtype=numpy.int8)),
  126. onnx_proto.TensorProto.INT16: (numpy.array(-32767, dtype=numpy.int16), numpy.array(32767, dtype=numpy.int16)),
  127. }
  128. ONNX_INT_TYPE_REDUCED_RANGE = {
  129. onnx_proto.TensorProto.UINT8: (numpy.array(0, dtype=numpy.uint8), numpy.array(127, dtype=numpy.uint8)),
  130. onnx_proto.TensorProto.INT8: (numpy.array(-64, dtype=numpy.int8), numpy.array(64, dtype=numpy.int8)),
  131. onnx_proto.TensorProto.UINT16: (numpy.array(0, dtype=numpy.uint16), numpy.array(32767, dtype=numpy.uint16)),
  132. onnx_proto.TensorProto.INT16: (numpy.array(-16384, dtype=numpy.int16), numpy.array(16384, dtype=numpy.int16)),
  133. onnx_proto.TensorProto.UINT4: (numpy.array(0, dtype=uint4), numpy.array(7, dtype=uint4)),
  134. onnx_proto.TensorProto.INT4: (numpy.array(-4, dtype=int4), numpy.array(3, dtype=int4)),
  135. }
  136. def _check_type(*args, zero_point_index=-1):
  137. new_args = []
  138. for i, a in enumerate(args):
  139. if numpy.issubdtype(type(a), numpy.number):
  140. new_args.append(numpy.array(a))
  141. elif isinstance(a, numpy.ndarray):
  142. new_args.append(a)
  143. else:
  144. raise TypeError(f"arg {i} is not an array: {a}")
  145. if i == zero_point_index:
  146. v = new_args[-1]
  147. if v.dtype == numpy.float32 or v.dtype == numpy.float16:
  148. raise TypeError(f"zero_point cannot be {v.dtype}")
  149. return tuple(new_args) if len(new_args) > 1 else new_args[0]
  150. def quantize_nparray(qType, arr, scale, zero_point, low=None, high=None):
  151. assert qType in ONNX_TYPE_TO_NP_TYPE, (
  152. f"Unexpected data type {qType} requested. Only INT8, UINT8, INT16, and UINT16 are supported."
  153. )
  154. if qType in (
  155. onnx_proto.TensorProto.FLOAT8E4M3FN,
  156. onnx_proto.TensorProto.FLOAT8E4M3FNUZ,
  157. onnx_proto.TensorProto.FLOAT8E5M2,
  158. onnx_proto.TensorProto.FLOAT8E5M2FNUZ,
  159. ):
  160. if zero_point != 0:
  161. raise NotImplementedError(f"zero_point is expected to be null for float 8 not {zero_point!r}.")
  162. if arr.dtype == numpy.float32:
  163. onnx_type = TensorProto.FLOAT
  164. elif arr.dtype == numpy.float16:
  165. onnx_type = TensorProto.FLOAT16
  166. else:
  167. raise ValueError(f"Unexpected dtype {arr.dtype}.")
  168. onnx_model = make_model(
  169. make_graph(
  170. [
  171. make_node(
  172. "Constant", [], ["zero_point"], value=onnx.helper.make_tensor("zero_point", qType, [], [0])
  173. ),
  174. make_node("QuantizeLinear", ["X", "scale", "zero_point"], ["Y"]),
  175. ],
  176. "qu",
  177. [
  178. make_tensor_value_info("X", onnx_type, None),
  179. make_tensor_value_info("scale", onnx_type, None),
  180. ],
  181. [make_tensor_value_info("Y", qType, None)],
  182. )
  183. )
  184. ref = ReferenceEvaluator(onnx_model)
  185. return _check_type(ref.run(None, {"X": arr, "scale": scale})[0])
  186. else:
  187. # Quantizes data for all integer types.
  188. #
  189. # For int4 types, the quantized data is returned as either np.int8 or np.uint8,
  190. # which matches the python reference ONNX implementation of QuantizeLinear.
  191. # This data can be packed into 4-bit elements by using pack_bytes_to_4bit().
  192. dtype = ONNX_TYPE_TO_NP_TYPE[qType]
  193. qmin, qmax = get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=False)
  194. cliplow = max(qmin, low) if low is not None else qmin
  195. cliphigh = min(qmax, high) if high is not None else qmax
  196. arr_fp32 = numpy.asarray((arr.astype(numpy.float32) / scale).round() + zero_point)
  197. numpy.clip(arr_fp32, cliplow, cliphigh, out=arr_fp32)
  198. return _check_type(arr_fp32.astype(dtype))
  199. def compute_scale_zp(rmin, rmax, qmin, qmax, symmetric=False, min_real_range=None):
  200. """Calculate the scale s and zero point z for the quantization relation
  201. r = s(q-z), where r are the original values and q are the corresponding
  202. quantized values.
  203. r and z are calculated such that every value within [rmin,rmax] has an
  204. approximate representation within [qmin,qmax]. In addition, qmin <= z <=
  205. qmax is enforced. If the symmetric flag is set to True, the interval
  206. [rmin,rmax] is symmetrized to [-absmax, +absmax], where
  207. absmax = max(abs(rmin), abs(rmax)).
  208. :parameter rmin: minimum value of r
  209. :parameter rmax: maximum value of r
  210. :parameter qmin: minimum value representable by the target quantization data type
  211. :parameter qmax: maximum value representable by the target quantization data type
  212. :parameter symmetric: True if the floating-point range should be made symmetric. Defaults to False.
  213. :parameter min_real_range: Minimum floating-point range (i.e., rmax - rmin) to enforce. Defaults to None.
  214. :return: zero and scale [z, s]
  215. """
  216. if qmin > 0 or qmax < 0:
  217. raise ValueError(f"qmin and qmax must meet requirement: qmin <= 0 <= qmax while qmin:{qmin}, qmmax:{qmax}")
  218. # Adjust rmin and rmax such that 0 is included in the range. This is
  219. # required to make sure zero can be represented by the quantization data
  220. # type (i.e. to make sure qmin <= zero_point <= qmax)
  221. rmin = numpy.minimum(rmin, numpy.array(0, dtype=rmin.dtype))
  222. rmax = numpy.maximum(rmax, numpy.array(0, dtype=rmax.dtype))
  223. # Ensure a minimum float-point range if specified.
  224. if min_real_range is not None:
  225. rmax = max(rmax, rmin + numpy.asarray(min_real_range, dtype=rmin.dtype))
  226. if symmetric:
  227. absmax = numpy.maximum(numpy.abs(rmin), numpy.abs(rmax))
  228. rmin = -absmax
  229. rmax = +absmax
  230. assert qmin <= qmax, f"qmin={rmin} > qmax={rmax}"
  231. dr = numpy.array(rmax - rmin, dtype=numpy.float64)
  232. dq = numpy.array(qmax, dtype=numpy.float64) - numpy.array(qmin, dtype=numpy.float64)
  233. scale = numpy.array(dr / dq)
  234. assert scale >= 0, "scale issue"
  235. if scale < numpy.finfo(rmax.dtype).tiny:
  236. scale = numpy.array(1.0, dtype=rmax.dtype)
  237. zero_point = numpy.array(0, dtype=qmin.dtype)
  238. else:
  239. if symmetric:
  240. # When symmetric (i.e., rmax == -rmin), the zero_point formula reduces to round((qmax + qmin) / 2.0).
  241. # This simpler formula doesn't depend on scale and guarantees that the zero point values
  242. # for int8, uint8, int16, and uint16 are always 0, 128, 0, and 32768, respectively.
  243. # This is important for per-channel/symmetric QLinearConv on CPU EP, which requires all channels to have
  244. # the exact same zero_point values.
  245. zero_point = numpy.array(
  246. numpy.round((qmin + qmax) / numpy.array(2.0, dtype=numpy.float64)), dtype=qmin.dtype
  247. )
  248. else:
  249. zero_point = numpy.array(numpy.round(qmin - rmin / scale), dtype=qmin.dtype)
  250. scale = scale.astype(rmax.dtype)
  251. return [zero_point, scale]
  252. def compute_scale_zp_float8(element_type, std):
  253. """Calculate the scale s for a float8 type (E4M3FN).
  254. The function assumes the coefficient distribution and the float 8
  255. distribution are similar to two gaussian laws.
  256. :return: zero and scale [z, s]
  257. More details in notebook `quantization_fp8.ipynb
  258. <https://github.com/microsoft/onnxruntime/blob/main/docs/python/notebooks/quantization_fp8.ipynb>`_.
  259. """
  260. zp_dtype = None
  261. if element_type not in FLOAT8_DISTRIBUTIONS:
  262. if element_type == TensorProto.FLOAT8E4M3FN:
  263. from ml_dtypes import float8_e4m3fn # noqa: PLC0415
  264. zp_dtype = float8_e4m3fn
  265. all_values = [float(i) for i in range(256)]
  266. values = numpy.array(
  267. [f for f in all_values if not numpy.isnan(f) and not numpy.isinf(f)], dtype=numpy.float32
  268. )
  269. else:
  270. raise ValueError(f"Quantization to element_type={element_type} not implemented.")
  271. FLOAT8_DISTRIBUTIONS[element_type] = values
  272. elif element_type == TensorProto.FLOAT8E4M3FN:
  273. from ml_dtypes import float8_e4m3fn # noqa: PLC0415
  274. zp_dtype = float8_e4m3fn
  275. if zp_dtype is None:
  276. raise TypeError(f"Unexpected element_type {element_type}.")
  277. std_f8 = numpy.std(FLOAT8_DISTRIBUTIONS[element_type])
  278. zero = numpy.array(0, dtype=zp_dtype)
  279. scale = numpy.array(std / std_f8, dtype=std.dtype)
  280. return [zero, scale]
  281. def compute_data_quant_params(
  282. data: numpy.ndarray,
  283. quant_type: onnx.TensorProto.DataType,
  284. symmetric: bool,
  285. reduce_range: bool = False,
  286. min_real_range: float | None = None,
  287. rmin_override: float | None = None,
  288. rmax_override: float | None = None,
  289. ) -> tuple[numpy.ndarray, numpy.ndarray]:
  290. """
  291. Returns the zero_point and scale for the given data.
  292. :param data: The data for which to compute quantization parameters.
  293. :param quant_type: The quantization data type.
  294. :param symmetric: whether symmetric quantization is used or not.
  295. :parameter reduce_range: True if the quantization range should be reduced. Defaults to False.
  296. :parameter min_real_range: Minimum floating-point range (i.e., rmax - rmin) to enforce. Defaults to None.
  297. :parameter rmin_override: The value of rmin to use if not None. Otherwise, uses min(data).
  298. :parameter rmax_override: The value of rmax to use if not None. Otherwise, uses max(data).
  299. :return: zero point and scale
  300. """
  301. if not isinstance(data, numpy.ndarray):
  302. raise TypeError(f"Weight must be given as an array not {type(data)}.")
  303. if rmin_override is not None:
  304. rmin = rmin_override
  305. else:
  306. rmin = data.min() if len(data) else 0.0
  307. if rmax_override is not None:
  308. rmax = rmax_override
  309. else:
  310. rmax = data.max() if len(data) else 0.0
  311. rmin = numpy.array(rmin, dtype=data.dtype)
  312. rmax = numpy.array(rmax, dtype=data.dtype)
  313. scale = numpy.array(1.0, dtype=data.dtype)
  314. if quant_type == TensorProto.FLOAT8E4M3FN:
  315. if reduce_range:
  316. raise RuntimeError("Unsupported option reduce_range=True for float 8.")
  317. std = numpy.std(data)
  318. zero_point, scale = compute_scale_zp_float8(quant_type, std)
  319. return _check_type(zero_point, scale, zero_point_index=0)
  320. if quant_type in (
  321. TensorProto.INT8,
  322. TensorProto.UINT8,
  323. TensorProto.INT16,
  324. TensorProto.UINT16,
  325. TensorProto.INT4,
  326. TensorProto.UINT4,
  327. ):
  328. qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range, symmetric=symmetric)
  329. if len(data):
  330. zero_point, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, min_real_range)
  331. else:
  332. zero_point = numpy.array(0, dtype=qmin.dtype)
  333. return _check_type(zero_point, scale, zero_point_index=0)
  334. raise ValueError(f"Unexpected value for quant_type={quant_type}.")
  335. def quantize_data(
  336. data, qType, symmetric, reduce_range=False, min_real_range=None, rmin_override=None, rmax_override=None
  337. ) -> tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]:
  338. """
  339. :param data: data to quantize
  340. :param qType: data type to quantize to.
  341. :param symmetric: whether symmetric quantization is used or not.
  342. :parameter reduce_range: True if the quantization range should be reduced. Defaults to False.
  343. :parameter min_real_range: Minimum floating-point range (i.e., rmax - rmin) to enforce. Defaults to None.
  344. :parameter rmin_override: The value of rmin to use if not None. Otherwise, uses min(data).
  345. :parameter rmax_override: The value of rmax to use if not None. Otherwise, uses max(data).
  346. :return: minimum, maximum, zero point, scale, and quantized weights
  347. To pack weights, we compute a linear transformation
  348. - when data `type == uint8` mode, from `[rmin, rmax]` -> :math:`[0, 2^{b-1}]` and
  349. - when data `type == int8`, from `[-m , m]` -> :math:`[-(2^{b-1}-1), 2^{b-1}-1]` where
  350. `m = max(abs(rmin), abs(rmax))`
  351. and add necessary intermediate nodes to transform quantized weight to full weight using the equation
  352. :math:`r = S(q-z)`, where
  353. - *r*: real original value
  354. - *q*: quantized value
  355. - *S*: scale
  356. - *z*: zero point
  357. """
  358. zero_point, scale = compute_data_quant_params(
  359. data,
  360. qType,
  361. symmetric,
  362. reduce_range,
  363. min_real_range,
  364. rmin_override,
  365. rmax_override,
  366. )
  367. if qType == TensorProto.FLOAT8E4M3FN:
  368. quantized_data = quantize_nparray(qType, data, scale, zero_point)
  369. if any((quantized_data.view(numpy.uint8).ravel() & 127) == 127):
  370. np_data = numpy.asarray(data)
  371. raise RuntimeError(
  372. f"One of the quantized value is NaN data in [{np_data.min()}, {np_data.max()}], "
  373. f"quantized_data in [{quantized_data.min()}, {quantized_data.max()}]."
  374. )
  375. return zero_point, scale, quantized_data
  376. if qType in (
  377. TensorProto.INT8,
  378. TensorProto.UINT8,
  379. TensorProto.INT16,
  380. TensorProto.UINT16,
  381. TensorProto.INT4,
  382. TensorProto.UINT4,
  383. ):
  384. quantized_data = quantize_nparray(qType, data, scale, zero_point)
  385. return zero_point, scale, quantized_data
  386. raise ValueError(f"Unexpected value for qType={qType}.")
  387. def quantize_onnx_initializer(
  388. weight: onnx.TensorProto,
  389. quant_type: onnx.TensorProto.DataType,
  390. zero_point: numpy.ndarray,
  391. scale: numpy.ndarray,
  392. axis: int | None = None,
  393. quant_weight_name: str | None = None,
  394. ) -> onnx.TensorProto:
  395. """
  396. Returns a quantized version of the given ONNX initializer.
  397. :param weight: The ONNX initializer to quantize.
  398. :param quant_type: The final quantized data type.
  399. :param zero_point: The zero-point value to use for quantization.
  400. :param scale: The scale value to use for quantization.
  401. :param axis: The quantization axis if quantizing per-channel. Defaults to None.
  402. :param quant_weight_name: The name of the quantized initializer.
  403. If not specified, the quantized name is generated.
  404. :return: The quantized ONNX initializer.
  405. """
  406. weight_data = tensor_proto_to_array(weight)
  407. q_weight_data: numpy.ndarray | None = None
  408. if axis is None: # Per-tensor quantization
  409. q_weight_data = quantize_nparray(quant_type, weight_data.ravel(), scale, zero_point)
  410. else: # Per-channel quantization
  411. channel_count = weight_data.shape[axis]
  412. channel_dims = list(weight_data.shape) # deep copy
  413. channel_dims[axis] = 1 # only one per channel for reshape
  414. quantized_channel_data_list = []
  415. for i in range(channel_count):
  416. channel_data = weight_data.take(i, axis)
  417. channel_scale = scale[i]
  418. channel_zero_point = zero_point[i]
  419. quantized_channel_data = quantize_nparray(
  420. quant_type, channel_data.ravel(), channel_scale, channel_zero_point
  421. )
  422. quantized_channel_data_list.append(numpy.asarray(quantized_channel_data).reshape(channel_dims))
  423. q_weight_data = numpy.concatenate(quantized_channel_data_list, axis)
  424. q_weight_name = quant_weight_name if quant_weight_name else f"{weight.name}{TENSOR_NAME_QUANT_SUFFIX}"
  425. if quant_type == onnx.TensorProto.FLOAT8E4M3FN:
  426. q_weight_initializer = onnx.TensorProto()
  427. q_weight_initializer.data_type = quant_type
  428. q_weight_initializer.dims.extend(weight.dims)
  429. q_weight_initializer.name = q_weight_name
  430. # Do not remove .flatten().copy() numpy is not clear about data persistence.
  431. q_weight_initializer.raw_data = q_weight_data.flatten().copy().tobytes()
  432. if to_array_extended is not None:
  433. # This test should not be needed but it helped catch some issues
  434. # with data persistence and tobytes.
  435. check = to_array_extended(q_weight_initializer)
  436. if check.shape != weight_data.shape or check.tobytes() != q_weight_data.tobytes():
  437. raise RuntimeError(
  438. f"The initializer of shape {weight_data.shape} could not be created, expecting "
  439. f"{q_weight_data.tobytes()[:10]}, got {check.tobytes()[:10]} and shape={weight.shape}"
  440. f"\nraw={str(q_weight_initializer)[:200]}."
  441. )
  442. elif quant_type in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4):
  443. if q_weight_data.dtype not in (int4, uint4):
  444. raise RuntimeError(f"Quantized weights for {q_weight_name} must be 8-bit before packing as 4-bit values.")
  445. # We do not use onnx.helper.pack_float32_to_4bit() due to performance.
  446. # This can be the difference between a large model taking 30 minutes to quantize vs 5 minutes.
  447. packed_data = bytes(pack_bytes_to_4bit(q_weight_data.tobytes()))
  448. # We only use onnx.helper.make_tensor with raw data due to bug: https://github.com/onnx/onnx/pull/6161
  449. q_weight_initializer = onnx.helper.make_tensor(q_weight_name, quant_type, weight.dims, packed_data, raw=True)
  450. else:
  451. quant_np_dtype = onnx.helper.tensor_dtype_to_np_dtype(quant_type)
  452. q_weight_data = numpy.asarray(q_weight_data, dtype=quant_np_dtype).reshape(weight.dims)
  453. q_weight_initializer = onnx.numpy_helper.from_array(q_weight_data, q_weight_name)
  454. return q_weight_initializer
  455. def get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=False): # noqa: N802
  456. """
  457. Return qmin and qmax, the minimum and maximum value representable by the given qType
  458. :parameter qType: onnx.onnx_pb.TensorProto.UINT8 or onnx.onnx_pb.TensorProto.UINT8
  459. :return: qmin, qmax
  460. """
  461. if qType == onnx_proto.TensorProto.FLOAT8E4M3FN:
  462. raise NotImplementedError("This function is not implemented for float 8 as not needed.")
  463. qrange = None
  464. if reduce_range:
  465. qrange = ONNX_INT_TYPE_REDUCED_RANGE.get(qType)
  466. elif symmetric and qType in ONNX_INT_TYPE_SYMMETRIC_RANGE:
  467. qrange = ONNX_INT_TYPE_SYMMETRIC_RANGE[qType]
  468. else:
  469. qrange = ONNX_INT_TYPE_RANGE.get(qType)
  470. if not qrange:
  471. raise ValueError(f"Unexpected data type {qType} requested. Only INT8, UINT8, INT16, and UINT16 are supported.")
  472. qmin, qmax = qrange
  473. if qmin > 0 or qmax < 0:
  474. raise ValueError(
  475. f"qmin and qmax must meet requirement: qmin <= 0 <= qmax while "
  476. f"qmin:{qmin}, qmmax:{qmax}, dtype={qmin.dtype}, reduce_range={reduce_range}, "
  477. f"symmetric={symmetric}, qType={qType}"
  478. )
  479. return qrange
  480. def get_qrange_for_qType(qType, reduce_range=False, symmetric=False): # noqa: N802
  481. """
  482. Helper function to get the quantization range for a type.
  483. parameter qType: quantization type.
  484. return: quantization range.
  485. """
  486. qmin, qmax = get_qmin_qmax_for_qType(qType, reduce_range, symmetric=symmetric)
  487. return qmax - qmin
  488. def normalize_axis(axis: int, rank: int) -> tuple[bool, int]:
  489. """
  490. Helper function that tries to return a normalized axis in the range [0, rank - 1].
  491. :parameter axis: The axis to normalize.
  492. :parameter rank: The tensor rank (number of dimensions).
  493. :return (is_valid, axis_norm)
  494. """
  495. axis_norm = axis + rank if axis < 0 else axis
  496. is_valid = axis_norm >= 0 and axis_norm < rank
  497. return is_valid, axis_norm
  498. def pack_bytes_to_4bit(src_8bit: bytes) -> bytearray:
  499. """
  500. Copies a source array of 8-bit values into a destination bytearray of packed 4-bit values.
  501. Assumes that the source values are already in the appropriate int4 range.
  502. :parameter src_8bit: The 8-bit element values to pack.
  503. :return A bytearray with every two 8-bit src elements packed into a single byte.
  504. """
  505. num_elems = len(src_8bit)
  506. if num_elems == 0:
  507. return bytearray()
  508. dst_size = (num_elems + 1) // 2 # Ex: 5 8-bit elems packed into 3 bytes
  509. dst = bytearray(dst_size)
  510. src_i: int = 0
  511. dst_i: int = 0
  512. # Pack two 8-bit elements into a single byte in each iteration.
  513. while src_i < num_elems - 1:
  514. dst[dst_i] = ((src_8bit[src_i + 1] & 0xF) << 4) | (src_8bit[src_i] & 0xF)
  515. dst_i += 1
  516. src_i += 2
  517. if src_i < num_elems:
  518. # Odd number of elements.
  519. dst[dst_i] = src_8bit[src_i] & 0xF
  520. return dst
  521. class QuantizedInitializer:
  522. """
  523. Represents a linearly quantized weight input from ONNX operators
  524. """
  525. def __init__(
  526. self,
  527. name,
  528. initializer,
  529. rmins,
  530. rmaxs,
  531. zero_points,
  532. scales,
  533. data=[], # noqa: B006
  534. quantized_data=[], # noqa: B006
  535. axis=None,
  536. ):
  537. self.name = name
  538. self.initializer = initializer # TensorProto initializer in ONNX graph
  539. self.rmins = rmins # List of minimum range for each axis
  540. self.rmaxs = rmaxs # List of maximum range for each axis
  541. # 1D tensor of zero points computed for each axis. scalar if axis is empty
  542. self.zero_points = zero_points
  543. self.scales = scales # 1D tensor of scales computed for each axis. scalar if axis is empty
  544. self.data = data # original data from initializer TensorProto
  545. self.quantized_data = quantized_data # weight-packed data from data
  546. # Scalar to specify which dimension in the initializer to weight pack.
  547. self.axis = axis
  548. # If empty, single zero point and scales computed from a single rmin and rmax
  549. class QuantizedValue:
  550. """
  551. Represents a linearly quantized value (input\\output\\intializer)
  552. """
  553. def __init__(
  554. self,
  555. name,
  556. new_quantized_name,
  557. scale_name,
  558. zero_point_name,
  559. quantized_value_type,
  560. axis=None,
  561. node_type=None,
  562. node_qtype=None,
  563. scale_type=None,
  564. ):
  565. self.original_name = name
  566. self.q_name = new_quantized_name
  567. self.scale_name = scale_name
  568. self.zp_name = zero_point_name
  569. self.value_type = quantized_value_type
  570. self.axis = axis
  571. self.node_type = node_type
  572. self.node_qtype = node_qtype
  573. self.scale_type = scale_type
  574. class BiasToQuantize:
  575. """
  576. Represents a bias to be quantized
  577. """
  578. def __init__(self, bias_name, input_name, weight_name):
  579. self.bias_name = bias_name
  580. self.input_name = input_name
  581. self.weight_name = weight_name
  582. def attribute_to_kwarg(attribute):
  583. """
  584. Convert attribute to kwarg format for use with onnx.helper.make_node.
  585. :parameter attribute: attribute in AttributeProto format.
  586. :return: attribute in {key: value} format.
  587. """
  588. if attribute.type == 0:
  589. raise ValueError(f"attribute {attribute.name} does not have type specified.")
  590. # Based on attribute type definitions from AttributeProto
  591. # definition in https://github.com/onnx/onnx/blob/main/onnx/onnx.proto
  592. if attribute.type == 1:
  593. value = attribute.f
  594. elif attribute.type == 2:
  595. value = attribute.i
  596. elif attribute.type == 3:
  597. value = attribute.s
  598. elif attribute.type == 4:
  599. value = attribute.t
  600. elif attribute.type == 5:
  601. value = attribute.g
  602. elif attribute.type == 6:
  603. value = attribute.floats
  604. elif attribute.type == 7:
  605. value = attribute.ints
  606. elif attribute.type == 8:
  607. value = attribute.strings
  608. elif attribute.type == 9:
  609. value = attribute.tensors
  610. elif attribute.type == 10:
  611. value = attribute.graphs
  612. else:
  613. raise ValueError(f"attribute {attribute.name} has unsupported type {attribute.type}.")
  614. return {attribute.name: value}
  615. def find_by_name(item_name, item_list):
  616. """
  617. Helper function to find item by name in a list.
  618. parameter item_name: name of the item.
  619. parameter item_list: list of items.
  620. return: item if found. None otherwise.
  621. """
  622. items = [item for item in item_list if item.name == item_name]
  623. return items[0] if len(items) > 0 else None
  624. def get_elem_index(elem_name, elem_list):
  625. """
  626. Helper function to return index of an item in a node list
  627. """
  628. elem_idx = -1
  629. for i in range(len(elem_list)):
  630. if elem_list[i] == elem_name:
  631. elem_idx = i
  632. return elem_idx
  633. def get_mul_node(inputs, output, name):
  634. """
  635. Helper function to create a Mul node.
  636. parameter inputs: list of input names.
  637. parameter output: output name.
  638. parameter name: name of the node.
  639. return: Mul node in NodeProto format.
  640. """
  641. return onnx.helper.make_node("Mul", inputs, [output], name)
  642. def generate_identified_filename(filename: Path, identifier: str) -> Path:
  643. """
  644. Helper function to generate a identifiable filepath by concatenating the given identifier as a suffix.
  645. """
  646. return filename.parent.joinpath(filename.stem + identifier + filename.suffix)
  647. def apply_plot(hist, hist_edges):
  648. import sys # noqa: PLC0415
  649. import matplotlib.pyplot as plt # noqa: PLC0415
  650. import numpy # noqa: PLC0415
  651. numpy.set_printoptions(threshold=sys.maxsize)
  652. print("Histogram:")
  653. print(hist)
  654. print("Histogram Edges:")
  655. print(hist_edges)
  656. plt.stairs(hist, hist_edges, fill=True)
  657. plt.xlabel("Tensor value")
  658. plt.ylabel("Counts")
  659. plt.title("Tensor value V.S. Counts")
  660. plt.show()
  661. def write_calibration_table(calibration_cache, dir="."):
  662. """
  663. Helper function to write calibration table to files.
  664. """
  665. import json # noqa: PLC0415
  666. import flatbuffers # noqa: PLC0415
  667. import numpy as np # noqa: PLC0415
  668. import onnxruntime.quantization.CalTableFlatBuffers.KeyValue as KeyValue # noqa: PLC0415
  669. import onnxruntime.quantization.CalTableFlatBuffers.TrtTable as TrtTable # noqa: PLC0415
  670. from onnxruntime.quantization.calibrate import CalibrationMethod, TensorData, TensorsData # noqa: PLC0415
  671. logging.info(f"calibration cache: {calibration_cache}")
  672. class MyEncoder(json.JSONEncoder):
  673. def default(self, obj):
  674. if isinstance(obj, (TensorData, TensorsData)):
  675. return obj.to_dict()
  676. if isinstance(obj, np.ndarray):
  677. return {"data": obj.tolist(), "dtype": str(obj.dtype), "CLS": "numpy.array"}
  678. if isinstance(obj, CalibrationMethod):
  679. return {"CLS": obj.__class__.__name__, "value": str(obj)}
  680. return json.JSONEncoder.default(self, obj)
  681. json_data = json.dumps(calibration_cache, cls=MyEncoder)
  682. with open(os.path.join(dir, "calibration.json"), "w") as file:
  683. file.write(json_data) # use `json.loads` to do the reverse
  684. # Serialize data using FlatBuffers
  685. zero = np.array(0)
  686. builder = flatbuffers.Builder(1024)
  687. key_value_list = []
  688. for key in sorted(calibration_cache.keys()):
  689. values = calibration_cache[key]
  690. d_values = values.to_dict()
  691. floats = [
  692. float(d_values.get("highest", zero).item()),
  693. float(d_values.get("lowest", zero).item()),
  694. ]
  695. value = str(max(floats))
  696. flat_key = builder.CreateString(key)
  697. flat_value = builder.CreateString(value)
  698. KeyValue.KeyValueStart(builder)
  699. KeyValue.KeyValueAddKey(builder, flat_key)
  700. KeyValue.KeyValueAddValue(builder, flat_value)
  701. key_value = KeyValue.KeyValueEnd(builder)
  702. key_value_list.append(key_value)
  703. TrtTable.TrtTableStartDictVector(builder, len(key_value_list))
  704. for key_value in key_value_list:
  705. builder.PrependUOffsetTRelative(key_value)
  706. main_dict = builder.EndVector()
  707. TrtTable.TrtTableStart(builder)
  708. TrtTable.TrtTableAddDict(builder, main_dict)
  709. cal_table = TrtTable.TrtTableEnd(builder)
  710. builder.Finish(cal_table)
  711. buf = builder.Output()
  712. with open(os.path.join(dir, "calibration.flatbuffers"), "wb") as file:
  713. file.write(buf)
  714. # Deserialize data (for validation)
  715. if os.environ.get("QUANTIZATION_DEBUG", "0") in (1, "1"):
  716. cal_table = TrtTable.TrtTable.GetRootAsTrtTable(buf, 0)
  717. dict_len = cal_table.DictLength()
  718. for i in range(dict_len):
  719. key_value = cal_table.Dict(i)
  720. logging.info(key_value.Key())
  721. logging.info(key_value.Value())
  722. # write plain text
  723. with open(os.path.join(dir, "calibration.cache"), "w") as file:
  724. for key in sorted(calibration_cache.keys()):
  725. values = calibration_cache[key]
  726. d_values = values.to_dict()
  727. floats = [
  728. float(d_values.get("highest", zero).item()),
  729. float(d_values.get("lowest", zero).item()),
  730. ]
  731. value = key + " " + str(max(floats))
  732. file.write(value)
  733. file.write("\n")
  734. def smooth_distribution(p, eps=0.0001):
  735. """Given a discrete distribution (may have not been normalized to 1),
  736. smooth it by replacing zeros with eps multiplied by a scaling factor
  737. and taking the corresponding amount off the non-zero values.
  738. Ref: http://web.engr.illinois.edu/~hanj/cs412/bk3/KL-divergence.pdf
  739. https://github.com//apache/incubator-mxnet/blob/master/python/mxnet/contrib/quantization.py
  740. """
  741. is_zeros = (p == 0).astype(numpy.float32)
  742. is_nonzeros = (p != 0).astype(numpy.float32)
  743. n_zeros = is_zeros.sum()
  744. n_nonzeros = p.size - n_zeros
  745. if not n_nonzeros:
  746. # raise ValueError('The discrete probability distribution is malformed. All entries are 0.')
  747. return None
  748. eps1 = eps * float(n_zeros) / float(n_nonzeros)
  749. assert eps1 < 1.0, f"n_zeros={n_zeros}, n_nonzeros={n_nonzeros}, eps1={eps1}"
  750. hist = p.astype(numpy.float32)
  751. hist += eps * is_zeros + (-eps1) * is_nonzeros
  752. assert (hist <= 0).sum() == 0
  753. return hist
  754. def model_has_external_data(model_path: Path):
  755. model = onnx.load(model_path.as_posix(), load_external_data=False)
  756. return any(external_data_helper.uses_external_data(intializer) for intializer in model.graph.initializer)
  757. def optimize_model(model_path: Path, opt_model_path: Path):
  758. """
  759. Generate model that applies graph optimization (constant folding, etc.)
  760. parameter model_path: path to the original onnx model
  761. parameter opt_model_path: path to the optimized onnx model
  762. :return: optimized onnx model
  763. """
  764. sess_option = SessionOptions()
  765. sess_option.optimized_model_filepath = opt_model_path.as_posix()
  766. sess_option.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
  767. kwargs = {}
  768. # This will rename constant initializer names, disable it to make test pass.
  769. kwargs["disabled_optimizers"] = ["ConstantSharing"]
  770. _ = InferenceSession(model_path.as_posix(), sess_option, providers=["CPUExecutionProvider"], **kwargs)
  771. def add_pre_process_metadata(model: ModelProto):
  772. """Tag the model that it went through quantization pre-processing"""
  773. metadata_props = {"onnx.quant.pre_process": "onnxruntime.quant"}
  774. if model.metadata_props:
  775. for prop in model.metadata_props:
  776. metadata_props.update({prop.key: prop.value})
  777. onnx.helper.set_model_props(model, metadata_props)
  778. def model_has_pre_process_metadata(model: ModelProto) -> bool:
  779. """Check the model whether it went through quantization pre-processing"""
  780. if model.metadata_props:
  781. for prop in model.metadata_props:
  782. if prop.key == "onnx.quant.pre_process" and prop.value == "onnxruntime.quant":
  783. return True
  784. return False
  785. def add_infer_metadata(model: ModelProto):
  786. metadata_props = {"onnx.infer": "onnxruntime.quant"}
  787. if model.metadata_props:
  788. for p in model.metadata_props:
  789. metadata_props.update({p.key: p.value})
  790. onnx.helper.set_model_props(model, metadata_props)
  791. def model_has_infer_metadata(model: ModelProto) -> bool:
  792. if model.metadata_props:
  793. for p in model.metadata_props:
  794. if p.key == "onnx.infer" and p.value == "onnxruntime.quant":
  795. return True
  796. return False
  797. def get_opset_version(model: ModelProto) -> int:
  798. ai_onnx_domain = [opset for opset in model.opset_import if not opset.domain or opset.domain == "ai.onnx"]
  799. if len(ai_onnx_domain) != 1:
  800. raise ValueError("Failed to find proper ai.onnx domain")
  801. opset_version = ai_onnx_domain[0].version
  802. return opset_version
  803. def update_opset_version(model: ModelProto, weight_type: QuantType) -> ModelProto:
  804. opset_version = get_opset_version(model)
  805. target_opset_version = opset_version
  806. weight_quant_type = getattr(weight_type, "tensor_type", weight_type)
  807. if opset_version < 19 and weight_quant_type == onnx.TensorProto.FLOAT8E4M3FN:
  808. logging.warning(
  809. f"The original model opset version is {opset_version}, which does not support quantization to float 8. "
  810. "Please update the model to opset >= 19. Automatically update the model to opset 19. "
  811. "Please verify the quantized model."
  812. )
  813. target_opset_version = 19
  814. elif opset_version == 10:
  815. logging.warning(
  816. f"The original model opset version is {opset_version}, which does not support node fusions. "
  817. "Please update the model to opset >= 11 for better performance."
  818. )
  819. elif opset_version < 10:
  820. logging.warning(
  821. f"The original model opset version is {opset_version}, which does not support quantization. "
  822. "Please update the model to opset >= 11. Automatically update the model to opset 11. "
  823. "Please verify the quantized model."
  824. )
  825. target_opset_version = 11
  826. if target_opset_version != opset_version:
  827. model = onnx.version_converter.convert_version(model, target_opset_version)
  828. # Additional nodes may be added to the model during the opset version conversion. Run shape inference
  829. # to ensure all nodes are included in model.graph.value_info.
  830. model = save_and_reload_model_with_shape_infer(model)
  831. return model
  832. def load_model_with_shape_infer(model_path: Path) -> ModelProto:
  833. inferred_model_path = generate_identified_filename(model_path, "-inferred")
  834. onnx.shape_inference.infer_shapes_path(str(model_path), str(inferred_model_path))
  835. model = onnx.load(inferred_model_path.as_posix())
  836. add_infer_metadata(model)
  837. inferred_model_path.unlink()
  838. return model
  839. def save_and_reload_model_with_shape_infer(model: ModelProto) -> ModelProto:
  840. with tempfile.TemporaryDirectory(prefix="ort.quant.") as quant_tmp_dir:
  841. model_copy = copy.deepcopy(model)
  842. model_path = Path(quant_tmp_dir).joinpath("model.onnx")
  843. onnx.save_model(model_copy, model_path.as_posix(), save_as_external_data=True)
  844. return load_model_with_shape_infer(model_path)
  845. def tensor_proto_to_array(initializer: TensorProto) -> numpy.ndarray:
  846. if initializer.data_type in (onnx_proto.TensorProto.FLOAT, onnx_proto.TensorProto.FLOAT16):
  847. return onnx.numpy_helper.to_array(initializer)
  848. raise ValueError(
  849. f"Only float type is supported. Weights {initializer.name} is {type_to_name[initializer.data_type]}"
  850. )
  851. def add_quant_suffix(tensor_name: str) -> str:
  852. return tensor_name + "_QuantizeLinear"
  853. def add_quant_input_suffix(tensor_name: str) -> str:
  854. return tensor_name + QUANT_INPUT_SUFFIX
  855. def add_quant_output_suffix(tensor_name) -> str:
  856. return tensor_name + "_QuantizeLinear_Output"
  857. def add_dequant_suffix(tensor_name) -> str:
  858. return tensor_name + "_DequantizeLinear"
  859. def add_dequant_input_suffix(tensor_name) -> str:
  860. return tensor_name + "_DequantizeLinear_Input"
  861. def add_dequant_output_suffix(tensor_name) -> str:
  862. return tensor_name + DEQUANT_OUTPUT_SUFFIX