fake_quantize.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663
  1. # mypy: allow-untyped-decorators
  2. # mypy: allow-untyped-defs
  3. """Implements modules used to perform fake quantization."""
  4. import re
  5. from abc import ABC, abstractmethod
  6. from typing import Any
  7. import torch
  8. from torch.ao.quantization.observer import (
  9. _with_args,
  10. default_fixed_qparams_range_0to1_observer,
  11. default_fixed_qparams_range_neg1to1_observer,
  12. FixedQParamsObserver,
  13. HistogramObserver,
  14. MovingAverageMinMaxObserver,
  15. MovingAveragePerChannelMinMaxObserver,
  16. )
  17. from torch.nn import Module
  18. __all__ = [
  19. "FakeQuantizeBase",
  20. "FakeQuantize",
  21. "FixedQParamsFakeQuantize",
  22. "FusedMovingAvgObsFakeQuantize",
  23. "disable_fake_quant",
  24. "disable_observer",
  25. "enable_fake_quant",
  26. "enable_observer",
  27. "default_fake_quant",
  28. "default_weight_fake_quant",
  29. "default_dynamic_fake_quant",
  30. "default_fixed_qparams_range_neg1to1_fake_quant",
  31. "default_fixed_qparams_range_0to1_fake_quant",
  32. "default_symmetric_fixed_qparams_fake_quant",
  33. "default_affine_fixed_qparams_fake_quant",
  34. "default_per_channel_weight_fake_quant",
  35. "default_embedding_fake_quant",
  36. "default_embedding_fake_quant_4bit",
  37. "default_histogram_fake_quant",
  38. "default_fused_act_fake_quant",
  39. "default_fused_wt_fake_quant",
  40. "default_fused_per_channel_wt_fake_quant",
  41. "fused_wt_fake_quant_range_neg_127_to_127",
  42. "fused_per_channel_wt_fake_quant_range_neg_127_to_127",
  43. ]
  44. def _is_per_channel(qscheme: "torch.qscheme") -> bool:
  45. return qscheme in [
  46. torch.per_channel_symmetric,
  47. torch.per_channel_affine,
  48. torch.per_channel_affine_float_qparams,
  49. ]
  50. def _is_per_tensor(qscheme: "torch.qscheme") -> bool:
  51. return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]
  52. def _is_symmetric_quant(qscheme: "torch.qscheme") -> bool:
  53. return qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric]
  54. def _is_float_qparams(qscheme: "torch.qscheme") -> bool:
  55. return qscheme == torch.per_channel_affine_float_qparams
  56. class FakeQuantizeBase(ABC, Module):
  57. r"""Base fake quantize module.
  58. Base fake quantize module
  59. Any fake quantize implementation should derive from this class.
  60. Concrete fake quantize module should follow the same API. In forward, they will update
  61. the statistics of the observed Tensor and fake quantize the input. They should also provide a
  62. `calculate_qparams` function that computes the quantization parameters given
  63. the collected statistics.
  64. """
  65. fake_quant_enabled: torch.Tensor
  66. observer_enabled: torch.Tensor
  67. def __init__(self) -> None:
  68. """Set fake_quant_enabled and observer_enabled."""
  69. super().__init__()
  70. # fake_quant_enabled and observer_enabled are buffers to support their
  71. # replication in DDP. Data type is uint8 because NCCL does not support
  72. # bool tensors.
  73. self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.uint8))
  74. self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.uint8))
  75. @abstractmethod
  76. def forward(self, x):
  77. pass
  78. @abstractmethod
  79. def calculate_qparams(self, **kwargs):
  80. pass
  81. @torch.jit.export
  82. def enable_fake_quant(self, enabled: bool = True) -> None:
  83. self.fake_quant_enabled[0] = 1 if enabled else 0
  84. @torch.jit.export
  85. def disable_fake_quant(self):
  86. self.enable_fake_quant(False)
  87. @torch.jit.export
  88. def enable_observer(self, enabled: bool = True) -> None:
  89. self.observer_enabled[0] = 1 if enabled else 0
  90. @torch.jit.export
  91. def disable_observer(self):
  92. self.enable_observer(False)
  93. @classmethod
  94. def with_args(cls, **kwargs):
  95. fake_quant_constructor = _with_args(cls, **kwargs)
  96. # need to assign the correct module to fake_quantize
  97. # constructors to satisfy public v private requirements
  98. fake_quant_constructor.__module__ = "torch.ao.quantization.fake_quantize"
  99. return fake_quant_constructor
  100. class FakeQuantize(FakeQuantizeBase):
  101. r"""Simulate the quantize and dequantize operations in training time.
  102. The output of this module is given by::
  103. x_out = (
  104. clamp(round(x / scale + zero_point), quant_min, quant_max) - zero_point
  105. ) * scale
  106. * :attr:`is_dynamic` indicates whether the fake quantie is a placeholder for dynamic quantization
  107. operators (choose_qparams -> q -> dq) or static quantization operators (q -> dq)
  108. * :attr:`scale` defines the scale factor used for quantization.
  109. * :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to
  110. * :attr:`fake_quant_enabled` controls the application of fake quantization on tensors, note that
  111. statistics can still be updated.
  112. * :attr:`observer_enabled` controls statistics collection on tensors
  113. * :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization,
  114. allowable values are torch.qint8 and torch.quint8.
  115. Args:
  116. observer (module): Module for observing statistics on input tensors and calculating scale
  117. and zero-point.
  118. observer_kwargs (optional): Arguments for the observer module
  119. Attributes:
  120. activation_post_process (Module): User provided module that collects statistics on the input tensor and
  121. provides a method to calculate scale and zero-point.
  122. """
  123. scale: torch.Tensor
  124. zero_point: torch.Tensor
  125. def __init__(
  126. self,
  127. observer=MovingAverageMinMaxObserver,
  128. quant_min=None,
  129. quant_max=None,
  130. is_dynamic=False,
  131. **observer_kwargs,
  132. ):
  133. super().__init__()
  134. # Populate quant_min/quant_max to observer_kwargs if valid
  135. if quant_min is not None and quant_max is not None:
  136. if quant_min > quant_max:
  137. raise AssertionError(
  138. "quant_min must be less than or equal to quant_max"
  139. )
  140. dtype = observer_kwargs.get("dtype", torch.quint8)
  141. if hasattr(observer, "p"):
  142. # In case observer is _PartialWrapper, dtype can be stored in
  143. # observer.p.keywords["dtype"]
  144. dtype = getattr(getattr(observer, "p", {}), "keywords", {}).get(
  145. "dtype", dtype
  146. )
  147. if torch.iinfo(dtype).min > quant_min:
  148. raise AssertionError("quant_min out of bound")
  149. if quant_max > torch.iinfo(dtype).max:
  150. raise AssertionError("quant_max out of bound")
  151. observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max})
  152. observer_kwargs["is_dynamic"] = is_dynamic
  153. self.activation_post_process = observer(**observer_kwargs)
  154. # TODO: keeping self.quant_min/max for BC; remove after a couple releases
  155. # Users should use self.activation_post_process.quant_min
  156. self.quant_min = self.activation_post_process.quant_min
  157. self.quant_max = self.activation_post_process.quant_max
  158. self.is_dynamic = self.activation_post_process.is_dynamic
  159. if _is_float_qparams(self.activation_post_process.qscheme):
  160. zero_point_dtype = torch.float
  161. else:
  162. zero_point_dtype = torch.int
  163. self.register_buffer("scale", torch.tensor([1.0], dtype=torch.float))
  164. self.register_buffer("zero_point", torch.tensor([0], dtype=zero_point_dtype))
  165. self.dtype = self.activation_post_process.dtype
  166. self.qscheme = self.activation_post_process.qscheme
  167. self.ch_axis = (
  168. self.activation_post_process.ch_axis
  169. if hasattr(self.activation_post_process, "ch_axis")
  170. else -1
  171. )
  172. if not (_is_per_channel(self.qscheme) or _is_per_tensor(self.qscheme)):
  173. raise AssertionError(
  174. "Only per channel and per tensor quantization are supported in fake quantize"
  175. + " got qscheme: "
  176. + str(self.qscheme)
  177. )
  178. self.is_per_channel = _is_per_channel(self.qscheme)
  179. @torch.jit.export
  180. def calculate_qparams(self): # type: ignore[override]
  181. return self.activation_post_process.calculate_qparams()
  182. def forward(self, X):
  183. if self.observer_enabled[0] == 1:
  184. self.activation_post_process(X.detach())
  185. _scale, _zero_point = self.calculate_qparams()
  186. _scale, _zero_point = (
  187. _scale.to(self.scale.device),
  188. _zero_point.to(self.zero_point.device),
  189. )
  190. if self.scale.shape != _scale.shape:
  191. self.scale.resize_(_scale.shape)
  192. self.zero_point.resize_(_zero_point.shape)
  193. self.scale.copy_(_scale)
  194. self.zero_point.copy_(_zero_point)
  195. if self.fake_quant_enabled[0] == 1:
  196. if self.is_per_channel:
  197. X = torch.fake_quantize_per_channel_affine(
  198. X,
  199. self.scale,
  200. self.zero_point,
  201. self.ch_axis,
  202. self.activation_post_process.quant_min,
  203. self.activation_post_process.quant_max,
  204. )
  205. else:
  206. X = torch.fake_quantize_per_tensor_affine(
  207. X,
  208. self.scale,
  209. self.zero_point,
  210. self.activation_post_process.quant_min,
  211. self.activation_post_process.quant_max,
  212. )
  213. return X
  214. @torch.jit.export
  215. def extra_repr(self):
  216. return (
  217. f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, "
  218. f"quant_min={self.activation_post_process.quant_min}, quant_max={self.activation_post_process.quant_max}, "
  219. f"dtype={self.dtype}, qscheme={self.qscheme}, ch_axis={self.ch_axis}, "
  220. f"scale={self.scale}, zero_point={self.zero_point}"
  221. )
  222. def _save_to_state_dict(self, destination, prefix, keep_vars):
  223. # We cannot currently register scalar values as buffers, so need to manually
  224. # specify serialization here.
  225. super()._save_to_state_dict(destination, prefix, keep_vars)
  226. destination[prefix + "scale"] = self.scale
  227. destination[prefix + "zero_point"] = self.zero_point
  228. def _load_from_state_dict(
  229. self,
  230. state_dict,
  231. prefix,
  232. local_metadata,
  233. strict,
  234. missing_keys,
  235. unexpected_keys,
  236. error_msgs,
  237. ):
  238. # Removing this function throws an error that the size of the loaded tensor does not match the original size
  239. # i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass.
  240. local_state = ["scale", "zero_point"]
  241. for name in local_state:
  242. key = prefix + name
  243. if key in state_dict:
  244. val = state_dict[key]
  245. # Custom handling to allow loading scale and zero_point
  246. # of size N into uninitialized buffers of size 0. The
  247. # buffers are resized here, and the values are copied in
  248. # the default state_dict loading code of the parent.
  249. if name == "scale":
  250. self.scale.resize_(val.shape)
  251. else:
  252. if name != "zero_point":
  253. raise AssertionError(
  254. "Expected 'zero_point' but got different state key"
  255. )
  256. self.zero_point.resize_(val.shape)
  257. # For torchscript module we need to update the attributes here since we do not
  258. # call the `_load_from_state_dict` function defined module.py
  259. if torch.jit.is_scripting():
  260. if name == "scale":
  261. self.scale.copy_(val)
  262. else:
  263. if name != "zero_point":
  264. raise AssertionError(
  265. "Expected 'zero_point' but got different state key"
  266. )
  267. self.zero_point.copy_(val)
  268. elif strict:
  269. missing_keys.append(key)
  270. super()._load_from_state_dict(
  271. state_dict,
  272. prefix,
  273. local_metadata,
  274. strict,
  275. missing_keys,
  276. unexpected_keys,
  277. error_msgs,
  278. )
  279. class FixedQParamsFakeQuantize(FakeQuantize):
  280. """Simulate quantize and dequantize in training time.
  281. Simulate quantize and dequantize with fixed quantization
  282. parameters in training time. Only per tensor quantization
  283. is supported.
  284. """
  285. # TODO: rename observer to observer_ctr
  286. def __init__(self, observer):
  287. super().__init__(observer=observer)
  288. if type(self.activation_post_process) is not FixedQParamsObserver:
  289. raise AssertionError(
  290. f"{self.__class__.__name__}'s observer must be a {FixedQParamsObserver.__name__}"
  291. )
  292. self._observer_ctr = observer
  293. self.scale = self.activation_post_process.scale
  294. self.zero_point = self.activation_post_process.zero_point
  295. if not _is_per_tensor(self.qscheme):
  296. raise AssertionError(
  297. "Only per tensor quantization is supported"
  298. + " FixedQParamsFakeQuantize module, got qscheme:"
  299. + str(self.qscheme)
  300. )
  301. @torch.jit.export
  302. def calculate_qparams(self): # type: ignore[override]
  303. return self.scale, self.zero_point
  304. @torch.jit.export
  305. def extra_repr(self):
  306. """Define a string representation of the object's attributes."""
  307. return (
  308. f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, "
  309. f"scale={self.scale}, zero_point={self.zero_point}, "
  310. f"dtype={self.dtype}, quant_min={self.activation_post_process.quant_min}, "
  311. f"quant_max={self.activation_post_process.quant_max}, qscheme={self.qscheme}"
  312. )
  313. class FusedMovingAvgObsFakeQuantize(FakeQuantize):
  314. r"""Define a fused module to observe the tensor.
  315. Fused module that is used to observe the input tensor (compute min/max), compute
  316. scale/zero_point and fake_quantize the tensor.
  317. This module uses calculation similar MovingAverageMinMaxObserver for the inputs,
  318. to compute the min/max values in order to compute the scale/zero_point.
  319. The qscheme input in the observer is used to differentiate between symmetric/affine
  320. quantization scheme.
  321. The output of this module is given by
  322. x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale
  323. Similar to :class:`~torch.ao.quantization.FakeQuantize`, and accepts the same attributes as the
  324. base class.
  325. """
  326. def __init__(
  327. self,
  328. observer: Any = MovingAverageMinMaxObserver,
  329. quant_min: int = 0,
  330. quant_max: int = 255,
  331. **observer_kwargs: Any,
  332. ) -> None:
  333. super().__init__(observer, quant_min, quant_max, **observer_kwargs)
  334. if not isinstance(
  335. self.activation_post_process,
  336. (MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver),
  337. ):
  338. raise AssertionError(
  339. "Fused observer+fake_quant module only works with MovingAverageMinMaxObserver"
  340. )
  341. self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long))
  342. self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long))
  343. self.is_symmetric_quant = _is_symmetric_quant(
  344. self.activation_post_process.qscheme
  345. )
  346. @torch.jit.export
  347. def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: # type: ignore[override]
  348. return self.activation_post_process.calculate_qparams()
  349. @torch.jit.export
  350. def extra_repr(self) -> str:
  351. return (
  352. f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, "
  353. f"scale={self.scale}, zero_point={self.zero_point}, dtype={self.dtype}, "
  354. f"quant_min={self.activation_post_process.quant_min}, quant_max={self.activation_post_process.quant_max}, "
  355. f"qscheme={self.qscheme}, reduce_range={self.activation_post_process.reduce_range}"
  356. )
  357. def forward(self, X: torch.Tensor) -> torch.Tensor:
  358. return torch.fused_moving_avg_obs_fake_quant(
  359. X,
  360. self.observer_enabled,
  361. self.fake_quant_enabled,
  362. self.activation_post_process.min_val,
  363. self.activation_post_process.max_val,
  364. self.scale,
  365. self.zero_point,
  366. self.activation_post_process.averaging_constant,
  367. self.activation_post_process.quant_min,
  368. self.activation_post_process.quant_max,
  369. self.ch_axis,
  370. self.is_per_channel,
  371. self.is_symmetric_quant,
  372. )
  373. default_fake_quant = FakeQuantize.with_args(
  374. observer=MovingAverageMinMaxObserver,
  375. quant_min=0,
  376. quant_max=255,
  377. dtype=torch.quint8,
  378. qscheme=torch.per_tensor_affine,
  379. reduce_range=True,
  380. )
  381. """
  382. Default fake_quant for activations.
  383. """
  384. default_weight_fake_quant = FakeQuantize.with_args(
  385. observer=MovingAverageMinMaxObserver,
  386. quant_min=-128,
  387. quant_max=127,
  388. dtype=torch.qint8,
  389. qscheme=torch.per_tensor_symmetric,
  390. reduce_range=False,
  391. )
  392. """
  393. Default fake_quant for weights.
  394. Observer is memoryless since averaging_constant is 1.
  395. """
  396. default_dynamic_fake_quant = FakeQuantize.with_args(
  397. observer=MovingAverageMinMaxObserver,
  398. quant_min=0,
  399. quant_max=255,
  400. is_dynamic=True,
  401. dtype=torch.quint8,
  402. averaging_constant=1,
  403. )
  404. """
  405. Default dynamic fake_quant for activations.
  406. """
  407. default_fixed_qparams_range_neg1to1_fake_quant = FixedQParamsFakeQuantize.with_args(
  408. observer=default_fixed_qparams_range_neg1to1_observer
  409. )
  410. default_fixed_qparams_range_0to1_fake_quant = FixedQParamsFakeQuantize.with_args(
  411. observer=default_fixed_qparams_range_0to1_observer
  412. )
  413. # TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases
  414. default_symmetric_fixed_qparams_fake_quant = (
  415. default_fixed_qparams_range_neg1to1_fake_quant
  416. )
  417. default_affine_fixed_qparams_fake_quant = default_fixed_qparams_range_0to1_fake_quant
  418. default_per_channel_weight_fake_quant = FakeQuantize.with_args(
  419. observer=MovingAveragePerChannelMinMaxObserver,
  420. quant_min=-128,
  421. quant_max=127,
  422. dtype=torch.qint8,
  423. qscheme=torch.per_channel_symmetric,
  424. reduce_range=False,
  425. ch_axis=0,
  426. )
  427. """
  428. Default fake_quant for per-channel weights.
  429. Observer is memoryless since averaging_constant is 1.
  430. """
  431. default_embedding_fake_quant = FakeQuantize.with_args(
  432. observer=MovingAveragePerChannelMinMaxObserver,
  433. qscheme=torch.per_channel_affine_float_qparams,
  434. dtype=torch.quint8,
  435. quant_min=0,
  436. quant_max=255,
  437. ch_axis=0,
  438. averaging_constant=1,
  439. )
  440. """
  441. Default fake_quant for embeddings.
  442. Observer is memoryless since averaging_constant is 1.
  443. """
  444. default_embedding_fake_quant_4bit = FakeQuantize.with_args(
  445. observer=MovingAveragePerChannelMinMaxObserver,
  446. qscheme=torch.per_channel_affine_float_qparams,
  447. ch_axis=0,
  448. dtype=torch.quint4x2,
  449. averaging_constant=1,
  450. )
  451. default_histogram_fake_quant = FakeQuantize.with_args(
  452. observer=HistogramObserver,
  453. quant_min=0,
  454. quant_max=255,
  455. dtype=torch.quint8,
  456. qscheme=torch.per_tensor_affine,
  457. reduce_range=True,
  458. )
  459. """
  460. Fake_quant for activations using a histogram..
  461. """
  462. default_fused_act_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(
  463. observer=MovingAverageMinMaxObserver,
  464. quant_min=0,
  465. quant_max=255,
  466. dtype=torch.quint8,
  467. )
  468. """
  469. Fused version of `default_fake_quant`, with improved performance.
  470. """
  471. default_fused_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(
  472. observer=MovingAverageMinMaxObserver,
  473. quant_min=-128,
  474. quant_max=127,
  475. dtype=torch.qint8,
  476. qscheme=torch.per_tensor_symmetric,
  477. )
  478. """
  479. Fused version of `default_weight_fake_quant`, with improved performance.
  480. """
  481. default_fused_per_channel_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(
  482. observer=MovingAveragePerChannelMinMaxObserver,
  483. quant_min=-128,
  484. quant_max=127,
  485. dtype=torch.qint8,
  486. qscheme=torch.per_channel_symmetric,
  487. )
  488. """
  489. Fused version of `default_per_channel_weight_fake_quant`, with improved performance.
  490. """
  491. fused_wt_fake_quant_range_neg_127_to_127 = FusedMovingAvgObsFakeQuantize.with_args(
  492. observer=MovingAverageMinMaxObserver,
  493. quant_min=-127,
  494. quant_max=127,
  495. dtype=torch.qint8,
  496. qscheme=torch.per_tensor_symmetric,
  497. eps=2**-12,
  498. )
  499. """
  500. Fused version of `default_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128.
  501. """
  502. fused_per_channel_wt_fake_quant_range_neg_127_to_127 = (
  503. FusedMovingAvgObsFakeQuantize.with_args(
  504. observer=MovingAveragePerChannelMinMaxObserver,
  505. quant_min=-127,
  506. quant_max=127,
  507. dtype=torch.qint8,
  508. qscheme=torch.per_channel_symmetric,
  509. eps=2**-12,
  510. )
  511. )
  512. """
  513. Fused version of `default_per_channel_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128.
  514. """
  515. def _is_fake_quant_script_module(mod):
  516. """Return true if given mod is an instance of FakeQuantize script module."""
  517. if isinstance(mod, torch.jit.RecursiveScriptModule):
  518. # qualified name looks like '__torch__.torch.ao.quantization.fake_quantize.___torch_mangle_2.FakeQuantize'
  519. suffix = mod._c.qualified_name.split(".", 1)[1]
  520. name = re.sub(r"\.___torch_mangle_\d+", "", suffix)
  521. return (
  522. name == "torch.ao.quantization.fake_quantize.FakeQuantize"
  523. or name
  524. == "torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize"
  525. )
  526. return False
  527. def disable_fake_quant(mod):
  528. """Disable fake quantization for the module.
  529. Disable fake quantization for this module, if applicable. Example usage::
  530. # model is any PyTorch model
  531. model.apply(torch.ao.quantization.disable_fake_quant)
  532. """
  533. if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
  534. mod.disable_fake_quant()
  535. def enable_fake_quant(mod):
  536. """Enable fake quantization for the module.
  537. Enable fake quantization for this module, if applicable. Example usage::
  538. # model is any PyTorch model
  539. model.apply(torch.ao.quantization.enable_fake_quant)
  540. """
  541. if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
  542. mod.enable_fake_quant()
  543. def disable_observer(mod):
  544. """Disable observation for this module.
  545. Disable observation for this module, if applicable. Example usage::
  546. # model is any PyTorch model
  547. model.apply(torch.ao.quantization.disable_observer)
  548. """
  549. if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
  550. mod.disable_observer()
  551. def enable_observer(mod):
  552. """Enable observation for this module.
  553. Enable observation for this module, if applicable. Example usage::
  554. # model is any PyTorch model
  555. model.apply(torch.ao.quantization.enable_observer)
  556. """
  557. if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
  558. mod.enable_observer()