base_quantizer.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import logging
  7. from typing import Any
  8. import numpy as np
  9. import onnx
  10. import onnx.numpy_helper
  11. try:
  12. from onnx.reference.op_run import to_array_extended
  13. except ImportError:
  14. # old version of onnx.
  15. to_array_extended = None
  16. from .calibrate import TensorData
  17. from .onnx_model import ONNXModel
  18. from .quant_utils import (
  19. DEQUANT_OP_NAME,
  20. ONNX_TYPE_TO_NP_TYPE,
  21. QUANT_OP_NAME,
  22. TENSOR_NAME_QUANT_SUFFIX,
  23. find_by_name,
  24. get_opset_version,
  25. model_has_infer_metadata,
  26. normalize_axis,
  27. pack_bytes_to_4bit,
  28. quantize_data,
  29. quantize_nparray,
  30. save_and_reload_model_with_shape_infer,
  31. tensor_proto_to_array,
  32. )
  33. from .tensor_quant_overrides import TensorQuantOverridesHelper
  34. class QuantizationParams:
  35. def __init__(self, **data: dict[str, Any]):
  36. self.data = {}
  37. for k, v in data.items():
  38. if not isinstance(k, str):
  39. raise TypeError(f"Keys must be strings not {type(k)} for k={k!r}.")
  40. if k != "axis" and not isinstance(v, (int, str, np.ndarray, float)):
  41. raise TypeError(f"Values must be numpy arrays, int, float, str not {type(v)} for k={k!r}.")
  42. if k == "axis" and not isinstance(v, int) and v is not None:
  43. raise TypeError(f"Axis value must be an int or None, not {type(v)}.")
  44. if k == "scale" and v.dtype not in (np.float32, np.float16):
  45. raise ValueError(f"scale must a float32 or float16 numpy element but is {v.dtype} for k={k!r}")
  46. self.data[k] = v
  47. def get(self, key, default_value=None):
  48. return self.data.get(key, default_value)
  49. def __iter__(self):
  50. yield from self.data
  51. def __getitem__(self, key):
  52. return self.data[key]
  53. def __setitem__(self, key, value):
  54. self.data[key] = value
  55. def __len__(self):
  56. return len(self.data)
  57. class BaseQuantizer:
  58. def __init__(
  59. self,
  60. model,
  61. per_channel,
  62. reduce_range,
  63. weight_qType,
  64. activation_qType,
  65. tensors_range,
  66. nodes_to_quantize,
  67. nodes_to_exclude,
  68. op_types_to_quantize,
  69. extra_options=None,
  70. ):
  71. if not model_has_infer_metadata(model):
  72. model = save_and_reload_model_with_shape_infer(model)
  73. self.value_infos = {vi.name: vi for vi in model.graph.value_info}
  74. self.value_infos.update({ot.name: ot for ot in model.graph.output})
  75. self.value_infos.update({it.name: it for it in model.graph.input})
  76. self.model = ONNXModel(model)
  77. self.opset_version = get_opset_version(model)
  78. self.per_channel = per_channel # weight-pack per channel
  79. self.reduce_range = reduce_range
  80. self.extra_options = extra_options if extra_options else {}
  81. self.enable_subgraph_quantization = (
  82. "EnableSubgraph" in self.extra_options and self.extra_options["EnableSubgraph"]
  83. )
  84. self.parent = None
  85. self.force_quantize_no_input_check = (
  86. "ForceQuantizeNoInputCheck" in self.extra_options and self.extra_options["ForceQuantizeNoInputCheck"]
  87. )
  88. # If user does not explicitly set "WeightSymmetric", then the weight's quantization type determines
  89. # the symmetry (i.e., signed integer types will use symmetric quantization). See `def is_weight_symmetric()`
  90. self._is_weight_symmetric: bool | None = self.extra_options.get("WeightSymmetric", None)
  91. self.is_activation_symmetric = self.extra_options.get("ActivationSymmetric", False)
  92. self.min_real_range = self.extra_options.get("MinimumRealRange")
  93. self.activation_qType = getattr(activation_qType, "tensor_type", activation_qType)
  94. self.weight_qType = getattr(weight_qType, "tensor_type", weight_qType)
  95. """
  96. Dictionary specifying the min and max values for tensors. It has following format:
  97. {
  98. "param_name": [min, max]
  99. }
  100. example:
  101. {
  102. 'Conv_3:0': [np.float32(0), np.float32(0.5)],
  103. 'Conv_4:0': [np.float32(1), np.float32(3.5)]
  104. }
  105. """
  106. if tensors_range is not None and any(not isinstance(t, TensorData) for t in tensors_range.values()):
  107. raise TypeError(
  108. f"tensors_range contains unexpected types { {type(v) for v in tensors_range.values()} }, not TensorData."
  109. )
  110. self.tensors_range = tensors_range
  111. self.nodes_to_quantize = nodes_to_quantize # specific nodes to quantize
  112. self.nodes_to_exclude = nodes_to_exclude # specific nodes to exclude
  113. self.op_types_to_quantize = op_types_to_quantize
  114. # Get tensor-level quantization overrides and ensure they are valid.
  115. self.tensor_quant_overrides = TensorQuantOverridesHelper(self.extra_options.get("TensorQuantOverrides", {}))
  116. self.initializers = {initzer.name: initzer for initzer in self.model.initializer()}
  117. overrides_valid, overrides_err = self.tensor_quant_overrides.is_valid(
  118. self.initializers, self.value_infos.keys(), activation_qType
  119. )
  120. if not overrides_valid:
  121. raise ValueError(overrides_err)
  122. self.tensor_quant_override_qtypes = self.tensor_quant_overrides.get_quant_types()
  123. def is_weight_symmetric(self, weight_quant_type: onnx.TensorProto.DataType) -> bool:
  124. if self._is_weight_symmetric is not None:
  125. return self._is_weight_symmetric # Return value explicitly set by user.
  126. return weight_quant_type in (
  127. onnx.TensorProto.INT4,
  128. onnx.TensorProto.INT8,
  129. onnx.TensorProto.INT16,
  130. onnx.TensorProto.FLOAT8E4M3FN,
  131. )
  132. def quantize_model(self):
  133. raise NotImplementedError
  134. def is_input_a_initializer(self, input_name):
  135. initializer = find_by_name(input_name, self.model.initializer())
  136. return initializer is not None
  137. def is_per_channel(self):
  138. return self.per_channel
  139. def is_valid_quantize_weight(self, weight_name):
  140. weight = find_by_name(weight_name, self.model.initializer())
  141. if weight is not None:
  142. return weight.data_type in (onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT16)
  143. if (not self.enable_subgraph_quantization) or (self.parent is None):
  144. return False
  145. return self.parent.is_valid_quantize_weight(weight_name)
  146. def should_quantize_node(self, node):
  147. if (
  148. self.nodes_to_quantize is not None
  149. and len(self.nodes_to_quantize) != 0
  150. and node.name not in self.nodes_to_quantize
  151. ):
  152. return False
  153. if node.op_type not in self.op_types_to_quantize:
  154. return False
  155. if node.op_type in (DEQUANT_OP_NAME, QUANT_OP_NAME):
  156. return False
  157. if self.nodes_to_exclude is not None and node.name in self.nodes_to_exclude:
  158. return False
  159. return True
  160. def quantize_bias_static_impl(self, bias_name, input_scale, weight_scale, beta=1.0):
  161. """
  162. Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale
  163. """
  164. # get bias
  165. bias_initializer = find_by_name(bias_name, self.model.initializer())
  166. bias_data = tensor_proto_to_array(bias_initializer)
  167. quantized_bias_name = bias_name + TENSOR_NAME_QUANT_SUFFIX
  168. # quantize bias
  169. if self.weight_qType == onnx.TensorProto.FLOAT8E4M3FN:
  170. data = np.asarray(bias_data)
  171. if data.dtype == np.float16:
  172. node_qtype = onnx.TensorProto.FLOAT16
  173. elif data.dtype == np.float32:
  174. node_qtype = onnx.TensorProto.FLOAT
  175. else:
  176. raise TypeError(f"Only float16 or float32 are supported with float 8 but bias dtype is {data.dtype}.")
  177. quantized_data = data.astype(np.float32)
  178. bias_scale = np.array([1], dtype=quantized_data.dtype)
  179. bias_scale_data = bias_scale.reshape(-1)
  180. packed_bias_initializer = onnx.numpy_helper.from_array(quantized_data, quantized_bias_name)
  181. self.model.initializer_extend([packed_bias_initializer])
  182. node_type = "Cast"
  183. else:
  184. # calculate scale for bias
  185. # TODO: This formula should be explained including why the scale is not estimated for the bias as well.
  186. bias_scale = input_scale * weight_scale * beta
  187. # Quantize by dividing by bias_scale
  188. quantized_data = np.asarray(bias_data, dtype=np.float64) / np.asarray(bias_scale, dtype=np.float64)
  189. quantized_data = quantized_data.round()
  190. # Clip quantized data to the range of a int32
  191. int32_min = np.float64(np.iinfo(np.int32).min)
  192. int32_max = np.float64(np.iinfo(np.int32).max)
  193. if np.any(quantized_data < int32_min) or np.any(quantized_data > int32_max):
  194. logging.warning(
  195. f"Quantized bias `{bias_name}` exceeds the range of a int32. The bias scale is too small."
  196. )
  197. quantized_data = np.clip(quantized_data, int32_min, int32_max).astype(np.int32)
  198. # update bias initializer
  199. bias_np_data = np.asarray(quantized_data, dtype=np.int32).reshape(bias_initializer.dims)
  200. packed_bias_initializer = onnx.numpy_helper.from_array(bias_np_data, quantized_bias_name)
  201. self.model.initializer_extend([packed_bias_initializer])
  202. # Bias's scale dtype should match the original bias data's unquantized type (float32 or float16).
  203. bias_scale_data = np.asarray(bias_scale, dtype=bias_data.dtype).reshape(-1)
  204. node_type = "DequantizeLinear"
  205. node_qtype = self.weight_qType
  206. # update scale initializer
  207. quantized_bias_scale_name = quantized_bias_name + "_scale"
  208. packed_bias_scale_initializer = onnx.numpy_helper.from_array(bias_scale_data, quantized_bias_scale_name)
  209. self.model.initializer_extend([packed_bias_scale_initializer])
  210. # update zero initializer
  211. if self.weight_qType == onnx.TensorProto.FLOAT8E4M3FN:
  212. tensor_type = self.weight_qType
  213. else:
  214. tensor_type = onnx.TensorProto.INT32
  215. quantized_bias_zp_name = quantized_bias_name + "_zero_point"
  216. if self.weight_qType == onnx.TensorProto.FLOAT8E4M3FN:
  217. packed_bias_zp_initializer = onnx.helper.make_tensor(quantized_bias_zp_name, self.weight_qType, [1], [0.0])
  218. elif bias_scale.size > 1:
  219. bias_zp_data = np.zeros(bias_scale.shape, dtype=np.int32).reshape(-1)
  220. packed_bias_zp_initializer = onnx.numpy_helper.from_array(bias_zp_data, quantized_bias_zp_name)
  221. else:
  222. packed_bias_zp_initializer = onnx.helper.make_tensor(quantized_bias_zp_name, tensor_type, [], [0])
  223. self.model.initializer_extend([packed_bias_zp_initializer])
  224. return (
  225. quantized_bias_name,
  226. quantized_bias_scale_name,
  227. quantized_bias_zp_name,
  228. bias_scale_data,
  229. node_type,
  230. node_qtype,
  231. )
  232. def quantize_initializer_impl(self, weight, qType, reduce_range=False, keep_float_weight=False):
  233. """
  234. :param weight: TensorProto initializer
  235. :param qType: type to quantize to
  236. :param keep_float_weight: Whether to quantize the weight. In some cases, we only want to qunatize scale and zero point.
  237. If keep_float_weight is False, quantize the weight, or don't quantize the weight.
  238. :return: quantized weight name, zero point name, scale name
  239. """
  240. # TODO(adrianlizarraga): This function is now only used by onnx_quantizer.py, so move it there.
  241. q_weight_name = weight.name + TENSOR_NAME_QUANT_SUFFIX
  242. zp_name = weight.name + "_zero_point"
  243. scale_name = weight.name + "_scale"
  244. # Quantize weight data. Use quantization overrides if provided by the user.
  245. weight_data = tensor_proto_to_array(weight)
  246. quant_overrides = self.tensor_quant_overrides.get_per_tensor_overrides(weight.name, default_val={})
  247. if "quant_type" in quant_overrides:
  248. qType = quant_overrides["quant_type"].tensor_type # noqa: N806
  249. if "scale" in quant_overrides and "zero_point" in quant_overrides:
  250. zero_point = np.array(quant_overrides["zero_point"], dtype=ONNX_TYPE_TO_NP_TYPE[qType])
  251. scale = np.array(quant_overrides["scale"])
  252. q_weight_data = quantize_nparray(qType, weight_data.flatten(), scale, zero_point)
  253. assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
  254. assert zero_point.dtype != np.float32 and zero_point.dtype != np.float16, (
  255. f"Unexpected dtype {zero_point.dtype}"
  256. )
  257. assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
  258. else:
  259. symmetric = self.is_weight_symmetric(qType) if qType == self.weight_qType else self.is_activation_symmetric
  260. zero_point, scale, q_weight_data = quantize_data(
  261. weight_data.flatten(),
  262. qType,
  263. quant_overrides.get("symmetric", symmetric),
  264. reduce_range=quant_overrides.get("reduce_range", self.reduce_range and reduce_range),
  265. min_real_range=self.min_real_range,
  266. rmin_override=quant_overrides.get("rmin"),
  267. rmax_override=quant_overrides.get("rmax"),
  268. )
  269. assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
  270. assert zero_point.dtype != np.float32 and zero_point.dtype != np.float16, (
  271. f"Unexpected dtype {zero_point.dtype}"
  272. )
  273. assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
  274. scale_dtype = weight.data_type
  275. scale_initializer = onnx.helper.make_tensor(scale_name, scale_dtype, [], scale.reshape((-1,)).tolist())
  276. zero_initializer = onnx.helper.make_tensor(zp_name, qType, [], zero_point.reshape((-1,)).tolist())
  277. self.model.initializer_extend([scale_initializer, zero_initializer])
  278. if not keep_float_weight:
  279. if self.weight_qType == onnx.TensorProto.FLOAT8E4M3FN:
  280. q_weight_initializer = onnx.TensorProto()
  281. q_weight_initializer.data_type = self.weight_qType
  282. q_weight_initializer.dims.extend(weight.dims)
  283. q_weight_initializer.name = q_weight_name
  284. # Do not remove .flatten().copy() numpy is not clear about data persistence.
  285. q_weight_initializer.raw_data = q_weight_data.flatten().copy().tobytes()
  286. if to_array_extended is not None:
  287. # This test should not be needed but it helped catch some issues
  288. # with data persistence and tobytes.
  289. check = to_array_extended(q_weight_initializer)
  290. if check.shape != weight_data.shape or check.tobytes() != q_weight_data.tobytes():
  291. raise RuntimeError(
  292. f"The initializer of shape {weight_data.shape} could not be created, expecting "
  293. f"{q_weight_data.tobytes()[:10]}, got {check.tobytes()[:10]} and shape={weight.shape}"
  294. f"\nraw={str(q_weight_initializer)[:200]}."
  295. )
  296. elif qType in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4):
  297. if q_weight_data.dtype not in (np.int8, np.uint8):
  298. raise RuntimeError(
  299. f"Quantized weights for {q_weight_name} must be 8-bit before packing as 4-bit values."
  300. )
  301. # We do not use onnx.helper.pack_float32_to_4bit() due to performance.
  302. # This can be the difference between a large model taking 30 minutes to quantize vs 5 minutes.
  303. packed_data = bytes(pack_bytes_to_4bit(q_weight_data.tobytes()))
  304. # We only use onnx.helper.make_tensor with raw data due to bug: https://github.com/onnx/onnx/pull/6161
  305. q_weight_initializer = onnx.helper.make_tensor(q_weight_name, qType, weight.dims, packed_data, raw=True)
  306. else:
  307. q_weight_data = np.asarray(q_weight_data, dtype=onnx.helper.tensor_dtype_to_np_dtype(qType)).reshape(
  308. weight.dims
  309. )
  310. q_weight_initializer = onnx.numpy_helper.from_array(q_weight_data, q_weight_name)
  311. self.model.initializer_extend([q_weight_initializer])
  312. return q_weight_name, zp_name, scale_name
  313. def quantize_weight_per_channel_impl(
  314. self,
  315. weight_name,
  316. weight_qType,
  317. channel_axis,
  318. reduce_range=True,
  319. keep_float_weight=False,
  320. ):
  321. # TODO(adrianlizarraga): This function is now only used by onnx_quantizer.py, so move it there.
  322. initializer = find_by_name(weight_name, self.model.initializer())
  323. if initializer is None:
  324. raise ValueError("{} is not an initializer", weight_name)
  325. weights = tensor_proto_to_array(initializer)
  326. weights_rank = len(weights.shape)
  327. is_axis_valid, axis_norm = normalize_axis(channel_axis, weights_rank)
  328. if not is_axis_valid:
  329. raise ValueError(
  330. f"Weight {weight_name} has a per-channel axis with value {channel_axis} that is "
  331. f"out-of-bounds for rank {weights_rank}"
  332. )
  333. channel_axis = axis_norm
  334. channel_count = weights.shape[channel_axis]
  335. quant_overrides_for_channels = self.tensor_quant_overrides.get_per_channel_overrides(
  336. weight_name, default_val=[{"axis": channel_axis}]
  337. )
  338. num_channel_overrides = len(quant_overrides_for_channels)
  339. if num_channel_overrides != 1 and num_channel_overrides != channel_count:
  340. raise ValueError(
  341. f"Per-channel tensor quantization overrides for {weight_name} must have "
  342. f"either 1 or {channel_count} elements in the list of dictionaries."
  343. )
  344. is_axis_override_valid, axis_override = normalize_axis(quant_overrides_for_channels[0]["axis"], weights_rank)
  345. if not is_axis_override_valid or axis_override != channel_axis:
  346. raise ValueError(
  347. f"Tensor quantization overrides for {weight_name} specify an unexpected axis. "
  348. f"Expected {channel_axis}, but got {quant_overrides_for_channels[0]['axis']}."
  349. )
  350. # If user provides per-channel quantization overrides, all channels must use the same quant_type,
  351. # axis, symmetric, and reduce_range values. So, just use the first channel's values.
  352. if "quant_type" in quant_overrides_for_channels[0]:
  353. weight_qType = quant_overrides_for_channels[0]["quant_type"].tensor_type # noqa: N806
  354. symmetric = quant_overrides_for_channels[0].get("symmetric", self.is_weight_symmetric(weight_qType))
  355. reduce_range = quant_overrides_for_channels[0].get("reduce_range", self.reduce_range and reduce_range)
  356. zero_point_list = []
  357. scale_list = []
  358. quantized_per_channel_data_list = []
  359. weights_shape = list(weights.shape)
  360. reshape_dims = list(weights_shape) # deep copy
  361. reshape_dims[channel_axis] = 1 # only one per channel for reshape
  362. for i in range(channel_count):
  363. per_channel_data = weights.take(i, channel_axis)
  364. channel_override_index = i if i < num_channel_overrides else 0
  365. channel_quant_overrides = quant_overrides_for_channels[channel_override_index]
  366. if "scale" in channel_quant_overrides and "zero_point" in channel_quant_overrides:
  367. zero_point = np.array(channel_quant_overrides["zero_point"], dtype=ONNX_TYPE_TO_NP_TYPE[weight_qType])
  368. scale = np.array(channel_quant_overrides["scale"])
  369. quantized_per_channel_data = quantize_nparray(
  370. weight_qType, per_channel_data.flatten(), scale, zero_point
  371. )
  372. assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
  373. assert zero_point.dtype != np.float32 and zero_point.dtype != np.float16, (
  374. f"Unexpected dtype {zero_point.dtype}"
  375. )
  376. assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
  377. assert isinstance(quantized_per_channel_data, np.ndarray), (
  378. f"Unexpected type {type(quantized_per_channel_data)}"
  379. )
  380. else:
  381. zero_point, scale, quantized_per_channel_data = quantize_data(
  382. per_channel_data.flatten(),
  383. weight_qType,
  384. symmetric,
  385. reduce_range=reduce_range,
  386. min_real_range=self.min_real_range,
  387. rmin_override=channel_quant_overrides.get("rmin"),
  388. rmax_override=channel_quant_overrides.get("rmax"),
  389. )
  390. assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
  391. assert zero_point.dtype != np.float32 and zero_point.dtype != np.float16, (
  392. f"Unexpected dtype {zero_point.dtype}"
  393. )
  394. assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
  395. assert isinstance(quantized_per_channel_data, np.ndarray), (
  396. f"Unexpected type {type(quantized_per_channel_data)}"
  397. )
  398. zero_point_list.append(zero_point)
  399. scale_list.append(scale)
  400. quantized_per_channel_data_list.append(np.asarray(quantized_per_channel_data).reshape(reshape_dims))
  401. # combine per_channel_data into one
  402. quantized_weights = np.concatenate(quantized_per_channel_data_list, channel_axis)
  403. q_weight_name = weight_name + TENSOR_NAME_QUANT_SUFFIX
  404. zp_name = weight_name + "_zero_point"
  405. scale_name = weight_name + "_scale"
  406. # Update packed weight, zero point, and scale initializers
  407. zero_scale_shape = [initializer.dims[channel_axis]]
  408. scale_initializer = onnx.helper.make_tensor(
  409. scale_name, initializer.data_type, zero_scale_shape, np.hstack(scale_list).tolist()
  410. )
  411. zero_initializer = onnx.helper.make_tensor(
  412. zp_name, weight_qType, zero_scale_shape, np.hstack(zero_point_list).tolist()
  413. )
  414. self.model.initializer_extend([scale_initializer, zero_initializer])
  415. if not keep_float_weight:
  416. if weight_qType in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4):
  417. if quantized_weights.dtype not in (np.int8, np.uint8):
  418. raise RuntimeError(
  419. f"Quantized weights for {q_weight_name} must be 8-bit before packing as 4-bit values."
  420. )
  421. # We do not use onnx.helper.pack_float32_to_4bit() due to performance.
  422. # This can be the difference between a large model taking 30 minutes to quantize vs 5 minutes.
  423. packed_data = bytes(pack_bytes_to_4bit(quantized_weights.tobytes()))
  424. # We only use onnx.helper.make_tensor with raw data due to bug: https://github.com/onnx/onnx/pull/6161
  425. q_weight_initializer = onnx.helper.make_tensor(
  426. q_weight_name, weight_qType, weights_shape, packed_data, raw=True
  427. )
  428. self.model.initializer_extend([q_weight_initializer])
  429. else:
  430. quantized_weights = np.asarray(
  431. quantized_weights,
  432. dtype=onnx.helper.tensor_dtype_to_np_dtype(weight_qType),
  433. ).reshape(initializer.dims)
  434. q_weight_initializer = onnx.numpy_helper.from_array(quantized_weights, q_weight_name)
  435. self.model.initializer_extend([q_weight_initializer])
  436. return q_weight_name, zp_name, scale_name
  437. def adjust_tensor_ranges(self):
  438. if self.tensors_range is None:
  439. return
  440. for node in self.model.nodes():
  441. # adjust tensor_ranges for input of Clip and Relu node
  442. if node.op_type in ["Clip", "Relu"]:
  443. if not self.should_quantize_node(node):
  444. continue
  445. if len(self.model.input_name_to_nodes()[node.input[0]]) != 1:
  446. continue
  447. if node.input[0] not in self.tensors_range or node.output[0] not in self.tensors_range:
  448. continue
  449. td = self.tensors_range[node.output[0]]
  450. if not isinstance(td, TensorData):
  451. raise TypeError(f"Unexpected type {type(td)} for {node.output[0]!r}.")
  452. self.tensors_range[node.input[0]] = td
  453. # Adjust Softmax to range from 0.0 to 1.0
  454. elif node.op_type == "Softmax":
  455. if not self.should_quantize_node(node):
  456. continue
  457. self.tensors_range[node.output[0]] = TensorData(lowest=np.float32(0.0), highest=np.float32(1.0))