quantize.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829
  1. # mypy: allow-untyped-defs
  2. import copy
  3. import inspect
  4. import itertools
  5. import typing_extensions
  6. import warnings
  7. import torch
  8. import torch.ao.nn.quantized as nnq
  9. import torch.nn as nn
  10. from torch.ao.nn.intrinsic import _FusedModule
  11. from torch.ao.quantization.observer import _is_activation_post_process
  12. from torch.ao.quantization.qconfig import (
  13. _activation_is_memoryless,
  14. _add_module_to_qconfig_obs_ctr,
  15. default_dynamic_qconfig,
  16. float16_dynamic_qconfig,
  17. float_qparams_weight_only_qconfig,
  18. float_qparams_weight_only_qconfig_4bit,
  19. )
  20. from torch.ao.quantization.quantization_mappings import (
  21. _get_special_act_post_process,
  22. _has_special_act_post_process,
  23. get_default_dynamic_quant_module_mappings,
  24. get_default_qat_module_mappings,
  25. get_default_qconfig_propagation_list,
  26. get_default_static_quant_module_mappings,
  27. get_default_static_quant_reference_module_mappings,
  28. no_observer_set,
  29. )
  30. from torch.ao.quantization.stubs import DeQuantStub, QuantWrapper
  31. from torch.nn.utils.parametrize import type_before_parametrizations
  32. from .utils import (
  33. DEPRECATION_WARNING,
  34. get_qparam_dict,
  35. has_no_children_ignoring_parametrizations,
  36. )
  37. __all__ = [
  38. "get_default_custom_config_dict",
  39. "propagate_qconfig_",
  40. "add_quant_dequant",
  41. "prepare",
  42. "quantize",
  43. "quantize_dynamic",
  44. "prepare_qat",
  45. "quantize_qat",
  46. "convert",
  47. "swap_module",
  48. ]
  49. # TODO remove this once BC is no longer required to avoid a SEV
  50. is_activation_post_process = _is_activation_post_process
  51. _DEFAULT_CUSTOM_CONFIG_DICT = {
  52. "float_to_observed_custom_module_class": {
  53. nn.LSTM: nn.quantizable.LSTM,
  54. nn.MultiheadAttention: nn.quantizable.MultiheadAttention,
  55. },
  56. "observed_to_quantized_custom_module_class": {
  57. nn.quantizable.LSTM: nn.quantized.LSTM,
  58. nn.quantizable.MultiheadAttention: nn.quantized.MultiheadAttention,
  59. },
  60. }
  61. def get_default_custom_config_dict():
  62. r"""Defines the default custom config dict."""
  63. return _DEFAULT_CUSTOM_CONFIG_DICT
  64. def _propagate_qconfig_helper(
  65. module,
  66. qconfig_dict,
  67. qconfig_parent=None,
  68. prefix="",
  69. prepare_custom_config_dict=None,
  70. ):
  71. r"""This is a helper function for `propagate_qconfig_`
  72. Args:
  73. module: input module
  74. qconfig_dict: dictionary that maps from name of submodule to quantization
  75. configuration
  76. qconfig_parent: quantization config of parent module, we will fallback to
  77. this config when there is no specified config for current
  78. module
  79. prefix: corresponding prefix of the current module, used as key in
  80. qconfig_dict
  81. prepare_custom_config_dict: dictionary for custom handling of modules
  82. see docs for :func:`~torch.ao.quantization.prepare_fx`
  83. Return:
  84. None, module is modified inplace with qconfig attached
  85. """
  86. module_qconfig = qconfig_dict.get(
  87. type_before_parametrizations(module), qconfig_parent
  88. )
  89. module_qconfig = qconfig_dict.get(prefix, module_qconfig)
  90. module_qconfig = getattr(module, "qconfig", module_qconfig)
  91. torch.ao.quantization.qconfig._assert_valid_qconfig(module_qconfig, module)
  92. qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(module_qconfig, module)
  93. module.qconfig = qconfig_with_device_check
  94. for name, child in module.named_children():
  95. module_prefix = prefix + "." + name if prefix else name
  96. # do no not propagate qconfig to child if child is non traceable
  97. if prepare_custom_config_dict is None or not (
  98. name in prepare_custom_config_dict.get("non_traceable_module_name", [])
  99. or type(child)
  100. in prepare_custom_config_dict.get("non_traceable_module_class", [])
  101. ):
  102. _propagate_qconfig_helper(
  103. child, qconfig_dict, qconfig_with_device_check, module_prefix
  104. )
  105. def propagate_qconfig_(module, qconfig_dict=None, prepare_custom_config_dict=None):
  106. r"""Propagate qconfig through the module hierarchy and assign `qconfig`
  107. attribute on each leaf module
  108. Args:
  109. module: input module
  110. qconfig_dict: dictionary that maps from name or type of submodule to
  111. quantization configuration, qconfig applies to all submodules of a
  112. given module unless qconfig for the submodules are specified (when
  113. the submodule already has qconfig attribute)
  114. prepare_custom_config_dict: dictionary for custom handling of modules
  115. see docs for :func:`~torch.ao.quantization.prepare_fx`
  116. Return:
  117. None, module is modified inplace with qconfig attached
  118. """
  119. if qconfig_dict is None:
  120. qconfig_dict = {}
  121. if prepare_custom_config_dict is None:
  122. prepare_custom_config_dict = {}
  123. _propagate_qconfig_helper(
  124. module, qconfig_dict, prepare_custom_config_dict=prepare_custom_config_dict
  125. )
  126. def _observer_forward_hook(self, input, output):
  127. r"""Forward hook that calls observer on the output"""
  128. return self.activation_post_process(output)
  129. def _observer_forward_pre_hook(self, input):
  130. r"""Forward pre hook that calls observer on the output"""
  131. return self.activation_post_process(input[0])
  132. def _register_activation_post_process_hook(module, pre_hook=False):
  133. if not hasattr(module, "activation_post_process"):
  134. raise AssertionError(
  135. "Expect activation_post_process attribute already attached to the module"
  136. )
  137. if pre_hook:
  138. module.register_forward_pre_hook(_observer_forward_pre_hook, prepend=True)
  139. else:
  140. module.register_forward_hook(_observer_forward_hook, prepend=True)
  141. def _add_observer_(
  142. module,
  143. qconfig_propagation_list=None,
  144. non_leaf_module_list=None,
  145. device=None,
  146. custom_module_class_mapping=None,
  147. ):
  148. r"""Add observer for the leaf child of the module.
  149. This function insert observer module to all leaf child module that
  150. has a valid qconfig attribute.
  151. Args:
  152. module: input module with qconfig attributes for all the leaf modules that we want to quantize
  153. qconfig_propagation_list: a list of quantizable modules that will have observers added to them
  154. if they are leaf nodes
  155. device: parent device, if any
  156. non_leaf_module_list: list of non-leaf modules we want to add observer
  157. Return:
  158. None, module is modified inplace with added observer modules and forward_hooks
  159. """
  160. if qconfig_propagation_list is None:
  161. qconfig_propagation_list = get_default_qconfig_propagation_list()
  162. if custom_module_class_mapping is None:
  163. custom_module_class_mapping = {}
  164. # respect device affinity when adding observers
  165. if device is None:
  166. devices = _get_unique_devices_(module)
  167. if len(devices) > 1:
  168. raise AssertionError(
  169. f"_add_observer_ only works with cpu or single-device CUDA modules, but got devices {devices}"
  170. )
  171. device = next(iter(devices)) if len(devices) > 0 else None
  172. def get_activation_post_process(qconfig, device, special_act_post_process=None):
  173. activation = (
  174. qconfig.activation()
  175. if special_act_post_process is None
  176. else special_act_post_process()
  177. )
  178. if device is not None:
  179. activation.to(device)
  180. return activation
  181. def needs_observation(m):
  182. return hasattr(m, "qconfig") and m.qconfig is not None
  183. def insert_activation_post_process(m, special_act_post_process=None):
  184. """Adds an activation post process module and register
  185. a pre or post hook that calls the module
  186. """
  187. # We don't insert observer/fake_quantize for DeQuantStub
  188. if needs_observation(m) and not isinstance(m, DeQuantStub):
  189. # observer and hook will be gone after we swap the module
  190. m.add_module(
  191. "activation_post_process",
  192. get_activation_post_process(
  193. m.qconfig, device, special_act_post_process
  194. ),
  195. )
  196. # Register observer as the first entry in the hook list
  197. # All post forward hooks are preserved and will be executed after the observer before convert
  198. _register_activation_post_process_hook(
  199. m, pre_hook=_activation_is_memoryless(m.qconfig)
  200. )
  201. for name, child in module.named_children():
  202. # TODO remove Dropout special after codebase stable
  203. if type_before_parametrizations(child) is nn.Dropout:
  204. continue
  205. elif issubclass(
  206. type_before_parametrizations(child), (nnq.FloatFunctional, nnq.QFunctional)
  207. ):
  208. if needs_observation(child):
  209. if not hasattr(child, "activation_post_process"):
  210. raise AssertionError(
  211. f"functional class {type_before_parametrizations(child)} has no pre-defined `activation_post_process`"
  212. )
  213. child.activation_post_process = get_activation_post_process(
  214. child.qconfig, device
  215. )
  216. elif isinstance(child, _FusedModule):
  217. # activation_post_process are now added directly to nn.Sequential/_FusedModule
  218. if needs_observation(child):
  219. insert_activation_post_process(child)
  220. elif (
  221. non_leaf_module_list is not None
  222. and type_before_parametrizations(child) in non_leaf_module_list
  223. ):
  224. if needs_observation(child):
  225. insert_activation_post_process(child)
  226. elif _has_special_act_post_process(child):
  227. special_act_post_process = _get_special_act_post_process(child)
  228. insert_activation_post_process(child, special_act_post_process)
  229. elif (
  230. needs_observation(child)
  231. and type_before_parametrizations(child) in custom_module_class_mapping
  232. ):
  233. observed_class = custom_module_class_mapping[
  234. type_before_parametrizations(child)
  235. ]
  236. observed_child = observed_class.from_float(child)
  237. setattr(module, name, observed_child)
  238. # TODO: These are the modules that cannot be observed
  239. # Once there are more, we should move them to a separate list
  240. if not issubclass(observed_class, tuple(no_observer_set())):
  241. insert_activation_post_process(observed_child)
  242. else:
  243. _add_observer_(
  244. child,
  245. qconfig_propagation_list,
  246. non_leaf_module_list,
  247. device,
  248. custom_module_class_mapping,
  249. )
  250. # Insert observers only for leaf nodes, note that this observer is for
  251. # the output of the module, for input QuantStub will observe them
  252. if (
  253. has_no_children_ignoring_parametrizations(module)
  254. and not isinstance(module, torch.nn.Sequential)
  255. and type_before_parametrizations(module) in qconfig_propagation_list
  256. ):
  257. insert_activation_post_process(module)
  258. # This is a special case for AdaRound eager mode
  259. # AdaRound contains weight_fake_quant to be propagated from API to convert
  260. # leaf node check with a number of children looks naive assumption that blocks
  261. # Adding an exception case for AdaRound
  262. if (
  263. hasattr(module, "weight_fake_quant")
  264. and not isinstance(module, torch.nn.Sequential)
  265. and type_before_parametrizations(module) in qconfig_propagation_list
  266. ):
  267. insert_activation_post_process(module)
  268. def _get_unique_devices_(module):
  269. return {p.device for p in module.parameters() if p.device.type != "meta"} | {
  270. p.device for p in module.buffers() if p.device.type != "meta"
  271. }
  272. def add_quant_dequant(module):
  273. r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig
  274. Note that this function will modify the children of module inplace and it
  275. can return a new module which wraps the input module as well.
  276. Args:
  277. module: input module with qconfig attributes for all the leaf modules
  278. that we want to quantize
  279. Return:
  280. Either the inplace modified module with submodules wrapped in
  281. `QuantWrapper` based on qconfig or a new `QuantWrapper` module which
  282. wraps the input module, the latter case only happens when the input
  283. module is a leaf module and we want to quantize it.
  284. """
  285. if (
  286. has_no_children_ignoring_parametrizations(module)
  287. and hasattr(module, "qconfig")
  288. and module.qconfig
  289. ):
  290. return QuantWrapper(module)
  291. for name, child in module.named_children():
  292. module._modules[name] = add_quant_dequant(child)
  293. return module
  294. @typing_extensions.deprecated(DEPRECATION_WARNING)
  295. def prepare(
  296. model,
  297. inplace=False,
  298. allow_list=None,
  299. observer_non_leaf_module_list=None,
  300. prepare_custom_config_dict=None,
  301. ):
  302. r"""Prepares a copy of the model for quantization calibration or quantization-aware training.
  303. Quantization configuration should be assigned preemptively
  304. to individual submodules in `.qconfig` attribute.
  305. The model will be attached with observer or fake quant modules, and qconfig
  306. will be propagated.
  307. Args:
  308. `model`: input model to be modified in-place
  309. `inplace`: carry out model transformations in-place, the original module is mutated
  310. `allow_list`: list of quantizable modules
  311. `observer_non_leaf_module_list`: list of non-leaf modules we want to add observer
  312. `prepare_custom_config_dict`: customization configuration dictionary for prepare function
  313. .. code-block:: python
  314. # Example of prepare_custom_config_dict:
  315. prepare_custom_config_dict = {
  316. # user will manually define the corresponding observed
  317. # module class which has a from_float class method that converts
  318. # float custom module to observed custom module
  319. "float_to_observed_custom_module_class": {CustomModule: ObservedCustomModule}
  320. }
  321. """
  322. torch._C._log_api_usage_once("quantization_api.quantize.prepare")
  323. if prepare_custom_config_dict is None:
  324. prepare_custom_config_dict = get_default_custom_config_dict()
  325. custom_module_class_mapping = prepare_custom_config_dict.get(
  326. "float_to_observed_custom_module_class", {}
  327. )
  328. if not inplace:
  329. model = copy.deepcopy(model)
  330. # TODO: remove allow_list
  331. qconfig_propagation_list = allow_list
  332. if allow_list is None:
  333. qconfig_propagation_list = get_default_qconfig_propagation_list()
  334. propagate_qconfig_(model, qconfig_dict=None)
  335. # sanity check common API misusage
  336. if not any(hasattr(m, "qconfig") and m.qconfig for m in model.modules()):
  337. warnings.warn(
  338. "None of the submodule got qconfig applied. Make sure you "
  339. "passed correct configuration through `qconfig_dict` or "
  340. "by assigning the `.qconfig` attribute directly on submodules",
  341. stacklevel=2,
  342. )
  343. _add_observer_(
  344. model,
  345. qconfig_propagation_list,
  346. observer_non_leaf_module_list,
  347. custom_module_class_mapping=custom_module_class_mapping,
  348. )
  349. return model
  350. def _remove_activation_post_process(module):
  351. # TODO: maybe we should change activation_post_process to _activation_post_process
  352. # to prevent it from being used by user
  353. if hasattr(module, "activation_post_process") and _is_activation_post_process(
  354. module.activation_post_process
  355. ):
  356. delattr(module, "activation_post_process")
  357. # remove activation_post_process pre and post hooks
  358. def remove_hooks(pre_hook=False):
  359. hook_map = module._forward_pre_hooks if pre_hook else module._forward_hooks
  360. observer_hook = (
  361. _observer_forward_pre_hook if pre_hook else _observer_forward_hook
  362. )
  363. handle_ids_to_remove = set()
  364. for handle_id, hook_fn in hook_map.items():
  365. if hook_fn is observer_hook:
  366. handle_ids_to_remove.add(handle_id)
  367. for handle_id in handle_ids_to_remove:
  368. hook_map.pop(handle_id)
  369. remove_hooks(pre_hook=True)
  370. remove_hooks(pre_hook=False)
  371. # TODO: rename to something more general
  372. def _remove_qconfig(module):
  373. r"""Clean up the qconfig left in the module so that new qconfig can be
  374. propagated.
  375. Args:
  376. module: module to be cleaned up
  377. """
  378. for child in module.children():
  379. _remove_qconfig(child)
  380. if hasattr(module, "qconfig"):
  381. del module.qconfig
  382. _remove_activation_post_process(module)
  383. @typing_extensions.deprecated(DEPRECATION_WARNING)
  384. def quantize(model, run_fn, run_args, mapping=None, inplace=False):
  385. r"""Quantize the input float model with post training static quantization.
  386. First it will prepare the model for calibration, then it calls
  387. `run_fn` which will run the calibration step, after that we will
  388. convert the model to a quantized model.
  389. Args:
  390. model: input float model
  391. run_fn: a calibration function for calibrating the prepared model
  392. run_args: positional arguments for `run_fn`
  393. inplace: carry out model transformations in-place, the original module is mutated
  394. mapping: correspondence between original module types and quantized counterparts
  395. Return:
  396. Quantized model.
  397. """
  398. torch._C._log_api_usage_once("quantization_api.quantize.quantize")
  399. if mapping is None:
  400. mapping = get_default_static_quant_module_mappings()
  401. if not inplace:
  402. model = copy.deepcopy(model)
  403. model.eval()
  404. prepare(model, inplace=True)
  405. run_fn(model, *run_args)
  406. convert(model, mapping, inplace=True)
  407. return model
  408. @typing_extensions.deprecated(DEPRECATION_WARNING)
  409. def quantize_dynamic(
  410. model, qconfig_spec=None, dtype=torch.qint8, mapping=None, inplace=False
  411. ):
  412. r"""Converts a float model to dynamic (i.e. weights-only) quantized model.
  413. Replaces specified modules with dynamic weight-only quantized versions and output the quantized model.
  414. For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization
  415. by default is performed for layers with large weights size - i.e. Linear and RNN variants.
  416. Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`.
  417. If `qconfig` is provided, the `dtype` argument is ignored.
  418. Args:
  419. model: input model
  420. qconfig_spec: Either:
  421. - A dictionary that maps from name or type of submodule to quantization
  422. configuration, qconfig applies to all submodules of a given
  423. module unless qconfig for the submodules are specified (when the
  424. submodule already has qconfig attribute). Entries in the dictionary
  425. need to be QConfig instances.
  426. - A set of types and/or submodule names to apply dynamic quantization to,
  427. in which case the `dtype` argument is used to specify the bit-width
  428. inplace: carry out model transformations in-place, the original module is mutated
  429. mapping: maps type of a submodule to a type of corresponding dynamically quantized version
  430. with which the submodule needs to be replaced
  431. """
  432. torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic")
  433. if qconfig_spec is None:
  434. if dtype == torch.qint8:
  435. qconfig_spec = {
  436. nn.Linear: default_dynamic_qconfig,
  437. nn.LSTM: default_dynamic_qconfig,
  438. nn.GRU: default_dynamic_qconfig,
  439. nn.LSTMCell: default_dynamic_qconfig,
  440. nn.RNNCell: default_dynamic_qconfig,
  441. nn.GRUCell: default_dynamic_qconfig,
  442. }
  443. elif dtype == torch.float16:
  444. qconfig_spec = {
  445. nn.Linear: float16_dynamic_qconfig,
  446. nn.LSTM: float16_dynamic_qconfig,
  447. nn.GRU: float16_dynamic_qconfig,
  448. nn.LSTMCell: float16_dynamic_qconfig,
  449. nn.RNNCell: float16_dynamic_qconfig,
  450. nn.GRUCell: float16_dynamic_qconfig,
  451. }
  452. elif dtype == torch.quint8:
  453. qconfig_spec = {
  454. nn.EmbeddingBag: float_qparams_weight_only_qconfig,
  455. nn.Embedding: float_qparams_weight_only_qconfig,
  456. }
  457. elif dtype == torch.quint4x2:
  458. qconfig_spec = {
  459. nn.EmbeddingBag: float_qparams_weight_only_qconfig_4bit,
  460. }
  461. else:
  462. raise ValueError(
  463. f"Don't know how to quantize with default settings for {dtype}. Provide full qconfig please"
  464. )
  465. elif isinstance(qconfig_spec, set):
  466. if dtype is torch.qint8:
  467. default_qconfig = default_dynamic_qconfig
  468. elif dtype is torch.float16:
  469. default_qconfig = float16_dynamic_qconfig
  470. elif dtype is torch.quint8:
  471. default_qconfig = float_qparams_weight_only_qconfig
  472. elif dtype is torch.quint4x2:
  473. default_qconfig = float_qparams_weight_only_qconfig_4bit
  474. else:
  475. raise RuntimeError(
  476. "Unknown dtype specified for quantize_dynamic: ", str(dtype)
  477. )
  478. qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig)))
  479. if mapping is None:
  480. mapping = get_default_dynamic_quant_module_mappings()
  481. if not inplace:
  482. model = copy.deepcopy(model)
  483. model.eval()
  484. propagate_qconfig_(model, qconfig_spec)
  485. convert(model, mapping, inplace=True)
  486. return model
  487. @typing_extensions.deprecated(DEPRECATION_WARNING)
  488. def prepare_qat(model, mapping=None, inplace=False):
  489. r"""
  490. Prepares a copy of the model for quantization calibration or
  491. quantization-aware training and converts it to quantized version.
  492. Quantization configuration should be assigned preemptively
  493. to individual submodules in `.qconfig` attribute.
  494. Args:
  495. model: input model to be modified in-place
  496. mapping: dictionary that maps float modules to quantized modules to be
  497. replaced.
  498. inplace: carry out model transformations in-place, the original module
  499. is mutated
  500. """
  501. torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat")
  502. if not model.training:
  503. raise AssertionError("prepare_qat only works on models in training mode")
  504. if mapping is None:
  505. mapping = get_default_qat_module_mappings()
  506. if not inplace:
  507. model = copy.deepcopy(model)
  508. propagate_qconfig_(model, qconfig_dict=None)
  509. convert(model, mapping=mapping, inplace=True, remove_qconfig=False)
  510. prepare(model, observer_non_leaf_module_list=set(mapping.values()), inplace=True)
  511. return model
  512. @typing_extensions.deprecated(DEPRECATION_WARNING)
  513. def quantize_qat(model, run_fn, run_args, inplace=False):
  514. r"""Do quantization aware training and output a quantized model
  515. Args:
  516. model: input model
  517. run_fn: a function for evaluating the prepared model, can be a
  518. function that simply runs the prepared model or a training
  519. loop
  520. run_args: positional arguments for `run_fn`
  521. Return:
  522. Quantized model.
  523. """
  524. torch._C._log_api_usage_once("quantization_api.quantize.quantize_qat")
  525. if not inplace:
  526. model = copy.deepcopy(model)
  527. model.train()
  528. prepare_qat(model, inplace=True)
  529. run_fn(model, *run_args)
  530. convert(model, inplace=True)
  531. return model
  532. @typing_extensions.deprecated(DEPRECATION_WARNING)
  533. def convert(
  534. module,
  535. mapping=None,
  536. inplace=False,
  537. remove_qconfig=True,
  538. is_reference=False,
  539. convert_custom_config_dict=None,
  540. use_precomputed_fake_quant=False,
  541. ):
  542. r"""Converts submodules in input module to a different module according to `mapping`
  543. by calling `from_float` method on the target module class. And remove qconfig at the
  544. end if remove_qconfig is set to True.
  545. Args:
  546. `module`: prepared and calibrated module
  547. `mapping`: a dictionary that maps from source module type to target
  548. module type, can be overwritten to allow swapping user defined
  549. Modules
  550. `inplace`: carry out model transformations in-place, the original module
  551. is mutated
  552. `convert_custom_config_dict`: custom configuration dictionary for convert function
  553. `use_precomputed_fake_quant`: a flag to enable use of precomputed fake quant
  554. .. code-block:: python
  555. # Example of convert_custom_config_dict:
  556. convert_custom_config_dict = {
  557. # user will manually define the corresponding quantized
  558. # module class which has a from_observed class method that converts
  559. # observed custom module to quantized custom module
  560. "observed_to_quantized_custom_module_class": {
  561. ObservedCustomModule: QuantizedCustomModule
  562. }
  563. }
  564. """
  565. torch._C._log_api_usage_once("quantization_api.quantize.convert")
  566. if not inplace:
  567. module = copy.deepcopy(module)
  568. _convert(
  569. module,
  570. mapping,
  571. inplace=True,
  572. is_reference=is_reference,
  573. convert_custom_config_dict=convert_custom_config_dict,
  574. use_precomputed_fake_quant=use_precomputed_fake_quant,
  575. )
  576. if remove_qconfig:
  577. _remove_qconfig(module)
  578. return module
  579. def _convert(
  580. module,
  581. mapping=None,
  582. inplace=False,
  583. is_reference=False,
  584. convert_custom_config_dict=None,
  585. use_precomputed_fake_quant=False,
  586. ):
  587. r"""Converts submodules in input module to a different module according to `mapping`
  588. by calling `from_float` method on the target module class
  589. Args:
  590. module: input module
  591. mapping: a dictionary that maps from source module type to target
  592. module type, can be overwritten to allow swapping user defined
  593. Modules
  594. inplace: carry out model transformations in-place, the original module
  595. is mutated
  596. is_reference: a flag to enable quantized reference module
  597. use_precomputed_fake_quant: a flag to enable use of precomputed fake quant
  598. """
  599. if mapping is None:
  600. mapping = (
  601. get_default_static_quant_reference_module_mappings()
  602. if is_reference
  603. else get_default_static_quant_module_mappings()
  604. )
  605. if convert_custom_config_dict is None:
  606. convert_custom_config_dict = get_default_custom_config_dict()
  607. custom_module_class_mapping = convert_custom_config_dict.get(
  608. "observed_to_quantized_custom_module_class", {}
  609. )
  610. if not inplace:
  611. module = copy.deepcopy(module)
  612. reassign = {}
  613. for name, mod in module.named_children():
  614. # both fused modules and observed custom modules are
  615. # swapped as one unit
  616. if (
  617. not isinstance(mod, _FusedModule)
  618. and type_before_parametrizations(mod) not in custom_module_class_mapping
  619. ):
  620. _convert(
  621. mod,
  622. mapping,
  623. True, # inplace
  624. is_reference,
  625. convert_custom_config_dict,
  626. use_precomputed_fake_quant=use_precomputed_fake_quant,
  627. )
  628. reassign[name] = swap_module(
  629. mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant
  630. )
  631. for key, value in reassign.items():
  632. module._modules[key] = value
  633. return module
  634. def swap_module(
  635. mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant=False
  636. ):
  637. r"""Swaps the module if it has a quantized counterpart and it has an
  638. `observer` attached.
  639. Args:
  640. mod: input module
  641. mapping: a dictionary that maps from nn module to nnq module
  642. Return:
  643. The corresponding quantized module of `mod`
  644. """
  645. new_mod = mod
  646. if hasattr(mod, "qconfig") and mod.qconfig is not None:
  647. swapped = False
  648. if type_before_parametrizations(mod) in custom_module_class_mapping:
  649. new_mod = custom_module_class_mapping[
  650. type_before_parametrizations(mod)
  651. ].from_observed(mod)
  652. swapped = True
  653. elif type_before_parametrizations(mod) in mapping:
  654. qmod = mapping[type_before_parametrizations(mod)]
  655. if hasattr(qmod, "_IS_REFERENCE") and qmod._IS_REFERENCE:
  656. if mod.qconfig is None:
  657. raise AssertionError(
  658. "module qconfig must not be None when swapping to reference module"
  659. )
  660. weight_post_process = mod.qconfig.weight()
  661. weight_post_process(mod.weight)
  662. weight_qparams = get_qparam_dict(weight_post_process)
  663. new_mod = qmod.from_float(mod, weight_qparams)
  664. else:
  665. sig = inspect.signature(qmod.from_float)
  666. if "use_precomputed_fake_quant" in sig.parameters:
  667. new_mod = qmod.from_float(
  668. mod, use_precomputed_fake_quant=use_precomputed_fake_quant
  669. )
  670. else:
  671. new_mod = qmod.from_float(mod)
  672. swapped = True
  673. if swapped:
  674. # Preserve module's pre forward hooks. They'll be called on quantized input
  675. for pre_hook_fn in mod._forward_pre_hooks.values():
  676. new_mod.register_forward_pre_hook(pre_hook_fn)
  677. # Preserve module's post forward hooks except _observer_forward_hook
  678. # After convert they'll work with quantized output
  679. for hook_fn in mod._forward_hooks.values():
  680. if hook_fn is not _observer_forward_hook:
  681. new_mod.register_forward_hook(hook_fn)
  682. # respect device affinity when swapping modules
  683. devices = _get_unique_devices_(mod)
  684. if not (
  685. len(devices) <= 1
  686. or (len(devices) == 2 and torch.device("meta") in devices)
  687. ):
  688. raise AssertionError(
  689. f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"
  690. )
  691. device = next(iter(devices)) if len(devices) > 0 else None
  692. if device:
  693. new_mod.to(device)
  694. return new_mod
  695. def _get_observer_dict(mod, target_dict, prefix=""):
  696. r"""Traverse the modules and save all observers into dict.
  697. This is mainly used for quantization accuracy debug
  698. Args:
  699. mod: the top module we want to save all observers
  700. prefix: the prefix for the current module
  701. target_dict: the dictionary used to save all the observers
  702. """
  703. def get_prefix(prefix):
  704. return prefix if prefix == "" else prefix + "."
  705. if hasattr(mod, "activation_post_process"):
  706. target_dict[get_prefix(prefix) + "activation_post_process"] = (
  707. mod.activation_post_process
  708. )
  709. for name, child in mod.named_children():
  710. module_prefix = get_prefix(prefix) + name if prefix else name
  711. _get_observer_dict(child, target_dict, module_prefix)