tensor_quant_overrides.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. from __future__ import annotations
  7. import json
  8. from collections.abc import MutableMapping
  9. from dataclasses import dataclass
  10. from typing import Any
  11. import onnx
  12. from .quant_utils import QuantType
  13. @dataclass
  14. class QuantTypeInfo: # noqa: PLW1641
  15. """
  16. The quantization type information for a tensor override.
  17. """
  18. quant_type: QuantType
  19. symmetric: bool | None = None # If None, assumes default is used.
  20. reduce_range: bool | None = None # If None, assumes default is used.
  21. axis: int | None = None # If None, assumes per-tensor quantization
  22. def __eq__(self, other: object):
  23. if isinstance(other, QuantTypeInfo):
  24. return (
  25. self.quant_type == other.quant_type
  26. and (self.symmetric is None or other.symmetric is None or self.symmetric == other.symmetric)
  27. and (self.reduce_range is None or other.reduce_range is None or self.reduce_range == other.reduce_range)
  28. and (self.axis == other.axis)
  29. )
  30. return NotImplemented
  31. @staticmethod
  32. def load_from_dict(
  33. raw_dict: dict[str, Any],
  34. default_qtype: QuantType | None = None,
  35. default_symmetric: bool | None = None,
  36. default_reduce_range: bool | None = None,
  37. ) -> QuantTypeInfo:
  38. return QuantTypeInfo(
  39. raw_dict.get("quant_type", default_qtype),
  40. raw_dict.get("symmetric", default_symmetric),
  41. raw_dict.get("reduce_range", default_reduce_range),
  42. raw_dict.get("axis"),
  43. )
  44. def save_to_dict(self, raw_dict: dict[str, Any]):
  45. raw_dict["quant_type"] = self.quant_type
  46. if self.symmetric is not None:
  47. raw_dict["symmetric"] = self.symmetric
  48. if self.reduce_range is not None:
  49. raw_dict["reduce_range"] = self.reduce_range
  50. if self.axis is not None:
  51. raw_dict["axis"] = self.axis
  52. class TensorQuantOverridesHelper(MutableMapping):
  53. """
  54. Utility wrapper over the tensor quantization overrides passed via extra_options.
  55. """
  56. def __init__(self, raw_overrides: dict[str, list[dict[str, Any]]]):
  57. self.overrides = raw_overrides
  58. self.quant_types = None
  59. self.keys_unsupported_with_scale_zp = {"symmetric", "reduce_range", "rmax", "rmin"}
  60. def has_per_tensor_overrides(self, tensor_name: str) -> bool:
  61. overrides_list = self.overrides.get(tensor_name)
  62. return overrides_list and "axis" not in overrides_list[0]
  63. def has_per_channel_overrides(self, tensor_name: str) -> bool:
  64. overrides_list = self.overrides.get(tensor_name)
  65. return overrides_list and "axis" in overrides_list[0]
  66. def overrides_scale_zp(self, tensor_name: str) -> bool:
  67. overrides_list = self.overrides.get(tensor_name)
  68. return overrides_list and ("scale" in overrides_list[0]) and ("zero_point" in overrides_list[0])
  69. def get_per_tensor_overrides(
  70. self,
  71. tensor_name: str,
  72. default_val: dict[str, Any] | None = None,
  73. ) -> dict[str, Any] | None:
  74. default_list_val = [default_val] if default_val is not None else None
  75. overrides_list = self.overrides.get(tensor_name, default_list_val)
  76. if overrides_list and "axis" in overrides_list[0]:
  77. raise ValueError(
  78. f"Expected tensor '{tensor_name}' to use per-tensor quantization overrides, "
  79. f"but found per-channel overrides."
  80. )
  81. return overrides_list[0] if overrides_list else None
  82. def get_per_channel_overrides(
  83. self,
  84. tensor_name: str,
  85. default_val: list[dict[str, Any]] | None = None,
  86. ) -> list[dict[str, Any]] | None:
  87. overrides_list = self.overrides.get(tensor_name, default_val)
  88. if not overrides_list:
  89. return None
  90. if "axis" not in overrides_list[0]:
  91. raise ValueError(
  92. f"Expected tensor '{tensor_name}' to have per-channel quantization overrides (axis value is missing).",
  93. )
  94. return overrides_list
  95. def get_quant_types(self) -> set[QuantType]:
  96. if self.quant_types is not None:
  97. return self.quant_types
  98. self.quant_types = set()
  99. if self.overrides:
  100. for quant_overrides_list in self.overrides.values():
  101. for quant_overrides in quant_overrides_list:
  102. if "quant_type" in quant_overrides:
  103. self.quant_types.add(quant_overrides["quant_type"])
  104. if "convert" in quant_overrides and "quant_type" in quant_overrides["convert"]:
  105. self.quant_types.add(quant_overrides["convert"]["quant_type"])
  106. return self.quant_types
  107. def _is_valid_per_tensor(
  108. self,
  109. initializers,
  110. default_activation_qtype,
  111. tensor_name: str,
  112. quant_overrides: dict[str, Any],
  113. ) -> tuple[bool, str | None]:
  114. if not isinstance(quant_overrides, dict):
  115. return (
  116. False,
  117. f"Tensor quantization overrides for '{tensor_name}' are not in a dict",
  118. )
  119. is_initializer = tensor_name in initializers
  120. quant_type = quant_overrides.get("quant_type")
  121. if quant_type:
  122. self.quant_types.add(quant_type)
  123. has_scale = "scale" in quant_overrides
  124. has_zero_point = "zero_point" in quant_overrides
  125. if (has_scale and not has_zero_point) or (has_zero_point and not has_scale):
  126. return (
  127. False,
  128. "Must provide both 'scale' and 'zero_point' if one of the overrides is provided",
  129. )
  130. if has_scale:
  131. keys = self.keys_unsupported_with_scale_zp.intersection(set(quant_overrides))
  132. if keys:
  133. return (
  134. False,
  135. f"Tensor override option(s) [{', '.join(keys)}] are invalid with 'scale' and 'zero_point'",
  136. )
  137. if "reduce_range" in quant_overrides and not is_initializer:
  138. return (
  139. False,
  140. f"Option 'reduce_range' is only supported for initializers, not for activation {tensor_name}",
  141. )
  142. if "convert" in quant_overrides:
  143. if is_initializer:
  144. return False, "Cannot use 'convert' override for initializers"
  145. if "quant_type" not in quant_overrides["convert"]:
  146. return False, f"'convert' options (tensor '{tensor_name}') must specify a 'quant_type'"
  147. if "reduce_range" in quant_overrides["convert"]:
  148. return (
  149. False,
  150. f"Option 'reduce_range' is only supported for initializers, not for activation {tensor_name}",
  151. )
  152. convert_quant_type = quant_overrides["convert"]["quant_type"]
  153. original_quant_type = quant_type if quant_type is not None else default_activation_qtype
  154. if convert_quant_type == original_quant_type:
  155. return (
  156. False,
  157. f"'convert' quant_type must differ from original quant_type (tensor '{tensor_name}')",
  158. )
  159. convert_has_scale = "scale" in quant_overrides["convert"]
  160. convert_has_zero_point = "zero_point" in quant_overrides["convert"]
  161. if (convert_has_scale and not convert_has_zero_point) or (convert_has_zero_point and not convert_has_scale):
  162. return (
  163. False,
  164. f"Must provide both 'scale' and 'zero_point' if one of the overrides is provided (tensor '{tensor_name}')",
  165. )
  166. if convert_has_scale:
  167. keys = self.keys_unsupported_with_scale_zp.intersection(set(quant_overrides["convert"]))
  168. if keys:
  169. return (
  170. False,
  171. f"Tensor override option(s) [{', '.join(keys)}] are invalid with 'scale' and 'zero_point' "
  172. f"(tensor '{tensor_name}')",
  173. )
  174. self.quant_types.add(convert_quant_type)
  175. return True, None
  176. def _is_valid_per_channel(
  177. self,
  178. initializers,
  179. tensor_name: str,
  180. quant_overrides_list: list[dict[str, Any]],
  181. ) -> tuple[bool, str | None]:
  182. is_initializer = tensor_name in initializers
  183. if not is_initializer:
  184. return (
  185. False,
  186. f"Tensor '{tensor_name}' has per-channel overrides, but is not an initializer",
  187. )
  188. axis = quant_overrides_list[0].get("axis")
  189. if axis is None:
  190. return (
  191. False,
  192. f"Per-channel overrides for tensor {tensor_name} is missing an 'axis' value in "
  193. "the first channel dictionary.",
  194. )
  195. weight_shape = list(initializers[tensor_name].dims)
  196. weight_rank = len(weight_shape)
  197. norm_axis = axis
  198. if norm_axis < 0:
  199. norm_axis += weight_rank
  200. if norm_axis < 0 or norm_axis >= len(weight_shape):
  201. return (
  202. False,
  203. f"Axis override value is out-of-bounds for tensor {tensor_name} (rank {len(weight_shape)})",
  204. )
  205. if len(quant_overrides_list) > 1 and len(quant_overrides_list) != weight_shape[norm_axis]:
  206. return (
  207. False,
  208. f"Incorrect number of channel overrides for tensor {tensor_name} (axis {axis}), "
  209. f"expected {weight_shape[axis]}, but found {len(quant_overrides_list)}.",
  210. )
  211. if "convert" in quant_overrides_list[0]:
  212. return False, f"Cannot use 'convert' override for initializers, such as {tensor_name}."
  213. quant_type = quant_overrides_list[0].get("quant_type")
  214. if quant_type:
  215. self.quant_types.add(quant_type)
  216. symmetric = quant_overrides_list[0].get("symmetric")
  217. reduce_range = quant_overrides_list[0].get("reduce_range")
  218. has_scale = "scale" in quant_overrides_list[0]
  219. has_zero_point = "zero_point" in quant_overrides_list[0]
  220. has_scale_zp = has_scale and has_zero_point
  221. if (has_scale and not has_zero_point) or (has_zero_point and not has_scale):
  222. return (
  223. False,
  224. "Must provide both 'scale' and 'zero_point' if one of the overrides is provided",
  225. )
  226. if has_scale_zp:
  227. keys = self.keys_unsupported_with_scale_zp.intersection(set(quant_overrides_list[0]))
  228. if keys:
  229. return (
  230. False,
  231. f"Tensor override option(s) [{', '.join(keys)}] are invalid with 'scale' and 'zero_point'",
  232. )
  233. has_rmin = "rmin" in quant_overrides_list[0]
  234. has_rmax = "rmax" in quant_overrides_list[0]
  235. has_rmin_rmax = has_rmin and has_rmax
  236. if (has_rmin and not has_rmax) or (not has_rmin and has_rmax):
  237. return (
  238. False,
  239. "Must provide both 'rmin' and 'rmax' if one is provided",
  240. )
  241. for index, quant_overrides in enumerate(quant_overrides_list[1:]):
  242. if not isinstance(quant_overrides, dict):
  243. return (
  244. False,
  245. f"Tensor quantization overrides at index {index} for '{tensor_name}' are not in a dict",
  246. )
  247. if "convert" in quant_overrides:
  248. return False, f"Cannot use 'convert' override for initializers, such as {tensor_name}."
  249. # For per-channel quantization, all channels must use the same quantization type, axis, symmetric
  250. # and reduce_range values. And, if specified, they must be present in the first channel dict
  251. # (i.e., quant_overrides_list[0]).
  252. if "quant_type" in quant_overrides and quant_type != quant_overrides["quant_type"]:
  253. return (
  254. False,
  255. "Channel quantization types for tensor '{tensor_name}' do not match at index {index}.",
  256. )
  257. if "axis" in quant_overrides and axis != quant_overrides["axis"] and norm_axis != quant_overrides["axis"]:
  258. return (
  259. False,
  260. "Channel axis for tensor '{tensor_name}' does not match at index {index}.",
  261. )
  262. if "symmetric" in quant_overrides and symmetric != quant_overrides["symmetric"]:
  263. return (
  264. False,
  265. "Channel symmetric value for tensor '{tensor_name}' does not match at index {index}.",
  266. )
  267. if "reduce_range" in quant_overrides and reduce_range != quant_overrides["reduce_range"]:
  268. return (
  269. False,
  270. "Channel reduce_range value for tensor '{tensor_name}' does not match at index {index}.",
  271. )
  272. # If override scale/zp, must do so for all channels.
  273. chan_has_scale_zp = "scale" in quant_overrides and "zero_point" in quant_overrides
  274. if has_scale_zp and not chan_has_scale_zp:
  275. return (
  276. False,
  277. "Per-channel overrides that specify scale/zero_point must do so for all channels, "
  278. f"but tensor '{tensor_name}' is missing them at index {index}.",
  279. )
  280. if chan_has_scale_zp:
  281. keys = self.keys_unsupported_with_scale_zp.intersection(set(quant_overrides))
  282. if keys:
  283. return (
  284. False,
  285. f"Tensor override option(s) [{', '.join(keys)}] are invalid with 'scale' and 'zero_point'",
  286. )
  287. # If override rmin/rmax, must do so for all channels.
  288. chan_has_rmin_rmax = "rmin" in quant_overrides and "rmax" in quant_overrides
  289. if has_rmin_rmax and not chan_has_rmin_rmax:
  290. return (
  291. False,
  292. "Per-channel overrides that specify rmin/rmax must do so for all channels, "
  293. f"but tensor '{tensor_name}' is missing them at index {index}.",
  294. )
  295. return True, None
  296. def is_valid(
  297. self,
  298. initializers: dict[str, onnx.TensorProto],
  299. activation_names: set[str],
  300. default_activation_qtype,
  301. ) -> tuple[bool, str | None]:
  302. self.quant_types = set()
  303. # Validate that compatible/valid overrides are provided.
  304. if self.overrides:
  305. for tensor_name, quant_overrides_list in self.overrides.items():
  306. if tensor_name not in initializers and tensor_name not in activation_names:
  307. return False, f"Tensor '{tensor_name}' in TensorQuantOverrides is not present in the model"
  308. if not isinstance(quant_overrides_list, list):
  309. return False, f"Tensor quantization overrides for '{tensor_name}' are not in a list"
  310. if not quant_overrides_list:
  311. continue
  312. if not isinstance(quant_overrides_list[0], dict):
  313. return False, f"Tensor quantization overrides at index 0 for '{tensor_name}' are not in a dict"
  314. if not quant_overrides_list[0]:
  315. continue
  316. axis = quant_overrides_list[0].get("axis")
  317. is_per_channel = len(quant_overrides_list) > 1 or axis is not None
  318. if is_per_channel:
  319. return self._is_valid_per_channel(initializers, tensor_name, quant_overrides_list)
  320. return self._is_valid_per_tensor(
  321. initializers, default_activation_qtype, tensor_name, quant_overrides_list[0]
  322. )
  323. return True, None
  324. def update_tensor_overrides(
  325. self,
  326. tensor_name: str,
  327. new_vals: dict[str, Any],
  328. channels: list[int] | None = None,
  329. overwrite: bool = True,
  330. ) -> bool:
  331. if not new_vals:
  332. return False
  333. channels = set(channels) if channels is not None else None
  334. have_overrides = self.overrides.get(tensor_name)
  335. # If `overwrite` is False, check if we would overwrite anything.
  336. do_update = True
  337. if not overwrite and have_overrides:
  338. for channel, overrides in enumerate(self.overrides[tensor_name]):
  339. if channels is not None and channel not in channels:
  340. continue
  341. if set(new_vals).intersection(set(overrides)):
  342. do_update = False
  343. break
  344. # Do the update if `overwrite` is True or if nothing is overwritten (do not want partial overwrites).
  345. if do_update:
  346. if not have_overrides:
  347. self.overrides[tensor_name] = [{}]
  348. for channel, overrides in enumerate(self.overrides[tensor_name]):
  349. if channels is not None and channel not in channels:
  350. continue
  351. overrides.update(new_vals)
  352. return do_update
  353. def get_node_output_qtype_info(
  354. self,
  355. output_name: str,
  356. default_qtype: QuantType | None,
  357. default_symmetric: bool | None = None,
  358. ) -> QuantTypeInfo:
  359. # Outputs are activations, which do not support 'reduce_range' or 'axis'
  360. if output_name not in self.overrides:
  361. return QuantTypeInfo(default_qtype, default_symmetric)
  362. tensor_overrides = self.overrides[output_name][0]
  363. return QuantTypeInfo(
  364. tensor_overrides.get("quant_type", default_qtype),
  365. tensor_overrides.get("symmetric", default_symmetric),
  366. )
  367. def get_node_input_qtype_info(
  368. self,
  369. input_name: str,
  370. node_name: str,
  371. default_qtype: QuantType | None,
  372. default_symmetric: bool | None = None,
  373. default_reduce_range: bool | None = None,
  374. ) -> QuantTypeInfo:
  375. if input_name not in self.overrides or not self.overrides[input_name]:
  376. return QuantTypeInfo(default_qtype, default_symmetric, default_reduce_range)
  377. # Get the first overrides dict in the list. This works for both per-tensor and per-channel
  378. # quantization because all channels must use the same quant type.
  379. tensor_overrides = self.overrides[input_name][0]
  380. producer_type = tensor_overrides.get("quant_type", default_qtype)
  381. if "convert" not in tensor_overrides:
  382. return QuantTypeInfo(
  383. producer_type,
  384. tensor_overrides.get("symmetric", default_symmetric),
  385. tensor_overrides.get("reduce_range", default_reduce_range),
  386. tensor_overrides.get("axis"),
  387. )
  388. # This tensor is converted. Check if the node gets the original qtype or the converted qtype.
  389. convert_dict = tensor_overrides["convert"]
  390. qtype_info = QuantTypeInfo(
  391. producer_type,
  392. convert_dict.get("symmetric", default_symmetric),
  393. # Converted tensors are not initializers, so do not have 'axis' or 'reduce_range'.
  394. )
  395. # Check if all nodes receive the converted type (i.e., recv_nodes is None) or this node
  396. # is in the list of consumers (recv_nodes).
  397. if ("recv_nodes" not in convert_dict) or (node_name in convert_dict["recv_nodes"]):
  398. qtype_info.quant_type = convert_dict["quant_type"]
  399. return qtype_info
  400. def pprint_str(self, indent=None) -> str:
  401. return json.dumps(self.overrides, default=str, indent=indent)
  402. def empty(self) -> bool:
  403. return not self.overrides
  404. def get_dict(self) -> dict[str, list[dict[str, Any]]]:
  405. return self.overrides
  406. # Required implementations of abstract methods in collections.abc.MutableMapping
  407. # so that this class can be used like a dict.
  408. def __setitem__(self, key: str, value: list[dict]):
  409. self.overrides[key] = value
  410. def __getitem__(self, key: str) -> list[dict]:
  411. return self.overrides[key]
  412. def __delitem__(self, key: str):
  413. del self.overrides[key]
  414. def __iter__(self):
  415. return iter(self.overrides)
  416. def __len__(self):
  417. return len(self.overrides)
  418. def __str__(self) -> str:
  419. return str(self.overrides)
  420. def __repr__(self) -> str:
  421. return f"{super().__repr__()}, TensorQuantOverridesHelper({self.overrides})"