quantize_fx.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759
  1. import copy
  2. import typing_extensions
  3. import warnings
  4. from typing import Any
  5. import torch
  6. from torch.fx import GraphModule
  7. from torch.fx.graph_module import _USER_PRESERVED_ATTRIBUTES_KEY
  8. from .backend_config import BackendConfig, get_tensorrt_backend_config # noqa: F401
  9. from .fx.convert import convert
  10. from .fx.custom_config import ConvertCustomConfig, FuseCustomConfig, PrepareCustomConfig
  11. from .fx.fuse import fuse # noqa: F401
  12. from .fx.graph_module import ObservedGraphModule # noqa: F401
  13. from .fx.prepare import prepare # noqa: F401
  14. from .fx.tracer import QuantizationTracer, Scope, ScopeContextManager # noqa: F401
  15. from .fx.utils import ( # noqa: F401
  16. get_custom_module_class_keys,
  17. get_skipped_module_name_and_classes,
  18. )
  19. from .qconfig_mapping import QConfigMapping
  20. from .utils import DEPRECATION_WARNING
  21. def attach_preserved_attrs_to_model(
  22. model: GraphModule | torch.nn.Module,
  23. preserved_attrs: dict[str, Any],
  24. ) -> None:
  25. """Store preserved attributes to the model.meta so that it can be preserved during deepcopy"""
  26. model.meta[_USER_PRESERVED_ATTRIBUTES_KEY] = copy.copy(preserved_attrs) # type: ignore[operator, index, assignment]
  27. # set the preserved attributes in the model so that user can call
  28. # model.attr as they do before calling fx graph mode quantization
  29. for attr_name, attr in model.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items(): # type: ignore[index, union-attr]
  30. setattr(model, attr_name, attr)
  31. def _check_is_graph_module(model: torch.nn.Module) -> None:
  32. if not isinstance(model, GraphModule):
  33. raise ValueError(
  34. "input model must be a GraphModule, "
  35. + "Got type:"
  36. + str(type(model))
  37. + " Please make "
  38. + "sure to follow the tutorials."
  39. )
  40. def _attach_meta_to_node_if_not_exist(model: GraphModule) -> None:
  41. """Attach meta field to all nodes of the graph if it does not exist,
  42. meta field is a field stores some meta information about the node, such
  43. as dtype and shape information for output of the node, this only exists
  44. if the program is captured by make_fx (used in quantize_pt2e flow), if
  45. the program is captured by torch.fx symbolic tracing, this field may not exist,
  46. so we add it here to avoid checking this all over the places
  47. """
  48. for node in model.graph.nodes:
  49. if not hasattr(node, "meta"):
  50. node.meta = {}
  51. def _swap_ff_with_fxff(model: torch.nn.Module) -> None:
  52. r"""Swap FloatFunctional with FXFloatFunctional"""
  53. modules_to_swap = []
  54. for name, module in model.named_children():
  55. if isinstance(module, torch.ao.nn.quantized.FloatFunctional):
  56. modules_to_swap.append(name)
  57. else:
  58. _swap_ff_with_fxff(module)
  59. for name in modules_to_swap:
  60. del model._modules[name]
  61. model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional()
  62. def _fuse_fx(
  63. model: GraphModule,
  64. is_qat: bool,
  65. fuse_custom_config: FuseCustomConfig | dict[str, Any] | None = None,
  66. backend_config: BackendConfig | dict[str, Any] | None = None,
  67. ) -> GraphModule:
  68. r"""Internal helper function to fuse modules in preparation for quantization
  69. Args:
  70. model: GraphModule object from symbolic tracing (torch.fx.symbolic_trace)
  71. """
  72. _check_is_graph_module(model)
  73. return fuse(model, is_qat, fuse_custom_config, backend_config) # type: ignore[operator]
  74. def _prepare_fx(
  75. model: torch.nn.Module,
  76. qconfig_mapping: QConfigMapping | dict[str, Any],
  77. is_qat: bool,
  78. example_inputs: tuple[Any, ...],
  79. prepare_custom_config: PrepareCustomConfig | dict[str, Any] | None = None,
  80. _equalization_config: QConfigMapping | dict[str, Any] | None = None,
  81. backend_config: BackendConfig | dict[str, Any] | None = None,
  82. is_standalone_module: bool = False,
  83. ) -> GraphModule:
  84. r"""Internal helper function for prepare_fx
  85. Args:
  86. `model`, `qconfig_mapping`, `prepare_custom_config`, `_equalization_config`:
  87. see docs for :func:`~torch.ao.quantization.prepare_fx`
  88. `is_standalone_module`: a boolean flag indicates whether we are
  89. quantizing a standalone module or not, a standalone module
  90. is a submodule of the parent module that is not inlined in the
  91. forward graph of the parent module,
  92. the way we quantize standalone module is described in:
  93. :func:`~torch.ao.quantization._prepare_standalone_module_fx`
  94. """
  95. if prepare_custom_config is None:
  96. prepare_custom_config = PrepareCustomConfig()
  97. if _equalization_config is None:
  98. _equalization_config = QConfigMapping()
  99. if isinstance(prepare_custom_config, dict):
  100. warnings.warn(
  101. "Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported "
  102. "in a future version. Please pass in a PrepareCustomConfig instead.",
  103. FutureWarning,
  104. stacklevel=3,
  105. )
  106. prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config)
  107. # swap FloatFunctional with FXFloatFunctional
  108. _swap_ff_with_fxff(model)
  109. skipped_module_names, skipped_module_classes = get_skipped_module_name_and_classes(
  110. prepare_custom_config, is_standalone_module
  111. )
  112. preserved_attr_names = prepare_custom_config.preserved_attributes
  113. preserved_attrs = {
  114. attr: getattr(model, attr)
  115. for attr in preserved_attr_names
  116. if hasattr(model, attr)
  117. }
  118. # symbolically trace the model
  119. tracer = QuantizationTracer(skipped_module_names, skipped_module_classes) # type: ignore[arg-type]
  120. graph_module = GraphModule(model, tracer.trace(model))
  121. _attach_meta_to_node_if_not_exist(graph_module)
  122. fuse_custom_config = FuseCustomConfig().set_preserved_attributes(
  123. prepare_custom_config.preserved_attributes
  124. )
  125. graph_module = _fuse_fx(graph_module, is_qat, fuse_custom_config, backend_config)
  126. prepared = prepare(
  127. graph_module,
  128. qconfig_mapping,
  129. is_qat,
  130. tracer.node_name_to_scope,
  131. example_inputs=example_inputs,
  132. prepare_custom_config=prepare_custom_config,
  133. _equalization_config=_equalization_config,
  134. backend_config=backend_config,
  135. is_standalone_module=is_standalone_module,
  136. ) # type: ignore[operator]
  137. attach_preserved_attrs_to_model(prepared, preserved_attrs)
  138. return prepared
  139. def _prepare_standalone_module_fx(
  140. model: torch.nn.Module,
  141. qconfig_mapping: QConfigMapping | dict[str, Any],
  142. is_qat: bool,
  143. example_inputs: tuple[Any, ...],
  144. prepare_custom_config: PrepareCustomConfig | dict[str, Any] | None = None,
  145. backend_config: BackendConfig | dict[str, Any] | None = None,
  146. ) -> GraphModule:
  147. r"""[Internal use only] Prepare a standalone module, so that it can be used when quantizing the
  148. parent module.
  149. standalone_module means it a submodule that is not inlined in parent module,
  150. and will be quantized separately as one unit.
  151. How the standalone module is observed is specified by `input_quantized_idxs` and
  152. `output_quantized_idxs` in the prepare_custom_config for the standalone module
  153. Returns:
  154. * model(GraphModule): prepared standalone module. It has these attributes in
  155. model.meta:
  156. * `standalone_module_input_quantized_idxs(List[Int])`: a list of
  157. indexes for the graph input that is expected to be quantized,
  158. same as input_quantized_idxs configuration provided
  159. for the standalone module
  160. * `standalone_module_output_quantized_idxs(List[Int])`: a list of
  161. indices for the graph output that is quantized
  162. same as input_quantized_idxs configuration provided
  163. for the standalone module
  164. """
  165. return _prepare_fx(
  166. model,
  167. qconfig_mapping,
  168. is_qat,
  169. example_inputs,
  170. prepare_custom_config,
  171. backend_config=backend_config,
  172. is_standalone_module=True,
  173. )
  174. def fuse_fx(
  175. model: torch.nn.Module,
  176. fuse_custom_config: FuseCustomConfig | dict[str, Any] | None = None,
  177. backend_config: BackendConfig | dict[str, Any] | None = None,
  178. ) -> GraphModule:
  179. r"""Fuse modules like conv+bn, conv+bn+relu etc, model must be in eval mode.
  180. Fusion rules are defined in torch.ao.quantization.fx.fusion_pattern.py
  181. Args:
  182. * `model` (torch.nn.Module): a torch.nn.Module model
  183. * `fuse_custom_config` (FuseCustomConfig): custom configurations for fuse_fx.
  184. See :class:`~torch.ao.quantization.fx.custom_config.FuseCustomConfig` for more details
  185. Example::
  186. from torch.ao.quantization import fuse_fx
  187. m = Model().eval()
  188. m = fuse_fx(m)
  189. """
  190. if fuse_custom_config is None:
  191. fuse_custom_config = FuseCustomConfig()
  192. if isinstance(fuse_custom_config, dict):
  193. warnings.warn(
  194. "Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported "
  195. "in a future version. Please pass in a FuseCustomConfig instead.",
  196. FutureWarning,
  197. stacklevel=2,
  198. )
  199. fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config)
  200. torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx")
  201. preserved_attr_names = fuse_custom_config.preserved_attributes
  202. preserved_attrs = {
  203. attr: getattr(model, attr)
  204. for attr in preserved_attr_names
  205. if hasattr(model, attr)
  206. }
  207. graph_module = torch.fx.symbolic_trace(model)
  208. _attach_meta_to_node_if_not_exist(graph_module)
  209. graph_module = _fuse_fx(graph_module, False, fuse_custom_config, backend_config)
  210. attach_preserved_attrs_to_model(graph_module, preserved_attrs)
  211. return graph_module
  212. @typing_extensions.deprecated(DEPRECATION_WARNING)
  213. def prepare_fx(
  214. model: torch.nn.Module,
  215. qconfig_mapping: QConfigMapping | dict[str, Any],
  216. example_inputs: tuple[Any, ...],
  217. prepare_custom_config: PrepareCustomConfig | dict[str, Any] | None = None,
  218. _equalization_config: QConfigMapping | dict[str, Any] | None = None,
  219. backend_config: BackendConfig | dict[str, Any] | None = None,
  220. ) -> GraphModule:
  221. r""" Prepare a model for post training quantization
  222. Args:
  223. * `model` (torch.nn.Module): torch.nn.Module model
  224. * `qconfig_mapping` (QConfigMapping): QConfigMapping object to configure how a model is
  225. quantized, see :class:`~torch.ao.quantization.qconfig_mapping.QConfigMapping`
  226. for more details
  227. * `example_inputs` (Tuple[Any, ...]): Example inputs for forward function of the model,
  228. Tuple of positional args (keyword args can be passed as positional args as well)
  229. * `prepare_custom_config` (PrepareCustomConfig): customization configuration for quantization tool.
  230. See :class:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig` for more details
  231. * `_equalization_config`: config for specifying how to perform equalization on the model
  232. * `backend_config` (BackendConfig): config that specifies how operators are quantized
  233. in a backend, this includes how the operators are observed,
  234. supported fusion patterns, how quantize/dequantize ops are
  235. inserted, supported dtypes etc. See :class:`~torch.ao.quantization.backend_config.BackendConfig` for more details
  236. Return:
  237. A GraphModule with observer (configured by qconfig_mapping), ready for calibration
  238. Example::
  239. import torch
  240. from torch.ao.quantization import get_default_qconfig_mapping
  241. from torch.ao.quantization.quantize_fx import prepare_fx
  242. class Submodule(torch.nn.Module):
  243. def __init__(self) -> None:
  244. super().__init__()
  245. self.linear = torch.nn.Linear(5, 5)
  246. def forward(self, x):
  247. x = self.linear(x)
  248. return x
  249. class M(torch.nn.Module):
  250. def __init__(self) -> None:
  251. super().__init__()
  252. self.linear = torch.nn.Linear(5, 5)
  253. self.sub = Submodule()
  254. def forward(self, x):
  255. x = self.linear(x)
  256. x = self.sub(x) + x
  257. return x
  258. # initialize a floating point model
  259. float_model = M().eval()
  260. # define calibration function
  261. def calibrate(model, data_loader):
  262. model.eval()
  263. with torch.no_grad():
  264. for image, target in data_loader:
  265. model(image)
  266. # qconfig is the configuration for how we insert observers for a particular
  267. # operator
  268. # qconfig = get_default_qconfig("fbgemm")
  269. # Example of customizing qconfig:
  270. # qconfig = torch.ao.quantization.QConfig(
  271. # activation=MinMaxObserver.with_args(dtype=torch.qint8),
  272. # weight=MinMaxObserver.with_args(dtype=torch.qint8))
  273. # `activation` and `weight` are constructors of observer module
  274. # qconfig_mapping is a collection of quantization configurations, user can
  275. # set the qconfig for each operator (torch op calls, functional calls, module calls)
  276. # in the model through qconfig_mapping
  277. # the following call will get the qconfig_mapping that works best for models
  278. # that target "fbgemm" backend
  279. qconfig_mapping = get_default_qconfig_mapping("fbgemm")
  280. # We can customize qconfig_mapping in different ways.
  281. # e.g. set the global qconfig, which means we will use the same qconfig for
  282. # all operators in the model, this can be overwritten by other settings
  283. # qconfig_mapping = QConfigMapping().set_global(qconfig)
  284. # e.g. quantize the linear submodule with a specific qconfig
  285. # qconfig_mapping = QConfigMapping().set_module_name("linear", qconfig)
  286. # e.g. quantize all nn.Linear modules with a specific qconfig
  287. # qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig)
  288. # for a more complete list, please see the docstring for :class:`torch.ao.quantization.QConfigMapping`
  289. # argument
  290. # example_inputs is a tuple of inputs, that is used to infer the type of the
  291. # outputs in the model
  292. # currently it's not used, but please make sure model(*example_inputs) runs
  293. example_inputs = (torch.randn(1, 3, 224, 224),)
  294. # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
  295. # e.g. backend_config = get_default_backend_config("fbgemm")
  296. # `prepare_fx` inserts observers in the model based on qconfig_mapping and
  297. # backend_config. If the configuration for an operator in qconfig_mapping
  298. # is supported in the backend_config (meaning it's supported by the target
  299. # hardware), we'll insert observer modules according to the qconfig_mapping
  300. # otherwise the configuration in qconfig_mapping will be ignored
  301. #
  302. # Example:
  303. # in qconfig_mapping, user sets linear module to be quantized with quint8 for
  304. # activation and qint8 for weight:
  305. # qconfig = torch.ao.quantization.QConfig(
  306. # observer=MinMaxObserver.with_args(dtype=torch.quint8),
  307. # weight=MinMaxObserver.with-args(dtype=torch.qint8))
  308. # Note: current qconfig api does not support setting output observer, but
  309. # we may extend this to support these more fine grained control in the
  310. # future
  311. #
  312. # qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig)
  313. # in backend config, linear module also supports in this configuration:
  314. # weighted_int8_dtype_config = DTypeConfig(
  315. # input_dtype=torch.quint8,
  316. # output_dtype=torch.quint8,
  317. # weight_dtype=torch.qint8,
  318. # bias_type=torch.float)
  319. # linear_pattern_config = BackendPatternConfig(torch.nn.Linear) \
  320. # .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
  321. # .add_dtype_config(weighted_int8_dtype_config) \
  322. # ...
  323. # backend_config = BackendConfig().set_backend_pattern_config(linear_pattern_config)
  324. # `prepare_fx` will check that the setting requested by suer in qconfig_mapping
  325. # is supported by the backend_config and insert observers and fake quant modules
  326. # in the model
  327. prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs)
  328. # Run calibration
  329. calibrate(prepared_model, sample_inference_data)
  330. """
  331. torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_fx")
  332. return _prepare_fx(
  333. model,
  334. qconfig_mapping,
  335. False, # is_qat
  336. example_inputs,
  337. prepare_custom_config,
  338. _equalization_config,
  339. backend_config,
  340. )
  341. @typing_extensions.deprecated(DEPRECATION_WARNING)
  342. def prepare_qat_fx(
  343. model: torch.nn.Module,
  344. qconfig_mapping: QConfigMapping | dict[str, Any],
  345. example_inputs: tuple[Any, ...],
  346. prepare_custom_config: PrepareCustomConfig | dict[str, Any] | None = None,
  347. backend_config: BackendConfig | dict[str, Any] | None = None,
  348. ) -> GraphModule:
  349. r"""Prepare a model for quantization aware training
  350. Args:
  351. * `model` (torch.nn.Module): torch.nn.Module model
  352. * `qconfig_mapping` (QConfigMapping): see :func:`~torch.ao.quantization.prepare_fx`
  353. * `example_inputs` (Tuple[Any, ...]): see :func:`~torch.ao.quantization.prepare_fx`
  354. * `prepare_custom_config` (PrepareCustomConfig): see :func:`~torch.ao.quantization.prepare_fx`
  355. * `backend_config` (BackendConfig): see :func:`~torch.ao.quantization.prepare_fx`
  356. Return:
  357. A GraphModule with fake quant modules (configured by qconfig_mapping and backend_config), ready for
  358. quantization aware training
  359. Example::
  360. import torch
  361. from torch.ao.quantization import get_default_qat_qconfig_mapping
  362. from torch.ao.quantization.quantize_fx import prepare_qat_fx
  363. class Submodule(torch.nn.Module):
  364. def __init__(self) -> None:
  365. super().__init__()
  366. self.linear = torch.nn.Linear(5, 5)
  367. def forward(self, x):
  368. x = self.linear(x)
  369. return x
  370. class M(torch.nn.Module):
  371. def __init__(self) -> None:
  372. super().__init__()
  373. self.linear = torch.nn.Linear(5, 5)
  374. self.sub = Submodule()
  375. def forward(self, x):
  376. x = self.linear(x)
  377. x = self.sub(x) + x
  378. return x
  379. # initialize a floating point model
  380. float_model = M().train()
  381. # (optional, but preferred) load the weights from pretrained model
  382. # float_model.load_weights(...)
  383. # define the training loop for quantization aware training
  384. def train_loop(model, train_data):
  385. model.train()
  386. for image, target in data_loader:
  387. ...
  388. # qconfig is the configuration for how we insert observers for a particular
  389. # operator
  390. # qconfig = get_default_qconfig("fbgemm")
  391. # Example of customizing qconfig:
  392. # qconfig = torch.ao.quantization.QConfig(
  393. # activation=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8)),
  394. # weight=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8)))
  395. # `activation` and `weight` are constructors of observer module
  396. # qconfig_mapping is a collection of quantization configurations, user can
  397. # set the qconfig for each operator (torch op calls, functional calls, module calls)
  398. # in the model through qconfig_mapping
  399. # the following call will get the qconfig_mapping that works best for models
  400. # that target "fbgemm" backend
  401. qconfig_mapping = get_default_qat_qconfig_mapping("fbgemm")
  402. # We can customize qconfig_mapping in different ways, please take a look at
  403. # the docstring for :func:`~torch.ao.quantization.prepare_fx` for different ways
  404. # to configure this
  405. # example_inputs is a tuple of inputs, that is used to infer the type of the
  406. # outputs in the model
  407. # currently it's not used, but please make sure model(*example_inputs) runs
  408. example_inputs = (torch.randn(1, 3, 224, 224),)
  409. # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
  410. # e.g. backend_config = get_default_backend_config("fbgemm")
  411. # `prepare_qat_fx` inserts observers in the model based on qconfig_mapping and
  412. # backend_config, if the configuration for an operator in qconfig_mapping
  413. # is supported in the backend_config (meaning it's supported by the target
  414. # hardware), we'll insert fake_quantize modules according to the qconfig_mapping
  415. # otherwise the configuration in qconfig_mapping will be ignored
  416. # see :func:`~torch.ao.quantization.prepare_fx` for a detailed explanation of
  417. # how qconfig_mapping interacts with backend_config
  418. prepared_model = prepare_qat_fx(float_model, qconfig_mapping, example_inputs)
  419. # Run training
  420. train_loop(prepared_model, train_loop)
  421. """
  422. torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_qat_fx")
  423. return _prepare_fx(
  424. model,
  425. qconfig_mapping,
  426. True, # is_qat
  427. example_inputs,
  428. prepare_custom_config,
  429. backend_config=backend_config,
  430. )
  431. def _convert_fx(
  432. graph_module: GraphModule,
  433. is_reference: bool,
  434. convert_custom_config: ConvertCustomConfig | dict[str, Any] | None = None,
  435. is_standalone_module: bool = False,
  436. _remove_qconfig: bool = True,
  437. qconfig_mapping: QConfigMapping | dict[str, Any] | None = None,
  438. backend_config: BackendConfig | dict[str, Any] | None = None,
  439. is_decomposed: bool = False,
  440. keep_original_weights: bool = False,
  441. ) -> GraphModule:
  442. """`is_standalone_module`: see docs in :func:`~torch.ao.quantization.prepare_standalone_module_fx`"""
  443. if convert_custom_config is None:
  444. convert_custom_config = ConvertCustomConfig()
  445. if isinstance(convert_custom_config, dict):
  446. warnings.warn(
  447. "Passing a convert_custom_config_dict to convert is deprecated and will not be supported "
  448. "in a future version. Please pass in a ConvertCustomConfig instead.",
  449. FutureWarning,
  450. stacklevel=3,
  451. )
  452. convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config)
  453. _check_is_graph_module(graph_module)
  454. preserved_attr_names = convert_custom_config.preserved_attributes
  455. preserved_attrs = {
  456. attr: getattr(graph_module, attr)
  457. for attr in preserved_attr_names
  458. if hasattr(graph_module, attr)
  459. }
  460. quantized = convert(
  461. graph_module,
  462. is_reference,
  463. convert_custom_config,
  464. is_standalone_module,
  465. _remove_qconfig_flag=_remove_qconfig,
  466. qconfig_mapping=qconfig_mapping,
  467. backend_config=backend_config,
  468. is_decomposed=is_decomposed,
  469. keep_original_weights=keep_original_weights,
  470. )
  471. attach_preserved_attrs_to_model(quantized, preserved_attrs)
  472. return quantized
  473. @typing_extensions.deprecated(DEPRECATION_WARNING)
  474. def convert_fx(
  475. graph_module: GraphModule,
  476. convert_custom_config: ConvertCustomConfig | dict[str, Any] | None = None,
  477. _remove_qconfig: bool = True,
  478. qconfig_mapping: QConfigMapping | dict[str, Any] | None = None,
  479. backend_config: BackendConfig | dict[str, Any] | None = None,
  480. keep_original_weights: bool = False,
  481. ) -> GraphModule:
  482. r"""Convert a calibrated or trained model to a quantized model
  483. Args:
  484. * `graph_module` (torch.fx.GraphModule): A prepared and calibrated/trained model (GraphModule)
  485. * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function.
  486. See :class:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig` for more details
  487. * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert.
  488. * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization.
  489. The keys must include the ones in the qconfig_mapping passed to `prepare_fx` or `prepare_qat_fx`,
  490. with the same values or `None`. Additional keys can be specified with values set to `None`.
  491. For each entry whose value is set to None, we skip quantizing that entry in the model::
  492. qconfig_mapping = QConfigMapping
  493. .set_global(qconfig_from_prepare)
  494. .set_object_type(torch.nn.functional.add, None) # skip quantizing torch.nn.functional.add
  495. .set_object_type(torch.nn.functional.linear, qconfig_from_prepare)
  496. .set_module_name("foo.bar", None) # skip quantizing module "foo.bar"
  497. * `backend_config` (BackendConfig): A configuration for the backend which describes how
  498. operators should be quantized in the backend, this includes quantization
  499. mode support (static/dynamic/weight_only), dtype support (quint8/qint8 etc.),
  500. observer placement for each operators and fused operators.
  501. See :class:`~torch.ao.quantization.backend_config.BackendConfig` for more details
  502. Return:
  503. A quantized model (torch.nn.Module)
  504. Example::
  505. # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training
  506. # convert_fx converts a calibrated/trained model to a quantized model for the
  507. # target hardware, this includes converting the model first to a reference
  508. # quantized model, and then lower the reference quantized model to a backend
  509. # Currently, the supported backends are fbgemm (onednn), qnnpack (xnnpack) and
  510. # they share the same set of quantized operators, so we are using the same
  511. # lowering procedure
  512. #
  513. # backend_config defines the corresponding reference quantized module for
  514. # the weighted modules in the model, e.g. nn.Linear
  515. # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
  516. # e.g. backend_config = get_default_backend_config("fbgemm")
  517. quantized_model = convert_fx(prepared_model)
  518. """
  519. torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_fx")
  520. return _convert_fx(
  521. graph_module,
  522. is_reference=False,
  523. convert_custom_config=convert_custom_config,
  524. _remove_qconfig=_remove_qconfig,
  525. qconfig_mapping=qconfig_mapping,
  526. backend_config=backend_config,
  527. keep_original_weights=keep_original_weights,
  528. )
  529. def convert_to_reference_fx(
  530. graph_module: GraphModule,
  531. convert_custom_config: ConvertCustomConfig | dict[str, Any] | None = None,
  532. _remove_qconfig: bool = True,
  533. qconfig_mapping: QConfigMapping | dict[str, Any] | None = None,
  534. backend_config: BackendConfig | dict[str, Any] | None = None,
  535. ) -> GraphModule:
  536. r"""Convert a calibrated or trained model to a reference quantized model,
  537. see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details,
  538. reference quantized model is a standard representation of a quantized model provided
  539. by FX Graph Mode Quantization, it can be further lowered to run on the target
  540. hardware, like accelerators
  541. Args:
  542. * `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule)
  543. * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function.
  544. See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
  545. * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert.
  546. * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization.
  547. See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
  548. * `backend_config` (BackendConfig): A configuration for the backend which describes how
  549. operators should be quantized in the backend. See
  550. :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
  551. Return:
  552. A reference quantized model (GraphModule)
  553. Example::
  554. # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training
  555. # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
  556. # e.g. backend_config = get_default_backend_config("fbgemm")
  557. reference_quantized_model = convert_to_reference_fx(prepared_model)
  558. """
  559. torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_to_reference_fx")
  560. return _convert_fx(
  561. graph_module,
  562. is_reference=True,
  563. convert_custom_config=convert_custom_config,
  564. _remove_qconfig=_remove_qconfig,
  565. qconfig_mapping=qconfig_mapping,
  566. backend_config=backend_config,
  567. )
  568. def _convert_to_reference_decomposed_fx(
  569. graph_module: GraphModule,
  570. convert_custom_config: ConvertCustomConfig | dict[str, Any] | None = None,
  571. qconfig_mapping: QConfigMapping | dict[str, Any] | None = None,
  572. backend_config: BackendConfig | dict[str, Any] | None = None,
  573. ) -> GraphModule:
  574. r"""Convert a calibrated or trained model to a reference quantized model, with
  575. decomposed representation for quantized Tensor
  576. see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details,
  577. reference quantized model is a standard representation of a quantized model provided
  578. by FX Graph Mode Quantization, it can be further lowered to run on the target
  579. hardware, like accelerators
  580. Note: this is not public API
  581. Args:
  582. * `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule)
  583. * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function.
  584. See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
  585. * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert.
  586. * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization.
  587. See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
  588. * `backend_config` (BackendConfig): A configuration for the backend which describes how
  589. operators should be quantized in the backend. See
  590. :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
  591. Return:
  592. A reference quantized model (GraphModule) with operators working with decomposed quantized Tensor
  593. Example::
  594. # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training
  595. # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
  596. # e.g. backend_config = get_default_backend_config("fbgemm")
  597. reference_quantized_model = _convert_to_reference_decomposed_fx(prepared_model)
  598. """
  599. torch._C._log_api_usage_once(
  600. "quantization_api.quantize_fx._convert_to_reference_decomposed_fx"
  601. )
  602. return _convert_fx(
  603. graph_module,
  604. is_reference=True,
  605. convert_custom_config=convert_custom_config,
  606. _remove_qconfig=False,
  607. qconfig_mapping=qconfig_mapping,
  608. backend_config=backend_config,
  609. is_decomposed=True,
  610. )
  611. def _convert_standalone_module_fx(
  612. graph_module: GraphModule,
  613. is_reference: bool = False,
  614. convert_custom_config: ConvertCustomConfig | dict[str, Any] | None = None,
  615. ) -> GraphModule:
  616. r"""[Internal use only] Convert a model produced by :func:`~torch.ao.quantization.prepare_standalone_module_fx`
  617. and convert it to a quantized model
  618. Returns a quantized standalone module, whether input/output is quantized is
  619. specified by prepare_custom_config, with
  620. input_quantized_idxs, output_quantized_idxs, please
  621. see docs for prepare_fx for details
  622. """
  623. return _convert_fx(
  624. graph_module,
  625. is_reference,
  626. convert_custom_config,
  627. is_standalone_module=True,
  628. )