qdq_loss_debug.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. # --------------------------------------------------------------------------
  2. # Copyright (c) Microsoft, Intel Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. """Utilities to run a given ONNX model, while saving input/output tensors of
  7. eligible operator nodes.
  8. A use case is to debug quantization induced accuracy drop. An AI engineer can
  9. run the original float32 model and the quantized model with the same inputs,
  10. then compare the corresponding activations between the two models to find
  11. where the divergence is.
  12. Example Usage:
  13. ```python
  14. class ExampleDataReader(CalibrationDataReader):
  15. def __init__(self):
  16. ...
  17. def get_next(self):
  18. ...
  19. input_data_reader = ExampleDataReader()
  20. augmented_model_path = str(Path(self._tmp_model_dir.name).joinpath("augmented_model.onnx"))
  21. modify_model_output_intermediate_tensors (path_to_onnx_model, augmented_model_path)
  22. tensor_dict = collect_activations(augmented_model_path, input_data_reader)
  23. ```
  24. `tensor_dict` points to a dictionary where the keys are tensor names and each value
  25. is a list of tensors, one from each model run
  26. """
  27. import logging
  28. import math
  29. import time
  30. from collections.abc import Callable, Sequence
  31. from pathlib import Path
  32. import numpy
  33. import onnx
  34. from onnx import helper, numpy_helper
  35. import onnxruntime
  36. from .calibrate import CalibraterBase, CalibrationDataReader
  37. from .onnx_model import ONNXModel
  38. from .quant_utils import (
  39. DEQUANT_OP_NAME,
  40. DEQUANT_OUTPUT_SUFFIX,
  41. QUANT_INPUT_SUFFIX,
  42. TENSOR_NAME_QUANT_SUFFIX,
  43. find_by_name,
  44. load_model_with_shape_infer,
  45. )
  46. _TENSOR_SAVE_POSTFIX = "_ReshapedSavedOutput"
  47. _TENSOR_SAVE_POSTFIX_LEN = len(_TENSOR_SAVE_POSTFIX)
  48. def modify_model_output_intermediate_tensors(
  49. input_model_path: str | Path,
  50. output_model_path: str | Path,
  51. op_types_for_saving: Sequence[str] | None = None,
  52. save_as_external_data: bool = False,
  53. ) -> None:
  54. """Augment a given ONNX model to save node input/output tensors.
  55. Add all input/output tensors of operator nodes to model outputs
  56. so that their values can be retrieved for debugging purposes.
  57. Args:
  58. input_model: the path to load the model.
  59. op_types_for_saving: Operator types for which the
  60. input/output should be saved. By default, saving all the
  61. float32/float16 tensors.
  62. Returns:
  63. The augmented ONNX model
  64. """
  65. if op_types_for_saving is None:
  66. op_types_for_saving = []
  67. saver = CalibraterBase(input_model_path, op_types_to_calibrate=op_types_for_saving)
  68. model_to_augment = saver.model
  69. tensors, value_infos = saver.select_tensors_to_calibrate(model_to_augment)
  70. reshape_shape_name = "LinearReshape_" + str(time.time())
  71. reshape_shape = numpy_helper.from_array(numpy.array([-1], dtype=numpy.int64), reshape_shape_name)
  72. model_to_augment.graph.initializer.append(reshape_shape)
  73. for tensor_name in tensors:
  74. reshape_output = tensor_name + _TENSOR_SAVE_POSTFIX
  75. reshape_node = onnx.helper.make_node(
  76. "Reshape",
  77. inputs=[tensor_name, reshape_shape_name],
  78. outputs=[reshape_output],
  79. name=reshape_output,
  80. )
  81. model_to_augment.graph.node.append(reshape_node)
  82. reshape_output_value_info = helper.make_tensor_value_info(
  83. reshape_output, value_infos[tensor_name].type.tensor_type.elem_type, [-1]
  84. )
  85. model_to_augment.graph.output.append(reshape_output_value_info)
  86. onnx.save(
  87. model_to_augment,
  88. output_model_path,
  89. save_as_external_data=save_as_external_data,
  90. )
  91. def collect_activations(
  92. augmented_model: str,
  93. input_reader: CalibrationDataReader,
  94. session_options=None,
  95. execution_providers: Sequence[str] | None = None,
  96. ) -> dict[str, list[numpy.ndarray]]:
  97. """Run augmented model and collect activations tensors.
  98. Args:
  99. augmented_model: Path to augmented model created by modify_model_output_intermediate_tensors ()
  100. input_reader: Logic for reading input for the model, augmented model have the same
  101. input with the original model.
  102. session_options: Optional OnnxRuntime session options for controlling model run.
  103. By default graph optimization is turned off
  104. execution_providers: Collection of execution providers for running the model.
  105. Only CPU EP is used by default.
  106. Returns:
  107. A dictionary where the key is tensor name and values are list of tensors from each batch
  108. """
  109. if session_options is None:
  110. session_options = onnxruntime.SessionOptions()
  111. session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
  112. if execution_providers is None:
  113. execution_providers = ["CPUExecutionProvider"]
  114. inference_session = onnxruntime.InferenceSession(
  115. augmented_model,
  116. sess_options=session_options,
  117. providers=execution_providers,
  118. )
  119. intermediate_outputs = []
  120. for input_d in input_reader:
  121. intermediate_outputs.append(inference_session.run(None, input_d))
  122. if not intermediate_outputs:
  123. raise RuntimeError("No data is collected while running augmented model!")
  124. output_dict = {}
  125. output_info = inference_session.get_outputs()
  126. for batch in intermediate_outputs:
  127. for output, output_data in zip(output_info, batch, strict=False):
  128. if output.name.endswith(_TENSOR_SAVE_POSTFIX):
  129. output_name = output.name[:-_TENSOR_SAVE_POSTFIX_LEN]
  130. output_dict.setdefault(output_name, []).append(output_data)
  131. return output_dict
  132. _POST_QDQ_POSTFIX1 = DEQUANT_OUTPUT_SUFFIX + "_1"
  133. def _add_pre_post_qdq_pair(
  134. qdq_cmp: dict[str, dict[str, Sequence[numpy.ndarray]]],
  135. activation_name: str,
  136. pre_qdq_tensors: Sequence[numpy.ndarray] | None,
  137. post_qdq_tensors: Sequence[numpy.ndarray] | None,
  138. ) -> None:
  139. if post_qdq_tensors is not None and pre_qdq_tensors is not None:
  140. qdq_cmp[activation_name] = {}
  141. qdq_cmp[activation_name]["pre_qdq"] = pre_qdq_tensors
  142. qdq_cmp[activation_name]["post_qdq"] = post_qdq_tensors
  143. def create_activation_matching(
  144. qdq_activations: dict[str, Sequence[numpy.ndarray]],
  145. float_activations: dict[str, Sequence[numpy.ndarray]] | None = None,
  146. ) -> dict[str, dict[str, Sequence[numpy.ndarray]]]:
  147. """Comparing activation values to help debugging accuracy loss due to quantization.
  148. This functions takes saved activations from the QDQ model and (optionally) the
  149. float point model, and provides a data structure for comparing:
  150. * from the qdq model, activation values before and after QDQ operation
  151. * across both models, activations from the orignal model vs the corresponding
  152. activations in the QDQ model
  153. Arg:
  154. qdq_activations: Output of `collect_activations`. This must be from a quantized
  155. model with QDQ format.
  156. float_activations: Output of `collect_activations`. This must be from the float
  157. point model.
  158. Returns:
  159. Dict for comparing pre and post quantized activation tensors. E.g.
  160. ```
  161. qdq_cmp = cmp_qdq_input_output(qdq_activations)
  162. print(qdq_cmp['activation1']['pre_qdq'][0])
  163. print(qdq_cmp['activation1'][`post_qdq'][0])
  164. qdq_cmp = cmp_qdq_input_output(qdq_activations, float_activations)
  165. print(qdq_cmp['activation1']['float'][0])
  166. print(qdq_cmp['activation1']['pre_qdq'][0])
  167. print(qdq_cmp['activation1'][`post_qdq'][0])
  168. ```
  169. """
  170. qdq_cmp: dict[str, dict[str, Sequence[numpy.ndarray]]] = {}
  171. for tensor_name, tensors in qdq_activations.items():
  172. if tensor_name.endswith(QUANT_INPUT_SUFFIX):
  173. pre_name = tensor_name[: -len(QUANT_INPUT_SUFFIX)]
  174. post_qdq_tensors = qdq_activations.get(pre_name)
  175. pre_qdq_tensors = tensors
  176. _add_pre_post_qdq_pair(qdq_cmp, pre_name, pre_qdq_tensors, post_qdq_tensors)
  177. elif tensor_name.endswith(DEQUANT_OUTPUT_SUFFIX):
  178. pre_name = tensor_name[: -len(DEQUANT_OUTPUT_SUFFIX)]
  179. pre_qdq_tensors = qdq_activations.get(pre_name)
  180. post_qdq_tensors = tensors
  181. _add_pre_post_qdq_pair(qdq_cmp, pre_name, pre_qdq_tensors, post_qdq_tensors)
  182. elif tensor_name.endswith(_POST_QDQ_POSTFIX1):
  183. pre_name = tensor_name[: -len(_POST_QDQ_POSTFIX1)]
  184. pre_qdq_tensors = qdq_activations.get(pre_name)
  185. post_qdq_tensors = tensors
  186. _add_pre_post_qdq_pair(qdq_cmp, pre_name, pre_qdq_tensors, post_qdq_tensors)
  187. if not float_activations:
  188. return qdq_cmp
  189. for act_name, act_values in qdq_cmp.items():
  190. float_acts = float_activations.get(act_name)
  191. if float_acts is not None:
  192. act_values["float"] = float_acts
  193. return qdq_cmp
  194. def _run_dequantize_linear(
  195. weight_tensor: numpy.ndarray, weight_scale: numpy.ndarray, weight_zp: numpy.ndarray, channel_axis: int
  196. ) -> numpy.ndarray | None:
  197. assert weight_scale.shape == weight_zp.shape
  198. if weight_zp.size == 1:
  199. return (weight_tensor - weight_zp) * weight_scale
  200. assert weight_zp.ndim == 1
  201. reshape_dims = list(weight_tensor.shape) # deep copy
  202. reshape_dims[channel_axis] = 1 # only one per channel for reshape
  203. channel_count = weight_tensor.shape[channel_axis]
  204. dequantized_weights = None
  205. for i in range(channel_count):
  206. per_channel_data = weight_tensor.take(i, channel_axis)
  207. dequantized_per_channel_data = (per_channel_data - weight_zp[i]) * weight_scale[i]
  208. if i == 0:
  209. dequantized_weights = numpy.asarray(dequantized_per_channel_data).reshape(reshape_dims)
  210. else:
  211. channel_weights = numpy.asarray(dequantized_per_channel_data).reshape(reshape_dims)
  212. dequantized_weights = numpy.concatenate((dequantized_weights, channel_weights), channel_axis)
  213. if dequantized_weights is None:
  214. return None
  215. dequantized_weights.reshape(weight_tensor.shape)
  216. return dequantized_weights
  217. def create_weight_matching(float_model_path: str, qdq_model_path: str) -> dict[str, dict[str, numpy.ndarray]]:
  218. """Comparing weight values to help debugging accuracy loss due to quantization.
  219. This functions takes the float model and the qdq model, and provides a data structure for comparing
  220. their corresponding weights to locate quantization errors
  221. Arg:
  222. float_model_path: Path points to the float point model.
  223. qdq_model_path: Path points to the qdq model.
  224. Returns:
  225. Dict for comparing weight tensors. E.g.
  226. ```
  227. qdq_weight_cmp = create_weight_matching(float_model, qdq_model)
  228. print(qdq_weight_cmp['activation1']['float'])
  229. print(qdq_weight_cmp['activation1']['dequantized'])
  230. ```
  231. """
  232. float_onnx_model = ONNXModel(load_model_with_shape_infer(Path(float_model_path)))
  233. qdq_onnx_model = ONNXModel(load_model_with_shape_infer(Path(qdq_model_path)))
  234. matched_weights: dict[str, dict[str, numpy.ndarray]] = {}
  235. initializers = qdq_onnx_model.initializer()
  236. for node in qdq_onnx_model.nodes():
  237. if node.op_type != DEQUANT_OP_NAME:
  238. continue # Only care about DQ node
  239. weight_name: str = node.input[0]
  240. weight_values = find_by_name(weight_name, initializers)
  241. if not weight_values:
  242. continue # Only care about DQ node with const inputs
  243. if not weight_name.endswith(TENSOR_NAME_QUANT_SUFFIX):
  244. logging.error(f"Model Error in '{qdq_model_path}': Dequantized tensor name '{weight_name}' not recognized!")
  245. continue
  246. axis = -1
  247. for attr in node.attribute:
  248. if attr.name == "axis":
  249. axis = attr.i
  250. weight_tensor = numpy_helper.to_array(weight_values)
  251. weight_scale = numpy_helper.to_array(find_by_name(node.input[1], initializers))
  252. if len(node.input) > 2:
  253. weight_zp = numpy_helper.to_array(find_by_name(node.input[2], initializers))
  254. else:
  255. weight_zp = numpy.zeros(weight_scale.shape, dtype=numpy.int32)
  256. # Perform dequantization:
  257. if weight_scale.size == weight_zp.size == 1:
  258. # Avoids the confusion between a scaler and a tensor of one element.
  259. weight_scale = weight_scale.reshape(())
  260. weight_zp = weight_zp.reshape(())
  261. if weight_scale.shape != weight_zp.shape:
  262. raise RuntimeError(
  263. f"scale and zero_point must have the same shape but {weight_scale.shape} != {weight_zp.shape}"
  264. )
  265. weight_quant = _run_dequantize_linear(weight_tensor, weight_scale, weight_zp, channel_axis=axis)
  266. weight_name = weight_name[: -len(TENSOR_NAME_QUANT_SUFFIX)]
  267. if weight_quant is None:
  268. logging.error(f"Model Error in '{qdq_model_path}': '{weight_name}' per-channel quantization on 0 channel")
  269. continue
  270. float_values = find_by_name(weight_name, float_onnx_model.initializer())
  271. if not float_values:
  272. logging.error(f"Model Error in '{float_model_path}': weight tensor '{weight_name}' not found!")
  273. continue
  274. weight_float = numpy_helper.to_array(float_values)
  275. matched_weights[weight_name] = {"float": weight_float, "dequantized": weight_quant}
  276. return matched_weights
  277. def compute_signal_to_quantization_noice_ratio(
  278. x: Sequence[numpy.ndarray] | numpy.ndarray, y: Sequence[numpy.ndarray] | numpy.ndarray
  279. ) -> float:
  280. if isinstance(x, numpy.ndarray):
  281. xlist = [x]
  282. else:
  283. xlist = x
  284. if isinstance(y, numpy.ndarray):
  285. ylist = [y]
  286. else:
  287. ylist = y
  288. if len(xlist) != len(ylist):
  289. raise RuntimeError("Unequal number of tensors to compare!")
  290. left = numpy.concatenate(xlist).flatten()
  291. right = numpy.concatenate(ylist).flatten()
  292. epsilon = numpy.finfo("float").eps
  293. tensor_norm = max(numpy.linalg.norm(left), epsilon)
  294. diff_norm = max(numpy.linalg.norm(left - right), epsilon)
  295. res = tensor_norm / diff_norm
  296. return 20 * math.log10(res)
  297. def compute_weight_error(
  298. weights_match: dict[str, dict[str, numpy.ndarray]],
  299. err_func: Callable[[numpy.ndarray, numpy.ndarray], float] = compute_signal_to_quantization_noice_ratio,
  300. ) -> dict[str, float]:
  301. result: dict[str, float] = {}
  302. for weight_name, weight_match in weights_match.items():
  303. result[weight_name] = err_func(weight_match["float"], weight_match["dequantized"])
  304. return result
  305. def compute_activation_error(
  306. activations_match: dict[str, dict[str, Sequence[numpy.ndarray]]],
  307. err_func: Callable[
  308. [Sequence[numpy.ndarray], Sequence[numpy.ndarray]], float
  309. ] = compute_signal_to_quantization_noice_ratio,
  310. ) -> dict[str, dict[str, float]]:
  311. result: dict[str, dict[str, float]] = {}
  312. for name, match in activations_match.items():
  313. err_result: dict[str, float] = {}
  314. err_result["qdq_err"] = err_func(match["pre_qdq"], match["post_qdq"])
  315. float_activation = match["float"]
  316. if float_activation:
  317. err_result["xmodel_err"] = err_func(float_activation, match["post_qdq"])
  318. result[name] = err_result
  319. return result