onnx_quantizer.py 48 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import logging
  7. import numpy as np
  8. import onnx
  9. import onnx.numpy_helper
  10. from onnx import onnx_pb as onnx_proto
  11. from .base_quantizer import BaseQuantizer, QuantizationParams
  12. from .calibrate import TensorData
  13. from .onnx_model import ONNXModel
  14. from .quant_utils import (
  15. TENSOR_NAME_QUANT_SUFFIX,
  16. QuantizationMode,
  17. QuantizedValue,
  18. QuantizedValueType,
  19. __producer__,
  20. __version__,
  21. add_infer_metadata,
  22. attribute_to_kwarg,
  23. compute_scale_zp,
  24. compute_scale_zp_float8,
  25. find_by_name,
  26. get_qmin_qmax_for_qType,
  27. get_qrange_for_qType,
  28. ms_domain,
  29. quantize_onnx_initializer,
  30. save_and_reload_model_with_shape_infer,
  31. tensor_proto_to_array,
  32. )
  33. from .registry import CreateOpQuantizer
  34. class ONNXQuantizer(BaseQuantizer):
  35. def __init__(
  36. self,
  37. model,
  38. per_channel,
  39. reduce_range,
  40. mode,
  41. static,
  42. weight_qType,
  43. activation_qType,
  44. tensors_range,
  45. nodes_to_quantize,
  46. nodes_to_exclude,
  47. op_types_to_quantize,
  48. extra_options=None,
  49. ):
  50. BaseQuantizer.__init__(
  51. self,
  52. model,
  53. per_channel,
  54. reduce_range,
  55. weight_qType,
  56. activation_qType,
  57. tensors_range,
  58. nodes_to_quantize,
  59. nodes_to_exclude,
  60. op_types_to_quantize,
  61. extra_options,
  62. )
  63. if not static:
  64. self.model.replace_gemm_with_matmul()
  65. # We need to update value_infos.
  66. model = save_and_reload_model_with_shape_infer(self.model.model)
  67. self.value_infos = {vi.name: vi for vi in model.graph.value_info}
  68. self.value_infos.update({ot.name: ot for ot in model.graph.output})
  69. self.value_infos.update({it.name: it for it in model.graph.input})
  70. self.model = ONNXModel(model)
  71. self.mode = mode # QuantizationMode.Value
  72. self.static = static # use static quantization for inputs.
  73. self.fuse_dynamic_quant = self.opset_version > 10
  74. self.q_matmul_const_b_only = "MatMulConstBOnly" in self.extra_options and self.extra_options["MatMulConstBOnly"]
  75. self.new_nodes = []
  76. self.graph_scope = "/" # for human readable debug information
  77. self.tensor_names = {} # in case the shape inference not totally working
  78. self.tensor_names.update({ot.name: 1 for ot in model.graph.output})
  79. self.tensor_names.update({it.name: 1 for it in model.graph.input})
  80. for node in self.model.model.graph.node:
  81. self.tensor_names.update(dict.fromkeys(node.output, 1))
  82. if self.mode not in QuantizationMode:
  83. raise ValueError(f"unsupported quantization mode {self.mode}")
  84. self.quantization_params = self.calculate_quantization_params()
  85. # QuantizeRange tensor name and zero tensor name for scale and zero point calculation.
  86. # Used when static is False
  87. self.fixed_qrange_uint8_name = "fixed_quantization_range_uint8"
  88. self.fixed_qrange_int8_name = "fixed_quantization_range_int8"
  89. # For uint8 data-type, to compute zero point, we subtract rmin from 0 (represented by fixed_zero_name tensor)
  90. self.fixed_zero_name = "fixed_zero"
  91. # For int8 data-type, zero point is always zero (respresented by fixed_zero_point_name tensor)
  92. self.fixed_zero_zp_name = "fixed_zero_zp"
  93. # Map of all original value names to quantized value names
  94. self.quantized_value_map = {}
  95. # some output from nodes will be quantized, yet itself should be treat as existing so
  96. # no dequantized will be applied when needed later
  97. self.generated_value_names = self.model.get_non_initializer_inputs()
  98. # routines for subgraph support
  99. def quantize_subgraph(self, subgraph, graph_key):
  100. """
  101. generate submodel for the subgraph, so that we re-utilize current quantization implementation.
  102. quantize the submodel
  103. update subgraph and set it back to node
  104. """
  105. warped_model = onnx.helper.make_model(
  106. subgraph,
  107. producer_name="onnx-quantizer",
  108. opset_imports=self.model.model.opset_import,
  109. )
  110. add_infer_metadata(warped_model)
  111. sub_quantizer = ONNXQuantizer(
  112. warped_model,
  113. self.per_channel,
  114. self.reduce_range,
  115. self.mode,
  116. self.static,
  117. self.weight_qType,
  118. self.activation_qType,
  119. self.tensors_range,
  120. self.nodes_to_quantize,
  121. self.nodes_to_exclude,
  122. self.op_types_to_quantize,
  123. self.extra_options,
  124. )
  125. sub_quantizer.parent = self
  126. sub_quantizer.graph_scope = f"{self.graph_scope}{graph_key}/"
  127. sub_quantizer.quantize_model()
  128. return sub_quantizer.model.model.graph
  129. def quantize_node_with_sub_graph(self, node):
  130. """
  131. Check subgraph, if any, quantize it and replace it.
  132. return new_nodes added for quantizing subgraph
  133. """
  134. graph_attrs = [
  135. attr
  136. for attr in node.attribute
  137. if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS
  138. ]
  139. if len(graph_attrs) == 0:
  140. return node
  141. node_name = node.name if node.name else f"{node.op_type}_node_count_{len(self.new_nodes)}"
  142. kwargs = {}
  143. for attr in node.attribute:
  144. if attr.type == onnx.AttributeProto.GRAPH:
  145. kv = {attr.name: self.quantize_subgraph(attr.g, f"{node_name}:{attr.name}")}
  146. elif attr.type == onnx.AttributeProto.GRAPHS:
  147. value = []
  148. for subgraph in attr.graphs:
  149. value.extend(
  150. [
  151. self.quantize_subgraph(
  152. subgraph,
  153. f"{node_name}:{attr.name}:{len(value)}",
  154. )
  155. ]
  156. )
  157. kv = {attr.name: value}
  158. else:
  159. kv = attribute_to_kwarg(attr)
  160. kwargs.update(kv)
  161. return onnx.helper.make_node(node.op_type, node.input, node.output, name=node.name, **kwargs)
  162. def has_QDQ_nodes(self): # noqa: N802
  163. """
  164. Detect if model already has QuantizeLinear or DequantizeLinear.
  165. """
  166. return any(
  167. node.op_type == "QuantizeLinear" or node.op_type == "DequantizeLinear" for node in self.model.nodes()
  168. )
  169. def find_initializer_in_path(self, initializer_name):
  170. if find_by_name(initializer_name, self.model.initializer()) is not None:
  171. return True
  172. if self.parent is not None:
  173. return self.parent.find_initializer_in_path(initializer_name)
  174. return False
  175. def add_new_nodes(self, nodes):
  176. self.new_nodes.extend(nodes)
  177. for node in nodes:
  178. for output_name in node.output:
  179. self.generated_value_names.add(output_name)
  180. def quantize_model(self):
  181. if self.has_QDQ_nodes():
  182. logging.warning(
  183. "Please check if the model is already quantized. "
  184. "Note you don't need to quantize a QAT model. OnnxRuntime support to run QAT model directly."
  185. )
  186. for node in self.model.nodes():
  187. # quantize subgraphes if have
  188. if self.enable_subgraph_quantization:
  189. node = self.quantize_node_with_sub_graph(node) # noqa: PLW2901
  190. number_of_existing_new_nodes = len(self.new_nodes)
  191. op_quantizer = CreateOpQuantizer(self, node)
  192. op_quantizer.quantize()
  193. for i in range(number_of_existing_new_nodes, len(self.new_nodes)):
  194. for output_name in self.new_nodes[i].output:
  195. self.generated_value_names.add(output_name)
  196. self._dequantize_outputs()
  197. # extend is used to append to the list for a protobuf fields
  198. # https://developers.google.com/protocol-buffers/docs/reference/python-generated?csw=1#fields
  199. self.model.graph().ClearField("node")
  200. self.model.graph().node.extend(self.new_nodes)
  201. # Remove ununsed initializers from graph, starting from the top level graph.
  202. if self.parent is None:
  203. _, initializers_not_found = self.model.clean_initializers()
  204. if len(initializers_not_found) > 0:
  205. raise RuntimeError("Invalid model with unknown initializers/tensors." + str(initializers_not_found))
  206. self.model.model.producer_name = __producer__
  207. self.model.model.producer_version = __version__
  208. # Add ms domain if needed
  209. ms_opset = [opset for opset in self.model.model.opset_import if opset.domain == ms_domain]
  210. if not ms_opset:
  211. ms_nodes = [node for node in self.new_nodes if node.domain == "com.microsoft"]
  212. if ms_nodes:
  213. opset = self.model.model.opset_import.add()
  214. opset.version = 1
  215. opset.domain = ms_domain
  216. return self.model.model
  217. def _get_default_tensor_type(self, tensor_name):
  218. if "DefaultTensorType" in self.extra_options:
  219. logging.info(
  220. "get_tensor_type returns DefaultTensorType for tensor name %r, use %d",
  221. tensor_name,
  222. self.extra_options["DefaultTensorType"],
  223. )
  224. return self.extra_options["DefaultTensorType"]
  225. raise RuntimeError(
  226. f"Unable to find data type for weight_name={tensor_name!r}. "
  227. f"shape_inference failed to return a type probably this node is "
  228. f"from a different domain or using an input produced by such an operator. "
  229. f"This may happen if you quantize a model already quantized. "
  230. f"You may use extra_options `DefaultTensorType` to indicate "
  231. f"the default weight type, usually `onnx.TensorProto.FLOAT`."
  232. )
  233. def get_tensor_type(self, tensor_name, mandatory=False):
  234. weight = find_by_name(tensor_name, self.model.initializer())
  235. if weight is not None:
  236. return weight.data_type
  237. if tensor_name in self.value_infos:
  238. vi = self.value_infos[tensor_name]
  239. if vi.type.HasField("tensor_type"):
  240. if mandatory and vi.type.tensor_type.elem_type == 0:
  241. return self._get_default_tensor_type(tensor_name)
  242. return vi.type.tensor_type.elem_type
  243. if (not self.enable_subgraph_quantization) or (self.parent is None):
  244. if mandatory:
  245. return self._get_default_tensor_type(tensor_name)
  246. return None
  247. otype = self.parent.is_valid_quantize_weight(tensor_name)
  248. if otype is not None:
  249. return otype
  250. if self.enable_subgraph_quantization and self.parent:
  251. res = self.parent.get_tensor_type(tensor_name)
  252. if res is not None:
  253. return res
  254. if mandatory:
  255. return self._get_default_tensor_type(tensor_name)
  256. return None
  257. def is_float_tensor(self, tensor_name):
  258. if self.is_input_a_initializer(tensor_name):
  259. return self.is_valid_quantize_weight(tensor_name)
  260. if tensor_name in self.value_infos:
  261. vi = self.value_infos[tensor_name]
  262. if vi.type.HasField("tensor_type") and vi.type.tensor_type.elem_type in (
  263. onnx_proto.TensorProto.FLOAT,
  264. onnx_proto.TensorProto.FLOAT16,
  265. ):
  266. return True
  267. logging.warning(
  268. f"Inference failed or unsupported type to quantize for tensor {tensor_name!r}, type is {vi.type}."
  269. )
  270. return False
  271. if self.enable_subgraph_quantization and self.parent:
  272. return self.parent.is_float_tensor(tensor_name)
  273. logging.warning(
  274. f"Failed to infer data type of tensor: {tensor_name!r}. Please add data type info for this tensor "
  275. f"if your model has customized operators."
  276. )
  277. return False
  278. def _get_dynamic_input_quantization_params(self, input_name, nodes_list, qType, initial_type):
  279. """
  280. Create nodes for dynamic quantization of input and add them to nodes_list.
  281. parameter input_name: Name of the input.
  282. parameter nodes_list: new nodes are appended to this list.
  283. parameter qType: type to quantize to.
  284. parameter initial_type: type to quantize from
  285. return: scale_name, zero_point_name, scale_shape, zero_point_shape.
  286. """
  287. if qType == onnx_proto.TensorProto.INT8:
  288. return self._get_dynamic_input_quantization_params_int8(input_name, nodes_list, initial_type)
  289. if qType == onnx_proto.TensorProto.UINT8:
  290. return self._get_dynamic_input_quantization_params_uint8(input_name, nodes_list, initial_type)
  291. raise ValueError(f"Unexpected value for qType={qType}.")
  292. def _get_dynamic_input_quantization_params_int8(self, input_name, nodes_list, initial_type):
  293. """
  294. Create nodes for dynamic quantization of input to int8 and add them to nodes_list
  295. parameter input_name: Name of the input.
  296. parameter nodes_list: new nodes are appended to this list.
  297. parameter initial_type: initial weight type (FLOAT or FLOAT16)
  298. return: scale_name, zero_point_name, scale_shape, zero_point_shape.
  299. """
  300. qType = onnx_proto.TensorProto.INT8 # noqa: N806
  301. # Reduce min and Reduce max
  302. input_scale_name = input_name + "_scale"
  303. reduce_min_name = input_name + "_ReduceMin"
  304. reduce_min_node = onnx.helper.make_node(
  305. "ReduceMin",
  306. [input_name],
  307. [reduce_min_name + ":0"],
  308. reduce_min_name,
  309. keepdims=0,
  310. )
  311. nodes_list.append(reduce_min_node)
  312. reduce_max_name = input_name + "_ReduceMax"
  313. reduce_max_node = onnx.helper.make_node(
  314. "ReduceMax",
  315. [input_name],
  316. [reduce_max_name + ":0"],
  317. reduce_max_name,
  318. keepdims=0,
  319. )
  320. nodes_list.append(reduce_max_node)
  321. # Compute scale
  322. # Find abs(rmin)
  323. reduce_min_abs_name = reduce_min_name + "_Abs"
  324. reduce_min_abs_node = onnx.helper.make_node(
  325. "Abs",
  326. [reduce_min_node.output[0]],
  327. [reduce_min_abs_name + ":0"],
  328. reduce_min_abs_name,
  329. )
  330. nodes_list.append(reduce_min_abs_node)
  331. # Find abs(rmax)
  332. reduce_max_abs_name = reduce_max_name + "_Abs"
  333. reduce_max_abs_node = onnx.helper.make_node(
  334. "Abs",
  335. [reduce_max_node.output[0]],
  336. [reduce_max_abs_name + ":0"],
  337. reduce_max_abs_name,
  338. )
  339. nodes_list.append(reduce_max_abs_node)
  340. # Compute max of abs(rmin) and abs(rmax)
  341. abs_max_name = input_name + "_Abs_Max"
  342. abs_max_node = onnx.helper.make_node(
  343. "Max",
  344. [reduce_min_abs_node.output[0], reduce_max_abs_node.output[0]],
  345. [abs_max_name + ":0"],
  346. abs_max_name,
  347. )
  348. nodes_list.append(abs_max_node)
  349. # and divide by (quantize_range/2.0) which will be equal to max(...)*2.0/quantize_range
  350. initializer_div = onnx.helper.make_tensor(
  351. self.fixed_qrange_int8_name,
  352. initial_type,
  353. [],
  354. [get_qrange_for_qType(qType) / 2.0],
  355. )
  356. self.model.add_initializer(initializer_div)
  357. scale_div_name = input_name + "scale_Div"
  358. scale_div_node = onnx.helper.make_node(
  359. "Div",
  360. [abs_max_node.output[0], self.fixed_qrange_int8_name],
  361. [input_scale_name],
  362. scale_div_name,
  363. )
  364. nodes_list.append(scale_div_node)
  365. # Zero point
  366. initializer_zp = onnx.helper.make_tensor(self.fixed_zero_zp_name, qType, [], [0])
  367. self.model.add_initializer(initializer_zp)
  368. return input_scale_name, self.fixed_zero_zp_name, [], []
  369. def _get_dynamic_input_quantization_params_uint8(self, input_name, nodes_list, initial_type):
  370. """
  371. Create nodes for dynamic quantization of input to uint8 and add them to nodes_list
  372. parameter input_name: Name of the input.
  373. parameter nodes_list: new nodes are appended to this list.
  374. parameter initial_type: initial weight type (FLAOT or FLOAT16)
  375. return: scale_name, zero_point_name, scale_shape, zero_point_shape.
  376. """
  377. qType = onnx_proto.TensorProto.UINT8 # noqa: N806
  378. # Reduce min and Reduce max
  379. input_scale_name = input_name + "_scale"
  380. input_zp_name = input_name + "_zero_point"
  381. reduce_min_name = input_name + "_ReduceMin"
  382. reduce_min_node = onnx.helper.make_node(
  383. "ReduceMin",
  384. [input_name],
  385. [reduce_min_name + ":0"],
  386. reduce_min_name,
  387. keepdims=0,
  388. )
  389. nodes_list.append(reduce_min_node)
  390. reduce_max_name = input_name + "_ReduceMax"
  391. reduce_max_node = onnx.helper.make_node(
  392. "ReduceMax",
  393. [input_name],
  394. [reduce_max_name + ":0"],
  395. reduce_max_name,
  396. keepdims=0,
  397. )
  398. nodes_list.append(reduce_max_node)
  399. # Add tensors for quantize range and zero value.
  400. initializer_qrange = onnx.helper.make_tensor(
  401. self.fixed_qrange_uint8_name,
  402. initial_type,
  403. [],
  404. [get_qrange_for_qType(qType)],
  405. )
  406. self.model.add_initializer(initializer_qrange)
  407. initializer_qvalue = onnx.helper.make_tensor(self.fixed_zero_name, initial_type, [], [0.0])
  408. self.model.add_initializer(initializer_qvalue)
  409. # Compute Scale
  410. # Subtract rmax and rmin
  411. scale_sub_name = input_name + "_scale_Sub"
  412. scale_sub_node = onnx.helper.make_node(
  413. "Sub",
  414. [reduce_max_node.output[0], reduce_min_node.output[0]],
  415. [scale_sub_name + ":0"],
  416. scale_sub_name,
  417. )
  418. nodes_list.append(scale_sub_node)
  419. # and divide by quantize range
  420. scale_div_name = input_name + "_scale_Div"
  421. scale_div_node = onnx.helper.make_node(
  422. "Div",
  423. [scale_sub_node.output[0], self.fixed_qrange_uint8_name],
  424. [input_scale_name],
  425. scale_div_name,
  426. )
  427. nodes_list.append(scale_div_node)
  428. # Compute zero point
  429. # Subtract zero and rmin
  430. zp_sub_name = input_name + "_zero_point_Sub"
  431. zp_sub_node = onnx.helper.make_node(
  432. "Sub",
  433. [self.fixed_zero_name, reduce_min_node.output[0]],
  434. [zp_sub_name + ":0"],
  435. zp_sub_name,
  436. )
  437. nodes_list.append(zp_sub_node)
  438. # Divide by scale
  439. zp_div_name = input_name + "_zero_point_Div"
  440. zp_div_node = onnx.helper.make_node(
  441. "Div",
  442. [zp_sub_node.output[0], input_scale_name],
  443. [zp_div_name + ":0"],
  444. zp_div_name,
  445. )
  446. nodes_list.append(zp_div_node)
  447. # Compute floor
  448. zp_floor_name = input_name + "_zero_point_Floor"
  449. zp_floor_node = onnx.helper.make_node("Floor", zp_div_node.output, [zp_floor_name + ":0"], zp_floor_name)
  450. nodes_list.append(zp_floor_node)
  451. # Cast to integer
  452. zp_cast_name = input_name + "_zero_point_Cast"
  453. zp_cast_node = onnx.helper.make_node("Cast", zp_floor_node.output, [input_zp_name], zp_cast_name, to=qType)
  454. nodes_list.append(zp_cast_node)
  455. return input_scale_name, input_zp_name, [], []
  456. def _get_quantization_params(self, param_name, use_scale=None, use_zeropoint=None):
  457. """
  458. Create initializers and inputs in the graph for zero point and scale of output.
  459. Zero point and scale values are obtained from self.quantization_params if specified.
  460. parameter param_name: Name of the quantization parameter.
  461. return: result, scale_name, zero_point_name, scale_shape, zero_point_shape.
  462. """
  463. zero_point_type = self.activation_qType
  464. if use_scale is None or use_zeropoint is None:
  465. if self.quantization_params is None or param_name not in self.quantization_params:
  466. logging.info(f'Quantization parameters for tensor:"{param_name}" not specified')
  467. return False, "", "", "", ""
  468. params = self.quantization_params[param_name]
  469. if not isinstance(params, QuantizationParams):
  470. raise TypeError(f"Unexpected type {type(params)} for {param_name!r}.")
  471. if params is None or len(params) != 3:
  472. raise ValueError(
  473. "Quantization parameters should contain zero point, scale, quant type. "
  474. f"Specified values for output {param_name}: {params}"
  475. )
  476. zero_point_values = np.array([params["zero_point"]])
  477. if not hasattr(params["scale"], "dtype") or params["scale"].dtype not in (np.float32, np.float16):
  478. raise ValueError(f"Unexpected type {type(params['scale'])} and param_name={param_name!r}")
  479. scale_values = np.array([params["scale"]])
  480. assert scale_values.dtype != np.float64
  481. zero_point_type = params["quant_type"]
  482. else:
  483. zero_point_values = np.array([use_zeropoint])
  484. scale_values = np.array([use_scale])
  485. params = self.quantization_params[param_name]
  486. if "scale" in params:
  487. dtype = params["scale"].dtype
  488. scale_values = scale_values.astype(dtype)
  489. assert scale_values.dtype != np.float64
  490. zero_point_shape = []
  491. zero_point_name = param_name + "_zero_point"
  492. scale_shape = []
  493. scale_name = param_name + "_scale"
  494. # Add initializers
  495. init_zp = onnx.helper.make_tensor(
  496. zero_point_name, zero_point_type, zero_point_shape, zero_point_values.ravel().tolist()
  497. )
  498. self.model.add_initializer(init_zp)
  499. if scale_values.dtype == np.float32:
  500. scale_type = onnx_proto.TensorProto.FLOAT
  501. elif scale_values.dtype == np.float16:
  502. scale_type = onnx_proto.TensorProto.FLOAT16
  503. else:
  504. raise ValueError(f"Unexpected dtype={scale_values.dtype} for param_name={param_name!r}")
  505. init_scale = onnx.helper.make_tensor(scale_name, scale_type, scale_shape, scale_values.reshape((-1,)).tolist())
  506. self.model.add_initializer(init_scale)
  507. return True, scale_name, zero_point_name, scale_shape, zero_point_shape
  508. def _get_quantize_input_nodes(
  509. self, node, input_index, qType, given_scale_name=None, given_zp_name=None, initial_type=None
  510. ):
  511. """
  512. Given an input for a node (which is not a initializer), this function
  513. - add nodes to compute zero point and scale for this input if they don't exist.
  514. - add new QuantizeLinear node to quantize the input.
  515. :param node: node being quantized in NodeProto format.
  516. :param input_index: index of input in node.input.
  517. :param qType: type to quantize to.
  518. :param given_scale_name: if those inputs need to be quanitzed using this scale tensor.
  519. :param given_zp_name: if those inputs to be quantized using this zeropoint tensor.
  520. :param initial_type: type of the weight to quantize
  521. :return: List of newly created nodes in NodeProto format.
  522. """
  523. input_name = node.input[input_index]
  524. assert input_name != "", "Cannot access undefined variable in graph."
  525. output_name = input_name + TENSOR_NAME_QUANT_SUFFIX
  526. ql_node_name = input_name + "_QuantizeLinear"
  527. if (given_scale_name is not None) and (given_zp_name is not None):
  528. data_found, scale_name, zp_name = (True, given_scale_name, given_zp_name)
  529. else:
  530. data_found, scale_name, zp_name, _, _ = self._get_quantization_params(input_name)
  531. nodes = []
  532. if data_found:
  533. qlinear_node = onnx.helper.make_node(
  534. "QuantizeLinear",
  535. [input_name, scale_name, zp_name],
  536. [output_name],
  537. ql_node_name,
  538. )
  539. else:
  540. if self.static:
  541. return None
  542. # dynamic mode
  543. # Scale and Zero Points not available for this input. Add nodes to dynamically compute it
  544. if self.fuse_dynamic_quant and qType == onnx_proto.TensorProto.UINT8:
  545. scale_name = input_name + "_scale"
  546. zp_name = input_name + "_zero_point"
  547. qlinear_node = onnx.helper.make_node(
  548. "DynamicQuantizeLinear",
  549. [input_name],
  550. [output_name, scale_name, zp_name],
  551. ql_node_name,
  552. )
  553. else:
  554. assert initial_type is not None, (
  555. f"Cannot quantize input without knowing the initial type, "
  556. f"input_name={input_name!r}, input_index={input_index}, qType={qType}, node={node}"
  557. )
  558. (
  559. scale_name,
  560. zp_name,
  561. scale_shape,
  562. zp_shape,
  563. ) = self._get_dynamic_input_quantization_params(input_name, nodes, qType, initial_type=initial_type)
  564. qlinear_node = onnx.helper.make_node(
  565. "QuantizeLinear",
  566. [input_name, scale_name, zp_name],
  567. [output_name],
  568. ql_node_name,
  569. )
  570. self.quantized_value_map[input_name] = QuantizedValue(input_name, output_name, scale_name, zp_name, qType)
  571. return [*nodes, qlinear_node]
  572. def find_quantized_value(self, input_name):
  573. if input_name in self.quantized_value_map:
  574. return self.quantized_value_map[input_name]
  575. if self.parent is not None:
  576. return self.parent.find_quantized_value(input_name)
  577. return None
  578. def adjust_single_weight_scale_if_needed(
  579. self,
  580. bias_val,
  581. input_scale,
  582. weight_scale,
  583. weight_scale_dtype,
  584. weight_name,
  585. bias_name,
  586. qrange,
  587. multiplicative_epsilon,
  588. idx=None,
  589. ):
  590. """Adjust a single weight scale to ensure the int32 bias does not overflow."""
  591. absmax = np.abs(bias_val)
  592. bias_smallest_valid_scale = multiplicative_epsilon * (2.0 * absmax) / qrange
  593. input_scale_fp64 = np.array(input_scale.item(), dtype=np.float64)
  594. weight_scale_fp64 = np.array(weight_scale.item(), dtype=np.float64)
  595. bias_candidate_scale = input_scale_fp64 * weight_scale_fp64
  596. if (bias_candidate_scale < bias_smallest_valid_scale) and (bias_candidate_scale > 0.0):
  597. ratio = bias_smallest_valid_scale / bias_candidate_scale
  598. new_scale = weight_scale_fp64 * ratio
  599. if idx is None:
  600. logging.info(
  601. f"Increasing scale for weight `{weight_name}` by the ratio {ratio} to "
  602. f"ensure bias `{bias_name}` has a valid scale."
  603. )
  604. return True, np.array(new_scale, dtype=weight_scale_dtype)
  605. else:
  606. logging.info(
  607. f"Increased scale[{idx}] for weight `{weight_name}` by ratio {ratio} "
  608. f"to ensure bias `{bias_name}` has a valid scale."
  609. )
  610. return True, new_scale.astype(weight_scale_dtype)
  611. return False, weight_scale
  612. def _adjust_weight_scale_for_int32_bias(
  613. self,
  614. input_scale: np.ndarray,
  615. weight_scale: np.ndarray,
  616. weight_name: str,
  617. bias_tp: onnx.TensorProto,
  618. is_per_channel: bool,
  619. ) -> tuple[bool, np.ndarray | None]:
  620. """Checks if the bias scale is too small and increases the weight scale if needed."""
  621. if not weight_scale.size:
  622. return False, None
  623. bias_float_data = tensor_proto_to_array(bias_tp)
  624. int32_info = np.iinfo(np.int32)
  625. multiplicative_epsilon = 1.0001
  626. qrange = np.array(int32_info.max, dtype=np.float64) - np.array(int32_info.min + 1, dtype=np.float64)
  627. weight_scale_dtype = weight_scale.dtype
  628. updated = False
  629. if not is_per_channel:
  630. rmin = np.minimum(bias_float_data.min(), np.array(0, dtype=np.float64))
  631. rmax = np.maximum(bias_float_data.max(), np.array(0, dtype=np.float64))
  632. absmax = np.maximum(np.abs(rmin), np.abs(rmax))
  633. changed, new_scale = self.adjust_single_weight_scale_if_needed(
  634. absmax,
  635. input_scale,
  636. weight_scale,
  637. weight_scale_dtype,
  638. weight_name,
  639. bias_tp.name,
  640. qrange,
  641. multiplicative_epsilon,
  642. )
  643. if changed:
  644. weight_scale = new_scale
  645. updated = True
  646. elif weight_scale.shape and len(weight_scale.shape) == 1:
  647. for i in range(weight_scale.shape[0]):
  648. changed, new_scale = self.adjust_single_weight_scale_if_needed(
  649. bias_float_data[i],
  650. input_scale,
  651. weight_scale[i],
  652. weight_scale_dtype,
  653. weight_name,
  654. bias_tp.name,
  655. qrange,
  656. multiplicative_epsilon,
  657. idx=i,
  658. )
  659. if changed:
  660. weight_scale[i] = new_scale
  661. updated = True
  662. return updated, weight_scale
  663. def _requantize_weight(self, weight_name: str, new_scale: np.ndarray) -> None:
  664. """Re-quantizes the given weight initializer using the provided scale."""
  665. if weight_name not in self.quantized_value_map:
  666. return
  667. qv = self.quantized_value_map[weight_name]
  668. weight_tp = find_by_name(weight_name, self.model.initializer())
  669. scale_init = find_by_name(qv.scale_name, self.model.initializer())
  670. zp_init = find_by_name(qv.zp_name, self.model.initializer())
  671. q_weight_init = find_by_name(qv.q_name, self.model.initializer())
  672. if weight_tp is None or scale_init is None or zp_init is None or q_weight_init is None:
  673. return
  674. self.model.remove_initializer(scale_init)
  675. self.model.remove_initializer(q_weight_init)
  676. weight_zero_point = onnx.numpy_helper.to_array(zp_init)
  677. axis = qv.axis
  678. # Add new scale initializer
  679. scale_np = np.asarray(new_scale, dtype=onnx.helper.tensor_dtype_to_np_dtype(weight_tp.data_type))
  680. new_scale_init = onnx.numpy_helper.from_array(scale_np.reshape(scale_init.dims), qv.scale_name)
  681. self.model.add_initializer(new_scale_init)
  682. # Add new quantized weight initializer
  683. new_q_weight = quantize_onnx_initializer(
  684. weight_tp,
  685. self.weight_qType,
  686. weight_zero_point,
  687. scale_np,
  688. axis,
  689. quant_weight_name=qv.q_name,
  690. )
  691. self.model.add_initializer(new_q_weight)
  692. def quantize_bias_static(self, bias_name, input_name, weight_name, beta=1.0):
  693. """
  694. Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale
  695. """
  696. # Handle case where bias already in quantization map
  697. if bias_name in self.quantized_value_map:
  698. return self.quantized_value_map[bias_name].q_name
  699. # get scale for weight
  700. weight_scale_name = self.quantized_value_map[weight_name].scale_name
  701. weight_initializer = find_by_name(weight_scale_name, self.model.initializer())
  702. weight_scale = tensor_proto_to_array(weight_initializer)
  703. # get scale for input
  704. if input_name in self.quantized_value_map:
  705. input_scale_name = self.quantized_value_map[input_name].scale_name
  706. elif input_name in self.quantization_params:
  707. _, input_scale_name, _, _, _ = self._get_quantization_params(input_name)
  708. else:
  709. raise ValueError(f"Expected {input_name} to be in quantized value map for static quantization")
  710. inputscale_initializer = find_by_name(input_scale_name, self.model.initializer())
  711. input_scale = tensor_proto_to_array(inputscale_initializer)
  712. # Adjust weight scale if quantizing to int32 may overflow due to a small scale
  713. weight_zp_name = self.quantized_value_map[weight_name].zp_name
  714. weight_zp_init = find_by_name(weight_zp_name, self.model.initializer())
  715. weight_zero_point = onnx.numpy_helper.to_array(weight_zp_init) if weight_zp_init is not None else None
  716. is_per_channel = self.per_channel
  717. if (
  718. weight_zero_point is not None
  719. and weight_zero_point.size
  720. and not weight_zero_point.any()
  721. and self.weight_qType in (onnx_proto.TensorProto.INT8,)
  722. ):
  723. bias_initializer = find_by_name(bias_name, self.model.initializer())
  724. did_update, new_weight_scale = self._adjust_weight_scale_for_int32_bias(
  725. input_scale,
  726. weight_scale,
  727. weight_name,
  728. bias_initializer,
  729. is_per_channel,
  730. )
  731. if did_update:
  732. self._requantize_weight(weight_name, new_weight_scale)
  733. weight_scale = new_weight_scale
  734. (
  735. quantized_bias_name,
  736. quantized_bias_scale_name,
  737. quantized_bias_zp_name,
  738. bias_scale_data,
  739. node_type,
  740. node_qtype,
  741. ) = self.quantize_bias_static_impl(bias_name, input_scale, weight_scale, beta)
  742. assert bias_name not in self.quantized_value_map
  743. quantized_value = QuantizedValue(
  744. bias_name,
  745. quantized_bias_name,
  746. quantized_bias_scale_name,
  747. quantized_bias_zp_name,
  748. QuantizedValueType.Initializer,
  749. 0 if bias_scale_data.size > 1 else None,
  750. node_type=node_type,
  751. node_qtype=node_qtype,
  752. )
  753. self.quantized_value_map[bias_name] = quantized_value
  754. return quantized_bias_name
  755. def contains_tensor(self, tensor_name):
  756. """
  757. only check for value info and newly generated tensor names, initializers are checked separately
  758. """
  759. return (
  760. (tensor_name in self.value_infos)
  761. or (tensor_name in self.tensor_names)
  762. or (tensor_name in self.generated_value_names)
  763. )
  764. def quantize_activation(self, node, indices, from_subgraph=False):
  765. return self.__quantize_inputs(
  766. node=node,
  767. indices=indices,
  768. initializer_use_weight_qType=False,
  769. reduce_range=False,
  770. op_level_per_channel=False,
  771. axis=-1,
  772. from_subgraph=from_subgraph,
  773. )
  774. # In some circumstances a weight is not an initializer, for example of MatMul, if both A and B are not
  775. # initializer, B can still be considered as Weight
  776. def quantize_weight(
  777. self,
  778. node,
  779. indices,
  780. reduce_range=False,
  781. op_level_per_channel=False,
  782. axis=-1,
  783. from_subgraph=False,
  784. ):
  785. return self.__quantize_inputs(
  786. node=node,
  787. indices=indices,
  788. initializer_use_weight_qType=True,
  789. reduce_range=reduce_range,
  790. op_level_per_channel=op_level_per_channel,
  791. axis=axis,
  792. from_subgraph=from_subgraph,
  793. )
  794. def __quantize_inputs(
  795. self,
  796. node,
  797. indices,
  798. initializer_use_weight_qType=True,
  799. reduce_range=False,
  800. op_level_per_channel=False,
  801. axis=-1,
  802. from_subgraph=False,
  803. ):
  804. """
  805. Given a node, this function quantizes the inputs as follows:
  806. - If input is an initializer, quantize the initializer data, replace old initializer
  807. with new initializer
  808. - Else, add QuantizeLinear nodes to perform quantization
  809. parameter node: node being quantized in NodeProto format.
  810. parameter indices: input indices to quantize.
  811. return: (List of quantized input names,
  812. List of zero point names used for input quantization,
  813. List of scale names used for input quantization,
  814. List of new QuantizeLinear nodes created)
  815. """
  816. scale_names = []
  817. zero_point_names = []
  818. quantized_input_names = []
  819. nodes = []
  820. for input_index in indices:
  821. node_input = node.input[input_index]
  822. # Find if this input is already quantized
  823. if node_input in self.quantized_value_map:
  824. quantized_value = self.quantized_value_map[node_input]
  825. scale_names.append(quantized_value.scale_name)
  826. zero_point_names.append(quantized_value.zp_name)
  827. quantized_input_names.append(quantized_value.q_name)
  828. continue
  829. # adding this for case embed_layernorm.py has optional segment_embedding
  830. if not node_input:
  831. quantized_input_names.append("")
  832. scale_names.append("")
  833. zero_point_names.append("")
  834. continue
  835. # Quantize the input
  836. initializer = find_by_name(node_input, self.model.initializer())
  837. if initializer is not None:
  838. if self.per_channel and op_level_per_channel:
  839. (
  840. q_weight_name,
  841. zp_name,
  842. scale_name,
  843. ) = self.quantize_weight_per_channel(
  844. initializer.name,
  845. self.weight_qType if initializer_use_weight_qType else self.activation_qType,
  846. axis,
  847. reduce_range,
  848. )
  849. else:
  850. q_weight_name, zp_name, scale_name = self.quantize_initializer(
  851. initializer,
  852. self.weight_qType if initializer_use_weight_qType else self.activation_qType,
  853. reduce_range,
  854. )
  855. quantized_input_names.append(q_weight_name)
  856. zero_point_names.append(zp_name)
  857. scale_names.append(scale_name)
  858. elif self.contains_tensor(node_input):
  859. # Add QuantizeLinear node.
  860. qlinear_node = self.model.find_node_by_name(
  861. node_input + "_QuantizeLinear", self.new_nodes, self.model.graph()
  862. )
  863. if qlinear_node is None:
  864. input_name = node.input[input_index]
  865. if input_name in self.value_infos:
  866. value_info = self.value_infos[input_name]
  867. assert value_info.HasField("type"), f"value_info={value_info} has no type."
  868. assert value_info.type.HasField("tensor_type"), f"value_info={value_info} is not a tensor."
  869. initial_type = value_info.type.tensor_type.elem_type
  870. else:
  871. # Shape inference failed. Fallback to self.tensor_names.
  872. assert input_name in self.tensor_names, (
  873. f"shape inference failed for {input_name!r} and "
  874. f"attribute 'tensor_names' does not have any value for "
  875. f"this tensor."
  876. )
  877. initial_type = self.tensor_names[input_name]
  878. quantize_input_nodes = self._get_quantize_input_nodes(
  879. node, input_index, self.activation_qType, initial_type=initial_type
  880. )
  881. if quantize_input_nodes is None:
  882. return (None, None, None, None)
  883. if from_subgraph:
  884. self.add_new_nodes(quantize_input_nodes)
  885. else:
  886. nodes.extend(quantize_input_nodes)
  887. qlinear_node = quantize_input_nodes[-1]
  888. if qlinear_node.op_type == "QuantizeLinear":
  889. quantized_input_names.extend(qlinear_node.output)
  890. scale_names.append(qlinear_node.input[1])
  891. zero_point_names.append(qlinear_node.input[2])
  892. else:
  893. quantized_input_names.append(qlinear_node.output[0])
  894. scale_names.append(qlinear_node.output[1])
  895. zero_point_names.append(qlinear_node.output[2])
  896. elif self.parent is not None:
  897. (
  898. parent_quantized_input_names,
  899. parent_zero_point_names,
  900. parent_scale_names,
  901. _,
  902. ) = self.parent.__quantize_inputs(
  903. node,
  904. [input_index],
  905. initializer_use_weight_qType=initializer_use_weight_qType,
  906. reduce_range=reduce_range,
  907. op_level_per_channel=op_level_per_channel,
  908. axis=axis,
  909. from_subgraph=True,
  910. )
  911. quantized_input_names.append(parent_quantized_input_names[0])
  912. scale_names.append(parent_scale_names[0])
  913. zero_point_names.append(parent_zero_point_names[0])
  914. # node should not be add this child level here
  915. else:
  916. raise ValueError(f"Invalid tensor name to quantize: {node_input} @graph scope{self.graph_scope}")
  917. return quantized_input_names, zero_point_names, scale_names, nodes
  918. def quantize_initializer(self, weight, qType, reduce_range=False, keep_float_weight=False):
  919. """
  920. :param weight: TensorProto initializer
  921. :param qType: type to quantize to
  922. :param keep_float_weight: Whether to quantize the weight. In some cases, we only want to qunatize scale and zero point.
  923. If keep_float_weight is False, quantize the weight, or don't quantize the weight.
  924. :return: quantized weight name, zero point name, scale name
  925. """
  926. # Find if this input is already quantized
  927. if weight.name in self.quantized_value_map:
  928. quantized_value = self.quantized_value_map[weight.name]
  929. return (
  930. quantized_value.q_name,
  931. quantized_value.zp_name,
  932. quantized_value.scale_name,
  933. )
  934. q_weight_name, zp_name, scale_name = self.quantize_initializer_impl(
  935. weight, qType, reduce_range, keep_float_weight
  936. )
  937. # Log entry for this quantized weight
  938. quantized_value = QuantizedValue(
  939. weight.name,
  940. q_weight_name,
  941. scale_name,
  942. zp_name,
  943. QuantizedValueType.Initializer,
  944. None,
  945. )
  946. self.quantized_value_map[weight.name] = quantized_value
  947. return q_weight_name, zp_name, scale_name
  948. def quantize_weight_per_channel(
  949. self,
  950. weight_name,
  951. weight_qType,
  952. channel_axis,
  953. reduce_range=True,
  954. keep_float_weight=False,
  955. ):
  956. # Find if this input is already quantized
  957. if weight_name in self.quantized_value_map:
  958. quantized_value = self.quantized_value_map[weight_name]
  959. return (
  960. quantized_value.q_name,
  961. quantized_value.zp_name,
  962. quantized_value.scale_name,
  963. )
  964. q_weight_name, zp_name, scale_name = self.quantize_weight_per_channel_impl(
  965. weight_name, weight_qType, channel_axis, reduce_range, keep_float_weight
  966. )
  967. quantized_value = QuantizedValue(
  968. weight_name,
  969. q_weight_name,
  970. scale_name,
  971. zp_name,
  972. QuantizedValueType.Initializer,
  973. None,
  974. )
  975. self.quantized_value_map[weight_name] = quantized_value
  976. return q_weight_name, zp_name, scale_name
  977. def _dequantize_value(self, value_name):
  978. """
  979. Given a value (input/output) which is quantized, add a DequantizeLinear node to dequantize
  980. it back to float32 or float16
  981. parameter value_name: value to dequantize
  982. parameter new_nodes_list: List of new nodes created before processing current node
  983. return: None if there is already a DequantizeLinear node that dequantizes it
  984. A DequantizeLinear node otherwise
  985. """
  986. if (value_name in self.quantized_value_map) and (value_name not in self.generated_value_names):
  987. quantized_value = self.quantized_value_map[value_name]
  988. # Add DequantizeLinear Node for this input
  989. scale_init = find_by_name(quantized_value.scale_name, self.model.initializer())
  990. # In case we are working with subgraphs, the graph `producer_name` is set to `"onnx-quantizer"` in the `quantize_subgraph` method. In this case, the scale initializer may be on the top level graph, so the check below can not be done.
  991. if self.model.model.producer_name != "onnx-quantizer" or (
  992. self.model.model.producer_name == "onnx-quantizer" and scale_init is not None
  993. ):
  994. # axis is not specified so scale_init must be a scalar.
  995. assert scale_init is None or onnx.numpy_helper.to_array(scale_init).size == 1
  996. dqlinear_name = value_name + "_DequantizeLinear"
  997. dqlinear_node = self.model.find_node_by_name(dqlinear_name, self.new_nodes, self.model.graph())
  998. if dqlinear_node is None:
  999. dqlinear_inputs = [
  1000. quantized_value.q_name,
  1001. quantized_value.scale_name,
  1002. quantized_value.zp_name,
  1003. ]
  1004. dequantize_node = onnx.helper.make_node(
  1005. "DequantizeLinear", dqlinear_inputs, [value_name], dqlinear_name
  1006. )
  1007. return dequantize_node
  1008. else:
  1009. # DQ op is already present, assert it's output matches the input of current node
  1010. assert value_name == dqlinear_node.output[0]
  1011. return None
  1012. def _dequantize_outputs(self):
  1013. """
  1014. Dequantize output if it is quantized
  1015. parameter new_nodes_list: List of new nodes created before processing current node
  1016. return: List of new nodes created
  1017. """
  1018. for output in self.model.graph().output:
  1019. dequantize_node = self._dequantize_value(output.name)
  1020. if dequantize_node is not None:
  1021. self.new_nodes.append(dequantize_node)
  1022. def calculate_quantization_params(self):
  1023. if self.tensors_range is None:
  1024. return None
  1025. self.adjust_tensor_ranges()
  1026. quantization_params = {}
  1027. for tensor_name in self.tensors_range:
  1028. td = self.tensors_range[tensor_name]
  1029. if not isinstance(td, TensorData):
  1030. raise TypeError(f"Unexpected type {type(td)} for {tensor_name!r}.")
  1031. quant_overrides = self.tensor_quant_overrides.get_per_tensor_overrides(tensor_name, default_val={})
  1032. quant_type = self.activation_qType
  1033. if "quant_type" in quant_overrides:
  1034. quant_type = quant_overrides["quant_type"].tensor_type
  1035. if "scale" in quant_overrides and "zero_point" in quant_overrides:
  1036. zero, scale = quant_overrides["zero_point"], quant_overrides["scale"]
  1037. elif quant_type == onnx.TensorProto.FLOAT8E4M3FN:
  1038. zero, scale = compute_scale_zp_float8(quant_type, td.avg_std[1])
  1039. else:
  1040. rmin = quant_overrides.get("rmin", td.range_value[0])
  1041. rmax = quant_overrides.get("rmax", td.range_value[1])
  1042. symmetric = quant_overrides.get("symmetric", self.is_activation_symmetric)
  1043. reduce_range = quant_overrides.get("reduce_range", False)
  1044. qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range=reduce_range, symmetric=symmetric)
  1045. zero, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, self.min_real_range)
  1046. quantization_params[tensor_name] = QuantizationParams(zero_point=zero, scale=scale, quant_type=quant_type)
  1047. return quantization_params