quantization_mappings.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. import copy
  2. from collections.abc import Callable
  3. from typing import Any
  4. import torch
  5. import torch.ao.nn as ao_nn
  6. import torch.ao.nn.intrinsic as nni
  7. import torch.ao.nn.intrinsic.qat as nniqat
  8. import torch.ao.nn.intrinsic.quantized as nniq
  9. import torch.ao.nn.intrinsic.quantized.dynamic as nniqd
  10. import torch.ao.nn.qat as nnqat
  11. import torch.ao.nn.qat.dynamic as nnqatd
  12. import torch.ao.nn.quantized as nnq
  13. import torch.ao.nn.quantized.dynamic as nnqd
  14. import torch.ao.nn.quantized.reference as nnqr
  15. # Because `torch.ao.nn` uses lazy imports, we need to make
  16. # sure we import the contents explicitly here.
  17. import torch.ao.nn.sparse
  18. import torch.nn.functional as F
  19. from torch import nn
  20. from torch.ao.quantization.fake_quantize import (
  21. default_fixed_qparams_range_0to1_fake_quant,
  22. default_fixed_qparams_range_neg1to1_fake_quant,
  23. )
  24. from torch.ao.quantization.stubs import DeQuantStub, QuantStub
  25. from torch.ao.quantization.utils import get_combined_dict
  26. from torch.nn.utils.parametrize import type_before_parametrizations
  27. __all__ = [
  28. "DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS",
  29. "DEFAULT_STATIC_QUANT_MODULE_MAPPINGS",
  30. "DEFAULT_QAT_MODULE_MAPPINGS",
  31. "DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS",
  32. "DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS",
  33. "DEFAULT_MODULE_TO_ACT_POST_PROCESS",
  34. "DEFAULT_STATIC_SPARSE_QUANT_MODULE_MAPPINGS",
  35. "DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS",
  36. "no_observer_set",
  37. "get_default_static_quant_module_mappings",
  38. "get_default_static_quant_reference_module_mappings",
  39. "get_embedding_static_quant_module_mappings",
  40. "get_default_static_sparse_quant_module_mappings",
  41. "get_static_quant_module_class",
  42. "get_dynamic_quant_module_class",
  43. "get_default_qat_module_mappings",
  44. "get_embedding_qat_module_mappings",
  45. "get_default_dynamic_quant_module_mappings",
  46. "get_default_dynamic_sparse_quant_module_mappings",
  47. "get_default_qconfig_propagation_list",
  48. "get_default_compare_output_module_list",
  49. "get_default_float_to_quantized_operator_mappings",
  50. "get_quantized_operator",
  51. ]
  52. # Default map for swapping float module to reference quantized modules
  53. DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS: dict[Callable, Any] = {
  54. QuantStub: nnq.Quantize,
  55. DeQuantStub: nnq.DeQuantize,
  56. nn.Linear: nnqr.Linear,
  57. nn.Conv1d: nnqr.Conv1d,
  58. nn.Conv2d: nnqr.Conv2d,
  59. nn.Conv3d: nnqr.Conv3d,
  60. nn.ConvTranspose1d: nnqr.ConvTranspose1d,
  61. nn.ConvTranspose2d: nnqr.ConvTranspose2d,
  62. nn.ConvTranspose3d: nnqr.ConvTranspose3d,
  63. nn.Embedding: nnqr.Embedding,
  64. nn.EmbeddingBag: nnqr.EmbeddingBag,
  65. nn.GRUCell: nnqr.GRUCell,
  66. nn.LSTMCell: nnqr.LSTMCell,
  67. nn.RNNCell: nnqr.RNNCell,
  68. nn.LSTM: nnqr.LSTM,
  69. }
  70. # Default map for swapping float module to quantized ones
  71. DEFAULT_STATIC_QUANT_MODULE_MAPPINGS: dict[Callable, Any] = {
  72. QuantStub: nnq.Quantize,
  73. DeQuantStub: nnq.DeQuantize,
  74. nn.BatchNorm2d: nnq.BatchNorm2d,
  75. nn.BatchNorm3d: nnq.BatchNorm3d,
  76. nn.Dropout: nnq.Dropout,
  77. nn.Conv1d: nnq.Conv1d,
  78. nn.Conv2d: nnq.Conv2d,
  79. nn.Conv3d: nnq.Conv3d,
  80. nn.ConvTranspose1d: nnq.ConvTranspose1d,
  81. nn.ConvTranspose2d: nnq.ConvTranspose2d,
  82. nn.ConvTranspose3d: nnq.ConvTranspose3d,
  83. nn.ELU: nnq.ELU,
  84. nn.Embedding: nnq.Embedding,
  85. nn.EmbeddingBag: nnq.EmbeddingBag,
  86. nn.GroupNorm: nnq.GroupNorm,
  87. nn.Hardswish: nnq.Hardswish,
  88. nn.InstanceNorm1d: nnq.InstanceNorm1d,
  89. nn.InstanceNorm2d: nnq.InstanceNorm2d,
  90. nn.InstanceNorm3d: nnq.InstanceNorm3d,
  91. nn.LayerNorm: nnq.LayerNorm,
  92. nn.LeakyReLU: nnq.LeakyReLU,
  93. nn.modules.linear.NonDynamicallyQuantizableLinear: nnq.Linear,
  94. nn.Linear: nnq.Linear,
  95. nn.ReLU6: nnq.ReLU6,
  96. nn.PReLU: nnq.PReLU,
  97. # Wrapper Modules:
  98. nnq.FloatFunctional: nnq.QFunctional,
  99. # Intrinsic modules:
  100. nni.BNReLU2d: nniq.BNReLU2d,
  101. nni.BNReLU3d: nniq.BNReLU3d,
  102. nni.ConvReLU1d: nniq.ConvReLU1d,
  103. nni.ConvReLU2d: nniq.ConvReLU2d,
  104. nni.ConvReLU3d: nniq.ConvReLU3d,
  105. nni.ConvAdd2d: nniq.ConvAdd2d,
  106. nni.ConvAddReLU2d: nniq.ConvAddReLU2d,
  107. nni.LinearReLU: nniq.LinearReLU,
  108. nni.LinearLeakyReLU: nniq.LinearLeakyReLU,
  109. nni.LinearTanh: nniq.LinearTanh,
  110. nniqat.ConvBn1d: nnq.Conv1d,
  111. nniqat.ConvBn2d: nnq.Conv2d,
  112. nniqat.ConvBn3d: nnq.Conv3d,
  113. nniqat.ConvBnReLU1d: nniq.ConvReLU1d,
  114. nniqat.ConvBnReLU2d: nniq.ConvReLU2d,
  115. nniqat.ConvBnReLU3d: nniq.ConvReLU3d,
  116. nniqat.ConvReLU2d: nniq.ConvReLU2d,
  117. nniqat.ConvReLU3d: nniq.ConvReLU3d,
  118. nniqat.LinearReLU: nniq.LinearReLU,
  119. nniqat.LinearBn1d: nnq.Linear,
  120. # QAT modules:
  121. nnqat.Linear: nnq.Linear,
  122. nnqat.Conv2d: nnq.Conv2d,
  123. nnqat.Conv3d: nnq.Conv3d,
  124. }
  125. # Default map for swapping float module to qat modules
  126. DEFAULT_QAT_MODULE_MAPPINGS: dict[Callable, Any] = {
  127. nn.Conv2d: nnqat.Conv2d,
  128. nn.Conv3d: nnqat.Conv3d,
  129. nn.Linear: nnqat.Linear,
  130. nn.modules.linear.NonDynamicallyQuantizableLinear: nnqat.Linear,
  131. # Intrinsic modules:
  132. nni.ConvBn1d: nniqat.ConvBn1d,
  133. nni.ConvBn2d: nniqat.ConvBn2d,
  134. nni.ConvBn3d: nniqat.ConvBn3d,
  135. nni.ConvBnReLU1d: nniqat.ConvBnReLU1d,
  136. nni.ConvBnReLU2d: nniqat.ConvBnReLU2d,
  137. nni.ConvBnReLU3d: nniqat.ConvBnReLU3d,
  138. nni.ConvReLU2d: nniqat.ConvReLU2d,
  139. nni.ConvReLU3d: nniqat.ConvReLU3d,
  140. nni.LinearReLU: nniqat.LinearReLU,
  141. nni.LinearBn1d: nniqat.LinearBn1d,
  142. }
  143. # Default map for swapping dynamic modules
  144. DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS: dict[Callable, Any] = {
  145. nn.GRUCell: nnqd.GRUCell,
  146. nn.Linear: nnqd.Linear,
  147. nnqatd.Linear: nnqd.Linear,
  148. nn.modules.linear.NonDynamicallyQuantizableLinear: nnqd.Linear,
  149. nn.LSTM: nnqd.LSTM,
  150. nn.GRU: nnqd.GRU,
  151. nn.LSTMCell: nnqd.LSTMCell,
  152. nn.RNNCell: nnqd.RNNCell,
  153. nni.LinearReLU: nniqd.LinearReLU,
  154. nn.EmbeddingBag: nnq.EmbeddingBag,
  155. nn.Embedding: nnq.Embedding,
  156. # Don't want to enable these by default because the numerical
  157. # accuracy is poor compared to other dynamic ops
  158. # nn.Conv1d: nnqd.Conv1d,
  159. # nn.Conv2d: nnqd.Conv2d,
  160. # nn.Conv3d: nnqd.Conv3d,
  161. # nn.ConvTranspose1d: nnqd.ConvTranspose1d,
  162. # nn.ConvTranspose2d: nnqd.ConvTranspose2d,
  163. # nn.ConvTranspose3d: nnqd.ConvTranspose3d,
  164. }
  165. # Allowlist for propagating the qconfig
  166. _INCLUDE_QCONFIG_PROPAGATE_LIST: set[Callable] = {
  167. nn.Sequential,
  168. }
  169. # Default mapping from floating point function or torch ops to quantized ops
  170. # TODO: merge with default static mapping
  171. DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS: dict[Callable | str, Callable] = {
  172. F.elu: torch.ops.quantized.elu,
  173. F.hardswish: torch.ops.quantized.hardswish,
  174. F.instance_norm: torch.ops.quantized.instance_norm,
  175. F.layer_norm: torch.ops.quantized.layer_norm,
  176. F.leaky_relu: torch.ops.quantized.leaky_relu,
  177. F.dropout: torch.ops.quantized.dropout,
  178. }
  179. # mapping from module to output activation post process class
  180. DEFAULT_MODULE_TO_ACT_POST_PROCESS: dict[Callable, Callable] = {
  181. nn.Hardsigmoid: default_fixed_qparams_range_0to1_fake_quant,
  182. nn.Sigmoid: default_fixed_qparams_range_0to1_fake_quant,
  183. nn.Softmax: default_fixed_qparams_range_0to1_fake_quant,
  184. nn.Tanh: default_fixed_qparams_range_neg1to1_fake_quant,
  185. }
  186. # Default map for swapping float module to static sparse quantized ones
  187. DEFAULT_STATIC_SPARSE_QUANT_MODULE_MAPPINGS: dict[Callable, Any] = {
  188. nn.Linear: ao_nn.sparse.quantized.Linear
  189. }
  190. # Default map for swapping float module to dynamic sparse quantized ones
  191. DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS: dict[Callable, Any] = {
  192. nn.Linear: ao_nn.sparse.quantized.dynamic.Linear
  193. }
  194. def no_observer_set() -> set[Any]:
  195. r"""These modules cannot have observers inserted by default."""
  196. no_observers = {nn.quantizable.LSTM, nn.quantizable.MultiheadAttention}
  197. return no_observers
  198. def get_default_static_quant_module_mappings() -> dict[Callable, Any]:
  199. """Get module mapping for post training static quantization"""
  200. return copy.deepcopy(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS)
  201. def get_default_static_quant_reference_module_mappings() -> dict[Callable, Any]:
  202. """Get reference module mapping for post training static quantization"""
  203. return copy.deepcopy(DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS)
  204. def get_embedding_static_quant_module_mappings() -> dict[Callable, Any]:
  205. """Get module mapping, including mapping for embedding QAT"""
  206. mapping = copy.deepcopy(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS)
  207. mapping[nnqat.EmbeddingBag] = nnq.EmbeddingBag
  208. mapping[nnqat.Embedding] = nnq.Embedding
  209. return mapping
  210. def get_default_static_sparse_quant_module_mappings() -> dict[Callable, Any]:
  211. """Get module mapping for post training static sparse quantization"""
  212. return copy.deepcopy(DEFAULT_STATIC_SPARSE_QUANT_MODULE_MAPPINGS)
  213. def get_static_quant_module_class(
  214. float_module_class: Callable,
  215. additional_static_quant_mapping: dict[Callable, Any] | None = None,
  216. is_reference: bool = False,
  217. ) -> Any:
  218. r"""n Get the statically quantized module class corresponding to
  219. the floating point module class
  220. """
  221. if additional_static_quant_mapping is None:
  222. additional_static_quant_mapping = {}
  223. all_mappings = get_combined_dict(
  224. DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS
  225. if is_reference
  226. else DEFAULT_STATIC_QUANT_MODULE_MAPPINGS,
  227. additional_static_quant_mapping,
  228. )
  229. static_quant_module_class = all_mappings.get(float_module_class, None)
  230. if static_quant_module_class is None:
  231. raise AssertionError(
  232. f"Floating point module class {str(float_module_class)}"
  233. + " does not have a corresponding quantized module class"
  234. )
  235. return copy.deepcopy(static_quant_module_class)
  236. def get_dynamic_quant_module_class(
  237. float_module_class: Callable,
  238. additional_dynamic_quant_mapping: dict[Callable, Any] | None = None,
  239. ) -> Any:
  240. r"""n Get the dynamically quantized module class corresponding to
  241. the floating point module class
  242. """
  243. if additional_dynamic_quant_mapping is None:
  244. additional_dynamic_quant_mapping = {}
  245. all_mappings = get_combined_dict(
  246. DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, additional_dynamic_quant_mapping
  247. )
  248. dynamic_quant_module_class = all_mappings.get(float_module_class, None)
  249. if dynamic_quant_module_class is None:
  250. raise AssertionError(
  251. f"Floating point module class {str(float_module_class)}"
  252. + " does not have a corresponding quantized module class"
  253. )
  254. return copy.deepcopy(dynamic_quant_module_class)
  255. def get_default_qat_module_mappings() -> dict[Callable, Any]:
  256. """Get default module mapping for quantization aware training"""
  257. return copy.deepcopy(DEFAULT_QAT_MODULE_MAPPINGS)
  258. def get_embedding_qat_module_mappings() -> dict[Callable, Any]:
  259. """Get module mapping for quantization aware training
  260. This is includes default values in addition to
  261. enabling qat for embeddings.
  262. """
  263. mapping = copy.deepcopy(DEFAULT_QAT_MODULE_MAPPINGS)
  264. mapping[nn.EmbeddingBag] = nnqat.EmbeddingBag
  265. mapping[nn.Embedding] = nnqat.Embedding
  266. return mapping
  267. def get_default_dynamic_quant_module_mappings() -> dict[Callable, Any]:
  268. """Get module mapping for post training dynamic quantization"""
  269. return DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS
  270. def get_default_dynamic_sparse_quant_module_mappings() -> dict[Callable, Any]:
  271. """Get module mapping for post training dynamic sparse quantization"""
  272. return DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS
  273. def get_default_qconfig_propagation_list() -> set[Callable]:
  274. """Get the default list of module types that we'll attach qconfig
  275. attribute to in prepare
  276. """
  277. QCONFIG_PROPAGATE_MODULE_CLASS_LIST = (
  278. set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys())
  279. | set(DEFAULT_QAT_MODULE_MAPPINGS.keys())
  280. | set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys())
  281. | _INCLUDE_QCONFIG_PROPAGATE_LIST
  282. )
  283. return copy.deepcopy(QCONFIG_PROPAGATE_MODULE_CLASS_LIST)
  284. def get_default_compare_output_module_list() -> set[Callable]:
  285. """Get list of module class types that we will record output
  286. in numeric suite
  287. """
  288. NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST = (
  289. set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.values())
  290. | set(DEFAULT_QAT_MODULE_MAPPINGS.values())
  291. | set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.values())
  292. | set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys())
  293. | set(DEFAULT_QAT_MODULE_MAPPINGS.keys())
  294. | set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys())
  295. | _INCLUDE_QCONFIG_PROPAGATE_LIST
  296. )
  297. return copy.deepcopy(NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST)
  298. def get_default_float_to_quantized_operator_mappings() -> dict[
  299. Callable | str, Callable
  300. ]:
  301. return copy.deepcopy(DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS)
  302. # TODO: merge with get_static_quant_module_class
  303. def get_quantized_operator(float_op: Callable | str) -> Callable:
  304. """Get the quantized operator corresponding to the float operator"""
  305. quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op)
  306. if quantized_op is None:
  307. raise AssertionError(
  308. f"Operator {str(float_op)} does not have corresponding quantized op"
  309. )
  310. return quantized_op
  311. def _get_special_act_post_process(module: torch.nn.Module) -> Callable | None:
  312. r"""Get the special activation post process for `module`, this has
  313. higher priority than the activation post process in `qconfig`
  314. e.g.
  315. input: torch.nn.Sigmoid
  316. output: default_affine_fixed_qparam_fake_quant
  317. """
  318. return DEFAULT_MODULE_TO_ACT_POST_PROCESS.get(type_before_parametrizations(module))
  319. def _has_special_act_post_process(module: torch.nn.Module) -> bool:
  320. return module.training and type(module) in DEFAULT_MODULE_TO_ACT_POST_PROCESS