observer.py 79 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159
  1. # mypy: allow-untyped-decorators
  2. # mypy: allow-untyped-defs
  3. # temporarily skip RUF for this file for now, we can re-enable
  4. # after move the affine quantization related things to torchao
  5. # noqa: RUF
  6. """
  7. This module implements observers which are used to collect statistics about
  8. the values observed during calibration (PTQ) or training (QAT).
  9. """
  10. import operator
  11. import re
  12. import warnings
  13. from abc import ABCMeta, abstractmethod
  14. from collections import OrderedDict
  15. from functools import partial
  16. from typing import Any
  17. import torch
  18. import torch.nn as nn
  19. from torch.ao.quantization.utils import (
  20. calculate_qmin_qmax,
  21. check_min_max_valid,
  22. is_per_channel,
  23. is_per_tensor,
  24. validate_qmin_qmax,
  25. )
  26. from torch.fx import Node
  27. __all__ = [
  28. "default_affine_fixed_qparams_observer",
  29. "default_debug_observer",
  30. "default_dynamic_quant_observer",
  31. "default_fixed_qparams_range_0to1_observer",
  32. "default_fixed_qparams_range_neg1to1_observer",
  33. "default_float_qparams_observer",
  34. "default_float_qparams_observer_4bit",
  35. "default_histogram_observer",
  36. "default_observer",
  37. "default_per_channel_weight_observer",
  38. "default_placeholder_observer",
  39. "default_reuse_input_observer",
  40. "default_symmetric_fixed_qparams_observer",
  41. "default_weight_observer",
  42. "get_observer_state_dict",
  43. "load_observer_state_dict",
  44. "per_channel_weight_observer_range_neg_127_to_127",
  45. "weight_observer_range_neg_127_to_127",
  46. "FixedQParamsObserver",
  47. "HistogramObserver",
  48. "MinMaxObserver",
  49. "MovingAverageMinMaxObserver",
  50. "MovingAveragePerChannelMinMaxObserver",
  51. "NoopObserver",
  52. "ObserverBase",
  53. "PerChannelMinMaxObserver",
  54. "PlaceholderObserver",
  55. "RecordingObserver",
  56. "ReuseInputObserver",
  57. "UniformQuantizationObserverBase",
  58. "AffineQuantizedObserverBase",
  59. "Granularity",
  60. "MappingType",
  61. "PerAxis",
  62. "PerBlock",
  63. "PerGroup",
  64. "PerRow",
  65. "PerTensor",
  66. "PerToken",
  67. "TorchAODType",
  68. "ZeroPointDomain",
  69. "get_block_size",
  70. ]
  71. class _PartialWrapper:
  72. def __init__(self, p):
  73. self.p = p
  74. self.callable_args = {}
  75. def __call__(self, *args, **keywords):
  76. # call each arg in callable_args and add them partial, then run with keywords
  77. # skip if arg_name in keywords so its possible to overwrite
  78. for arg_name in self.callable_args:
  79. if arg_name not in keywords:
  80. keywords = {**keywords, arg_name: self.callable_args[arg_name]()}
  81. return self.p(*args, **keywords)
  82. def __repr__(self):
  83. return self.p.__repr__() + self.callable_args.__repr__()
  84. def with_args(self, **kwargs):
  85. return _with_args(self, **kwargs)
  86. def with_callable_args(self, **kwargs):
  87. result = _PartialWrapper(p=self.p)
  88. result.callable_args = {**self.callable_args, **kwargs}
  89. return result
  90. def _with_args(cls_or_self, **kwargs):
  91. r"""Wrapper that allows creation of class factories.
  92. This can be useful when there is a need to create classes with the same
  93. constructor arguments, but different instances. Can be used in conjunction with
  94. _callable_args
  95. Example::
  96. >>> # xdoctest: +SKIP("Undefined vars")
  97. >>> Foo.with_args = classmethod(_with_args)
  98. >>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42)
  99. >>> foo_instance1 = foo_builder()
  100. >>> foo_instance2 = foo_builder()
  101. >>> id(foo_instance1) == id(foo_instance2)
  102. False
  103. """
  104. r = _PartialWrapper(partial(cls_or_self, **kwargs))
  105. return r
  106. def _with_callable_args(cls_or_self, **kwargs):
  107. r"""Wrapper that allows creation of class factories args that need to be
  108. called at construction time.
  109. This can be useful when there is a need to create classes with the same
  110. constructor arguments, but different instances and those arguments should only
  111. be calculated at construction time. Can be used in conjunction with _with_args
  112. Example::
  113. >>> # xdoctest: +SKIP("Undefined vars")
  114. >>> Foo.with_callable_args = classmethod(_with_callable_args)
  115. >>> Foo.with_args = classmethod(_with_args)
  116. >>> foo_builder = Foo.with_callable_args(cur_time=get_time_func).with_args(name="dan")
  117. >>> foo_instance1 = foo_builder()
  118. >>> # wait 50
  119. >>> foo_instance2 = foo_builder()
  120. >>> id(foo_instance1.creation_time) == id(foo_instance2.creation_time)
  121. False
  122. """
  123. r = _PartialWrapper(partial(cls_or_self))
  124. return r.with_callable_args(**kwargs)
  125. ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3:
  126. class ObserverBase(ABC, nn.Module):
  127. r"""Base observer Module.
  128. Any observer implementation should derive from this class.
  129. Concrete observers should follow the same API. In forward, they will update
  130. the statistics of the observed Tensor. And they should provide a
  131. `calculate_qparams` function that computes the quantization parameters given
  132. the collected statistics.
  133. Args:
  134. dtype: dtype argument to the `quantize` node needed to implement the
  135. reference model spec.
  136. is_dynamic: indicator for whether the observer is a placeholder for dynamic quantization
  137. or static quantization
  138. """
  139. def __init__(self, dtype, is_dynamic: bool = False):
  140. super().__init__()
  141. self.dtype = dtype
  142. self.is_dynamic = is_dynamic
  143. @abstractmethod
  144. def forward(self, x):
  145. pass
  146. @abstractmethod
  147. def calculate_qparams(self, **kwargs):
  148. pass
  149. with_args = classmethod(_with_args)
  150. with_callable_args = classmethod(_with_callable_args)
  151. class UniformQuantizationObserverBase(ObserverBase):
  152. r"""Common base for all observers using uniform quantization to calculate
  153. scale and zero_point.
  154. Args:
  155. dtype: dtype argument to the `quantize` node needed to implement the
  156. reference model spec.
  157. qscheme: Quantization scheme to be used.
  158. reduce_range: Reduces the range of the quantized data type by 1 bit.
  159. This is sometimes required to avoid instruction overflow.
  160. quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
  161. quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup.
  162. eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`.
  163. .. warning::
  164. :attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``.
  165. or `torch.int8` or `torch.uint8`
  166. .. warning::
  167. :attr:`qscheme` can only take one of the following options:
  168. - ``torch.per_tensor_affine``
  169. - ``torch.per_tensor_symmetric``
  170. - ``torch.per_channel_affine``
  171. - ``torch.per_channel_symmetric``
  172. """
  173. # Note: the version is shared by all observer types
  174. #
  175. # Version 1/None
  176. # self
  177. #
  178. # Version 2 (base class only, does not include child class buffers)
  179. # self
  180. # |--- eps : Tensor
  181. #
  182. # Version 3
  183. # for HistogramObserver only, changed the shape of uninitialized
  184. # min_val and max_val buffers from torch.Size([0]) to torch.Size([])
  185. # for PerChannelObservers, changed the name of the buffers from min_vals
  186. # to min_val and from max_vals to max_val.
  187. _version = 3
  188. eps: torch.Tensor
  189. def __init__(
  190. self,
  191. dtype=torch.quint8,
  192. qscheme=torch.per_tensor_affine,
  193. reduce_range=False,
  194. quant_min=None,
  195. quant_max=None,
  196. factory_kwargs=None,
  197. eps=torch.finfo(torch.float32).eps,
  198. is_dynamic=False,
  199. **kwargs,
  200. ) -> None:
  201. factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
  202. super().__init__(dtype=dtype, is_dynamic=is_dynamic, **kwargs)
  203. self.qscheme = qscheme
  204. if reduce_range:
  205. warnings.warn(
  206. "Please use quant_min and quant_max to specify the range for observers. \
  207. reduce_range will be deprecated in a future release of PyTorch.",
  208. stacklevel=2,
  209. )
  210. self.reduce_range = reduce_range
  211. self.register_buffer("eps", torch.tensor([eps], **factory_kwargs))
  212. if self.qscheme not in (
  213. torch.per_tensor_affine,
  214. torch.per_tensor_symmetric,
  215. torch.per_channel_affine,
  216. torch.per_channel_symmetric,
  217. torch.per_channel_affine_float_qparams,
  218. ):
  219. raise AssertionError(
  220. "Default Observer only works for per_tensor_affine, per_tensor_symmetric, "
  221. "per_channel_affine, per_channel_symmetric and per_channel_float_qparams quantization scheme"
  222. )
  223. _ALLOWED_DTYPES = (
  224. torch.qint8,
  225. torch.quint8,
  226. torch.quint4x2,
  227. torch.qint32,
  228. torch.int8,
  229. torch.uint8,
  230. torch.int16,
  231. torch.int32,
  232. torch.float8_e5m2,
  233. torch.float8_e4m3fn,
  234. torch.uint16,
  235. )
  236. if self.dtype not in _ALLOWED_DTYPES:
  237. raise AssertionError(
  238. f"Default Observer only works for {_ALLOWED_DTYPES} data type"
  239. )
  240. self.has_customized_qrange = (quant_min is not None) and (quant_max is not None)
  241. if self.has_customized_qrange:
  242. # pyrefly: ignore [bad-argument-type]
  243. validate_qmin_qmax(quant_min, quant_max)
  244. self.quant_min, self.quant_max = calculate_qmin_qmax(
  245. # pyrefly: ignore [bad-argument-type]
  246. quant_min,
  247. # pyrefly: ignore [bad-argument-type]
  248. quant_max,
  249. self.has_customized_qrange,
  250. self.dtype,
  251. self.reduce_range,
  252. )
  253. def _load_from_state_dict(
  254. self,
  255. state_dict,
  256. prefix,
  257. local_metadata,
  258. strict,
  259. missing_keys,
  260. unexpected_keys,
  261. error_msgs,
  262. ):
  263. version = local_metadata.get("version", None)
  264. if version is None or version == 1:
  265. # eps was moved to a buffer in version 2
  266. eps = torch.tensor([torch.finfo(torch.float32).eps])
  267. state_dict[prefix + "eps"] = eps
  268. super()._load_from_state_dict(
  269. state_dict,
  270. prefix,
  271. local_metadata,
  272. strict,
  273. missing_keys,
  274. unexpected_keys,
  275. error_msgs,
  276. )
  277. @torch.jit.export
  278. def _validate_qmin_qmax(self, quant_min: int, quant_max: int) -> None:
  279. r"""Validates that the user-specified quantization range is properly initialized
  280. and within the given bound supported by the observer dtype.
  281. To accommodate lower-bit quantization with respect to the existing torch.qint8 and
  282. torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing
  283. in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax
  284. values are used to calculate static estimates of the scale and zero point for aggressive lower-bit
  285. fake quantization. These estimates are compared against parameters learned through backpropagation.
  286. The related literatures for scale and zero point via backpropagation are as follows:
  287. Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS
  288. Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf
  289. """
  290. # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted
  291. # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer.
  292. if not quant_min <= 0 <= quant_max:
  293. raise AssertionError("Used-specified quantization range must include 0.")
  294. if quant_min >= quant_max:
  295. raise AssertionError(
  296. "qmin must be strictly less than qmax for user-specified quantization range."
  297. )
  298. @torch.jit.export
  299. def _calculate_qparams(
  300. self, min_val: torch.Tensor, max_val: torch.Tensor
  301. ) -> tuple[torch.Tensor, torch.Tensor]:
  302. r"""Calculates the quantization parameters, given min and max
  303. value tensors. Works for both per tensor and per channel cases
  304. Args:
  305. min_val: Minimum values per channel
  306. max_val: Maximum values per channel
  307. Returns:
  308. scales: Scales tensor of shape (#channels,)
  309. zero_points: Zero points tensor of shape (#channels,)
  310. """
  311. # Functionally equivalent to 'determine_qparams' in utils.py. Observers must be torchscriptable however and qscheme
  312. # as far as I can tell is not allowed to passed as a parameter in torchscript functions. This makes refactoring observer
  313. # to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code
  314. # seems unlikely to change (last update over 1 year ago) and when torchscript is fully deprecated we can refactor.
  315. # TODO(jakeszwe, jerryzh168)
  316. if not check_min_max_valid(min_val, max_val):
  317. return torch.tensor([1.0], device=min_val.device.type), torch.tensor(
  318. [0], device=min_val.device.type
  319. )
  320. quant_min, quant_max = self.quant_min, self.quant_max
  321. min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
  322. max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
  323. device = min_val_neg.device
  324. scale = torch.ones(min_val_neg.size(), dtype=torch.float32, device=device)
  325. zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
  326. if (
  327. self.qscheme == torch.per_tensor_symmetric
  328. or self.qscheme == torch.per_channel_symmetric
  329. ):
  330. max_val_pos = torch.max(-min_val_neg, max_val_pos)
  331. scale = max_val_pos / (float(quant_max - quant_min) / 2)
  332. scale = torch.max(scale, self.eps)
  333. if self.dtype in [torch.quint8, torch.uint8]:
  334. if self.has_customized_qrange:
  335. # When customized quantization range is used, down-rounded midpoint of the range is chosen.
  336. zero_point = zero_point.new_full(
  337. zero_point.size(), (quant_min + quant_max) // 2
  338. )
  339. else:
  340. zero_point = zero_point.new_full(zero_point.size(), 128)
  341. elif self.dtype == torch.uint16:
  342. zero_point = zero_point.new_full(zero_point.size(), 2**15)
  343. elif self.qscheme == torch.per_channel_affine_float_qparams:
  344. scale = (max_val - min_val) / float(quant_max - quant_min)
  345. scale = torch.where(scale > self.eps, scale, torch.ones_like(scale))
  346. # We use the quantize function
  347. # xq = Round(Xf * inv_scale + zero_point),
  348. # setting zero_point to (-1 * min *inv_scale) we get
  349. # Xq = Round((Xf - min) * inv_scale)
  350. zero_point = -1 * min_val / scale
  351. else:
  352. scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
  353. scale = torch.max(scale, self.eps)
  354. zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
  355. zero_point = torch.clamp(zero_point, quant_min, quant_max)
  356. # For scalar values, cast them to Tensors of size 1 to keep the shape
  357. # consistent with default values in FakeQuantize.
  358. if len(scale.shape) == 0:
  359. # TODO: switch to scale.item() after adding JIT support
  360. scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device)
  361. if len(zero_point.shape) == 0:
  362. # TODO: switch to zero_point.item() after adding JIT support
  363. zero_point = torch.tensor(
  364. [int(zero_point)], dtype=zero_point.dtype, device=device
  365. )
  366. if self.qscheme == torch.per_channel_affine_float_qparams:
  367. zero_point = torch.tensor(
  368. [float(zero_point)], dtype=zero_point.dtype, device=device
  369. )
  370. return scale, zero_point
  371. @torch.jit.export
  372. def reset_min_max_vals(self):
  373. raise NotImplementedError("Cannot reset min/max values in the given observer.")
  374. # Originally, this class was called `_ObserverBase`. Keeping the old name around
  375. # for backwards compatibility.
  376. # TODO(after v1.13): delete this
  377. _ObserverBase = UniformQuantizationObserverBase
  378. class MinMaxObserver(UniformQuantizationObserverBase):
  379. r"""Observer module for computing the quantization parameters based on the
  380. running min and max values.
  381. This observer uses the tensor min/max statistics to compute the quantization
  382. parameters. The module records the running minimum and maximum of incoming
  383. tensors, and uses this statistic to compute the quantization parameters.
  384. Args:
  385. dtype: dtype argument to the `quantize` node needed to implement the
  386. reference model spec.
  387. qscheme: Quantization scheme to be used
  388. reduce_range: Reduces the range of the quantized data type by 1 bit
  389. quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
  390. quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup.
  391. eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`.
  392. Given running min/max as :math:`x_\text{min}` and :math:`x_\text{max}`,
  393. scale :math:`s` and zero point :math:`z` are computed as:
  394. The running minimum/maximum :math:`x_\text{min/max}` is computed as:
  395. .. math::
  396. \begin{array}{ll}
  397. x_\text{min} &= \begin{cases}
  398. \min(X) & \text{if~}x_\text{min} = \text{None} \\
  399. \min\left(x_\text{min}, \min(X)\right) & \text{otherwise}
  400. \end{cases}\\
  401. x_\text{max} &= \begin{cases}
  402. \max(X) & \text{if~}x_\text{max} = \text{None} \\
  403. \max\left(x_\text{max}, \max(X)\right) & \text{otherwise}
  404. \end{cases}\\
  405. \end{array}
  406. where :math:`X` is the observed tensor.
  407. The scale :math:`s` and zero point :math:`z` are then computed as:
  408. .. math::
  409. \begin{aligned}
  410. \text{if Symmetric:}&\\
  411. &s = 2 \max(|x_\text{min}|, x_\text{max}) /
  412. \left( Q_\text{max} - Q_\text{min} \right) \\
  413. &z = \begin{cases}
  414. 0 & \text{if dtype is qint8} \\
  415. 128 & \text{otherwise}
  416. \end{cases}\\
  417. \text{Otherwise:}&\\
  418. &s = \left( x_\text{max} - x_\text{min} \right ) /
  419. \left( Q_\text{max} - Q_\text{min} \right ) \\
  420. &z = Q_\text{min} - \text{round}(x_\text{min} / s)
  421. \end{aligned}
  422. where :math:`Q_\text{min}` and :math:`Q_\text{max}` are the minimum and
  423. maximum of the quantized data type.
  424. .. warning:: :attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``.
  425. .. note:: If the running minimum equals to the running maximum, the scale
  426. and zero_point are set to 1.0 and 0.
  427. """
  428. min_val: torch.Tensor
  429. max_val: torch.Tensor
  430. def __init__(
  431. self,
  432. dtype=torch.quint8,
  433. qscheme=torch.per_tensor_affine,
  434. reduce_range=False,
  435. quant_min=None,
  436. quant_max=None,
  437. factory_kwargs=None,
  438. eps=torch.finfo(torch.float32).eps,
  439. is_dynamic=False,
  440. **kwargs,
  441. ) -> None:
  442. if not is_per_tensor(qscheme):
  443. raise NotImplementedError(
  444. "MinMaxObserver's qscheme only support torch.per_tensor_symmetric \
  445. and torch.per_tensor_affine."
  446. )
  447. # TODO: MinMaxObserver by itself doesn't support dynamic quantization, but
  448. # if it's inherited by MovingAverageObserver, and averaging_constant is 1, it
  449. # supports dynamic quantization, we may need to better error checking here
  450. # For x86 quantized kernels, we need to ensure that the vpmaddubsw
  451. # instruction does not overflow. We allow for a reduce_range argument to
  452. # observers that reduces the quantized range to (0,127) or (-64, 63).
  453. # For more details see aten/src/ATen/native/quantized/cpu/qconv.cpp
  454. # This is not an optimal choice for non x86 backends as it loses a bit
  455. # of precision for activations.
  456. super().__init__(
  457. dtype=dtype,
  458. qscheme=qscheme,
  459. reduce_range=reduce_range,
  460. quant_min=quant_min,
  461. quant_max=quant_max,
  462. factory_kwargs=factory_kwargs,
  463. eps=eps,
  464. is_dynamic=is_dynamic,
  465. **kwargs,
  466. )
  467. factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
  468. self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs))
  469. self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs))
  470. if (
  471. self.qscheme == torch.per_tensor_symmetric
  472. and self.reduce_range
  473. and self.dtype == torch.quint8
  474. ):
  475. raise NotImplementedError(
  476. "Cannot reduce range for symmetric \
  477. quantization for quint8"
  478. )
  479. def forward(self, x_orig):
  480. r"""Records the running minimum and maximum of ``x``."""
  481. if x_orig.numel() == 0:
  482. return x_orig
  483. x = x_orig.detach() # avoid keeping autograd tape
  484. x = x.to(self.min_val.dtype)
  485. min_val_cur, max_val_cur = torch.aminmax(x)
  486. min_val = torch.min(min_val_cur, self.min_val)
  487. max_val = torch.max(max_val_cur, self.max_val)
  488. self.min_val.copy_(min_val)
  489. self.max_val.copy_(max_val)
  490. return x_orig
  491. @torch.jit.export
  492. def calculate_qparams(self): # type: ignore[override]
  493. r"""Calculates the quantization parameters."""
  494. return self._calculate_qparams(self.min_val, self.max_val)
  495. @torch.jit.export
  496. def extra_repr(self):
  497. return f"min_val={self.min_val}, max_val={self.max_val}"
  498. @torch.jit.export
  499. def reset_min_max_vals(self):
  500. """Resets the min/max values."""
  501. self.min_val.copy_(torch.tensor(float("inf")))
  502. self.max_val.copy_(torch.tensor(float("-inf")))
  503. class MovingAverageMinMaxObserver(MinMaxObserver):
  504. r"""Observer module for computing the quantization parameters based on the
  505. moving average of the min and max values.
  506. This observer computes the quantization parameters based on the moving
  507. averages of minimums and maximums of the incoming tensors. The module
  508. records the average minimum and maximum of incoming tensors, and uses this
  509. statistic to compute the quantization parameters.
  510. Args:
  511. averaging_constant: Averaging constant for min/max.
  512. dtype: dtype argument to the `quantize` node needed to implement the
  513. reference model spec.
  514. qscheme: Quantization scheme to be used
  515. reduce_range: Reduces the range of the quantized data type by 1 bit
  516. quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
  517. quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup.
  518. eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`.
  519. The moving average min/max is computed as follows
  520. .. math::
  521. \begin{array}{ll}
  522. x_\text{min} = \begin{cases}
  523. \min(X) & \text{if~}x_\text{min} = \text{None} \\
  524. (1 - c) x_\text{min} + c \min(X) & \text{otherwise}
  525. \end{cases}\\
  526. x_\text{max} = \begin{cases}
  527. \max(X) & \text{if~}x_\text{max} = \text{None} \\
  528. (1 - c) x_\text{max} + c \max(X) & \text{otherwise}
  529. \end{cases}\\
  530. \end{array}
  531. where :math:`x_\text{min/max}` is the running average min/max, :math:`X` is
  532. is the incoming tensor, and :math:`c` is the ``averaging_constant``.
  533. The scale and zero point are then computed as in
  534. :class:`~torch.ao.quantization.observer.MinMaxObserver`.
  535. .. note:: Only works with ``torch.per_tensor_affine`` quantization scheme.
  536. .. note:: If the running minimum equals to the running maximum, the scale
  537. and zero_point are set to 1.0 and 0.
  538. """
  539. def __init__(
  540. self,
  541. averaging_constant=0.01,
  542. dtype=torch.quint8,
  543. qscheme=torch.per_tensor_affine,
  544. reduce_range=False,
  545. quant_min=None,
  546. quant_max=None,
  547. eps=torch.finfo(torch.float32).eps,
  548. is_dynamic=False,
  549. **kwargs,
  550. ) -> None:
  551. if not is_per_tensor(qscheme):
  552. raise NotImplementedError(
  553. f"MovingAverageMinMaxObserver's qscheme only support \
  554. torch.per_tensor_symmetric and torch.per_tensor_affine. \
  555. but got: {qscheme}"
  556. )
  557. self.averaging_constant = averaging_constant
  558. if is_dynamic and self.averaging_constant != 1:
  559. raise NotImplementedError(
  560. "MovingAverageMinMaxObserver doesn't support dynamic quantization for "
  561. f"averaging constant of {self.averaging_constant}"
  562. )
  563. super().__init__(
  564. dtype=dtype,
  565. qscheme=qscheme,
  566. reduce_range=reduce_range,
  567. quant_min=quant_min,
  568. quant_max=quant_max,
  569. eps=eps,
  570. is_dynamic=is_dynamic,
  571. **kwargs,
  572. )
  573. def forward(self, x_orig):
  574. if x_orig.numel() == 0:
  575. return x_orig
  576. x = x_orig.detach() # avoid keeping autograd tape
  577. x = x.to(self.min_val.dtype)
  578. min_val = self.min_val
  579. max_val = self.max_val
  580. if min_val == float("inf") and max_val == float("-inf"):
  581. min_val, max_val = torch.aminmax(x)
  582. else:
  583. min_val_cur, max_val_cur = torch.aminmax(x)
  584. min_val = min_val + self.averaging_constant * (min_val_cur - min_val)
  585. max_val = max_val + self.averaging_constant * (max_val_cur - max_val)
  586. self.min_val.copy_(min_val)
  587. self.max_val.copy_(max_val)
  588. return x_orig
  589. class PerChannelMinMaxObserver(UniformQuantizationObserverBase):
  590. r"""Observer module for computing the quantization parameters based on the
  591. running per channel min and max values.
  592. This observer uses the tensor min/max statistics to compute the per channel
  593. quantization parameters. The module records the running minimum and maximum
  594. of incoming tensors, and uses this statistic to compute the quantization
  595. parameters.
  596. Args:
  597. ch_axis: Channel axis
  598. dtype: dtype argument to the `quantize` node needed to implement the
  599. reference model spec.
  600. qscheme: Quantization scheme to be used
  601. reduce_range: Reduces the range of the quantized data type by 1 bit
  602. quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
  603. quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup.
  604. eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`.
  605. The quantization parameters are computed the same way as in
  606. :class:`~torch.ao.quantization.observer.MinMaxObserver`, with the difference
  607. that the running min/max values are stored per channel.
  608. Scales and zero points are thus computed per channel as well.
  609. .. note:: If the running minimum equals to the running maximum, the scales
  610. and zero_points are set to 1.0 and 0.
  611. """
  612. min_val: torch.Tensor
  613. max_val: torch.Tensor
  614. def __init__(
  615. self,
  616. ch_axis=0,
  617. dtype=torch.quint8,
  618. qscheme=torch.per_channel_affine,
  619. reduce_range=False,
  620. quant_min=None,
  621. quant_max=None,
  622. factory_kwargs=None,
  623. eps=torch.finfo(torch.float32).eps,
  624. is_dynamic=False,
  625. **kwargs,
  626. ) -> None:
  627. if not is_per_channel(qscheme):
  628. raise NotImplementedError(
  629. "PerChannelMinMaxObserver's qscheme only support \
  630. torch.per_channel_symmetric, torch.per_channel_affine and torch.per_channel_affine_float_qparams."
  631. )
  632. if is_dynamic:
  633. raise NotImplementedError(
  634. "PerChannelMinMaxObserver doesn't support dynamic quantization"
  635. )
  636. super().__init__(
  637. dtype=dtype,
  638. qscheme=qscheme,
  639. reduce_range=reduce_range,
  640. quant_min=quant_min,
  641. quant_max=quant_max,
  642. factory_kwargs=factory_kwargs,
  643. eps=eps,
  644. is_dynamic=is_dynamic,
  645. **kwargs,
  646. )
  647. factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
  648. self.ch_axis = ch_axis
  649. self.register_buffer("min_val", torch.tensor([], **factory_kwargs))
  650. self.register_buffer("max_val", torch.tensor([], **factory_kwargs))
  651. if (
  652. self.qscheme == torch.per_channel_symmetric
  653. and self.reduce_range
  654. and self.dtype == torch.quint8
  655. ):
  656. raise NotImplementedError(
  657. "Cannot reduce range for symmetric quantization for quint8"
  658. )
  659. def forward(self, x_orig):
  660. return self._forward(x_orig)
  661. def _forward(self, x_orig):
  662. if x_orig.numel() == 0:
  663. return x_orig
  664. x = x_orig.detach() # avoid keeping autograd tape
  665. min_val = self.min_val
  666. max_val = self.max_val
  667. x_dim = x.size()
  668. new_axis_list = [i for i in range(len(x_dim))] # noqa: C416
  669. new_axis_list[self.ch_axis] = 0
  670. new_axis_list[0] = self.ch_axis
  671. y = x.permute(new_axis_list)
  672. # Need to match dtype of min/max because the updates to buffers
  673. # are done in place and types need to match for comparisons
  674. y = y.to(self.min_val.dtype)
  675. y = torch.flatten(y, start_dim=1)
  676. if min_val.numel() == 0 or max_val.numel() == 0:
  677. min_val, max_val = torch.aminmax(y, dim=1)
  678. else:
  679. min_val_cur, max_val_cur = torch.aminmax(y, dim=1)
  680. min_val = torch.min(min_val_cur, min_val)
  681. max_val = torch.max(max_val_cur, max_val)
  682. self.min_val.resize_(min_val.shape)
  683. self.max_val.resize_(max_val.shape)
  684. self.min_val.copy_(min_val)
  685. self.max_val.copy_(max_val)
  686. return x_orig
  687. @torch.jit.export
  688. def calculate_qparams(self): # type: ignore[override]
  689. return self._calculate_qparams(self.min_val, self.max_val)
  690. def extra_repr(self):
  691. return f"min_val={self.min_val}, max_val={self.max_val}"
  692. def _load_from_state_dict(
  693. self,
  694. state_dict: dict[str, Any],
  695. prefix: str,
  696. local_metadata: dict[str, torch.Tensor],
  697. strict: bool,
  698. missing_keys: list[str],
  699. unexpected_keys: list[str],
  700. error_msgs: list[str],
  701. ):
  702. version = local_metadata.get("version")
  703. if version is not None and version < 3:
  704. local_state = ["min_vals", "max_vals"]
  705. expected_min_name = "min_vals"
  706. expected_max_name = "max_vals"
  707. else:
  708. local_state = ["min_val", "max_val"]
  709. expected_min_name = "min_val"
  710. expected_max_name = "max_val"
  711. for name in local_state:
  712. key = prefix + name
  713. if key in state_dict:
  714. val = state_dict[key]
  715. # Custom handling to allow loading min_val or max_val
  716. # of size N into uninitialized buffers of size 0. The
  717. # buffers are resized here, and the values are copied in
  718. # the default state_dict loading code of the parent.
  719. if name == expected_min_name:
  720. self.min_val.resize_(val.shape)
  721. elif name == expected_max_name:
  722. self.max_val.resize_(val.shape)
  723. else:
  724. warnings.warn(
  725. f"Observer load_from_state_dict got unexpected name {name}",
  726. stacklevel=2,
  727. )
  728. # For torchscript module we need to update the attributes here since we do not
  729. # call the `_load_from_state_dict` function defined module.py
  730. if torch.jit.is_scripting():
  731. if name == expected_min_name:
  732. self.min_val.copy_(val)
  733. elif name == expected_max_name:
  734. self.max_val.copy_(val)
  735. else:
  736. warnings.warn(
  737. f"Observer load_from_state_dict got unexpected name {name}",
  738. stacklevel=2,
  739. )
  740. elif strict:
  741. missing_keys.append(key)
  742. if not torch.jit.is_scripting():
  743. super()._load_from_state_dict(
  744. state_dict,
  745. prefix,
  746. local_metadata,
  747. False,
  748. missing_keys,
  749. unexpected_keys,
  750. error_msgs,
  751. )
  752. def _load_from_state_dict_script(
  753. self,
  754. state_dict: dict[str, Any],
  755. prefix: str,
  756. local_metadata: dict[str, torch.Tensor],
  757. strict: bool,
  758. missing_keys: list[str],
  759. unexpected_keys: list[str],
  760. error_msgs: list[str],
  761. ):
  762. self._load_from_state_dict(
  763. state_dict,
  764. prefix,
  765. local_metadata,
  766. strict,
  767. missing_keys,
  768. unexpected_keys,
  769. error_msgs,
  770. )
  771. @torch.jit.export
  772. def reset_min_max_vals(self):
  773. """Resets the min/max values."""
  774. # This used to be torch.ones but that does not work because
  775. # JIT compiler can optimize it via common subexpression elimination
  776. # in which case both min_val and max_val point to the same tensor.
  777. self.min_val = torch.rand(
  778. 0,
  779. )
  780. self.max_val = torch.rand(
  781. 0,
  782. )
  783. class MovingAveragePerChannelMinMaxObserver(PerChannelMinMaxObserver):
  784. r"""Observer module for computing the quantization parameters based on the
  785. running per channel min and max values.
  786. This observer uses the tensor min/max statistics to compute the per channel
  787. quantization parameters. The module records the running minimum and maximum
  788. of incoming tensors, and uses this statistic to compute the quantization
  789. parameters.
  790. Args:
  791. averaging_constant: Averaging constant for min/max.
  792. ch_axis: Channel axis
  793. dtype: Quantized data type
  794. qscheme: Quantization scheme to be used
  795. reduce_range: Reduces the range of the quantized data type by 1 bit
  796. quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
  797. quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup.
  798. eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`.
  799. The quantization parameters are computed the same way as in
  800. :class:`~torch.ao.quantization.observer.MovingAverageMinMaxObserver`, with the
  801. difference that the running min/max values are stored per channel.
  802. Scales and zero points are thus computed per channel as well.
  803. .. note:: If the running minimum equals to the running maximum, the scales
  804. and zero_points are set to 1.0 and 0.
  805. """
  806. def __init__(
  807. self,
  808. averaging_constant=0.01,
  809. ch_axis=0,
  810. dtype=torch.quint8,
  811. qscheme=torch.per_channel_affine,
  812. reduce_range=False,
  813. quant_min=None,
  814. quant_max=None,
  815. eps=torch.finfo(torch.float32).eps,
  816. is_dynamic=False,
  817. **kwargs,
  818. ) -> None:
  819. if not is_per_channel(qscheme):
  820. raise NotImplementedError(
  821. "MovingAveragePerChannelMinMaxObserver's qscheme only support \
  822. torch.per_channel_symmetric, torch.per_channel_affine and torch.per_channel_affine_float_qparams."
  823. )
  824. if is_dynamic:
  825. raise NotImplementedError(
  826. "MovingAveragePerChannelMinMaxObserver doesn't support dynamic quantization"
  827. )
  828. super().__init__(
  829. ch_axis=ch_axis,
  830. dtype=dtype,
  831. qscheme=qscheme,
  832. reduce_range=reduce_range,
  833. quant_min=quant_min,
  834. quant_max=quant_max,
  835. eps=eps,
  836. is_dynamic=is_dynamic,
  837. **kwargs,
  838. )
  839. self.averaging_constant = averaging_constant
  840. def forward(self, x_orig):
  841. if x_orig.numel() == 0:
  842. return x_orig
  843. x = x_orig.detach() # avoid keeping autograd tape
  844. x = x.to(self.min_val.dtype)
  845. min_val = self.min_val
  846. max_val = self.max_val
  847. x_dim = x.size()
  848. new_axis_list = [i for i in range(len(x_dim))] # noqa: C416
  849. new_axis_list[self.ch_axis] = 0
  850. new_axis_list[0] = self.ch_axis
  851. y = x.permute(new_axis_list)
  852. y = torch.flatten(y, start_dim=1)
  853. if min_val.numel() == 0 or max_val.numel() == 0:
  854. min_val, max_val = torch.aminmax(y, dim=1)
  855. else:
  856. min_val_cur, max_val_cur = torch.aminmax(y, dim=1)
  857. min_val = min_val + self.averaging_constant * (min_val_cur - min_val)
  858. max_val = max_val + self.averaging_constant * (max_val_cur - max_val)
  859. self.min_val.resize_(min_val.shape)
  860. self.max_val.resize_(max_val.shape)
  861. self.min_val.copy_(min_val)
  862. self.max_val.copy_(max_val)
  863. return x_orig
  864. class HistogramObserver(UniformQuantizationObserverBase):
  865. r"""
  866. The module records the running histogram of tensor values along with
  867. min/max values. ``calculate_qparams`` will calculate scale and zero_point.
  868. Args:
  869. bins: Number of bins to use for the histogram
  870. dtype: dtype argument to the `quantize` node needed to implement the
  871. reference model spec
  872. qscheme: Quantization scheme to be used
  873. reduce_range: Reduces the range of the quantized data type by 1 bit
  874. eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`.
  875. The scale and zero point are computed as follows:
  876. 1. Create the histogram of the incoming inputs.
  877. The histogram is computed continuously, and the ranges per bin change
  878. with every new tensor observed.
  879. 2. Search the distribution in the histogram for optimal min/max values.
  880. The search for the min/max values ensures the minimization of the
  881. quantization error with respect to the floating point model.
  882. 3. Compute the scale and zero point the same way as in the
  883. :class:`~torch.ao.quantization.MinMaxObserver`
  884. """
  885. histogram: torch.Tensor
  886. min_val: torch.Tensor
  887. max_val: torch.Tensor
  888. def __init__(
  889. self,
  890. bins: int = 2048,
  891. dtype: torch.dtype = torch.quint8,
  892. qscheme=torch.per_tensor_affine,
  893. reduce_range=False,
  894. quant_min=None,
  895. quant_max=None,
  896. factory_kwargs=None,
  897. eps=torch.finfo(torch.float32).eps,
  898. is_dynamic=False,
  899. **kwargs,
  900. ) -> None:
  901. if not is_per_tensor(qscheme):
  902. raise NotImplementedError(
  903. "HistogramObserver's qscheme only support torch.per_tensor_symmetric \
  904. and torch.per_tensor_affine."
  905. )
  906. if is_dynamic:
  907. raise NotImplementedError(
  908. "HistogramObserver doesn't support dynamic quantization"
  909. )
  910. # bins: The number of bins used for histogram calculation.
  911. super().__init__(
  912. dtype=dtype,
  913. qscheme=qscheme,
  914. reduce_range=reduce_range,
  915. quant_min=quant_min,
  916. quant_max=quant_max,
  917. factory_kwargs=factory_kwargs,
  918. eps=eps,
  919. is_dynamic=is_dynamic,
  920. **kwargs,
  921. )
  922. factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
  923. self.bins = bins
  924. self.register_buffer("histogram", torch.zeros(self.bins, **factory_kwargs))
  925. self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs))
  926. self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs))
  927. self.dst_nbins = 2 ** torch.iinfo(self.dtype).bits
  928. self.upsample_rate = (
  929. 16 # used to reduce quantization errors when upscaling histogram
  930. )
  931. def _get_norm(
  932. self, delta_begin: torch.Tensor, delta_end: torch.Tensor, density: torch.Tensor
  933. ) -> torch.Tensor:
  934. r"""
  935. Compute the norm of the values uniformaly distributed between
  936. delta_begin and delta_end.
  937. Currently only L2 norm is supported.
  938. norm = density * (integral_{begin, end} x^2)
  939. = density * (end^3 - begin^3) / 3
  940. """
  941. norm = (
  942. delta_end * delta_end * delta_end - delta_begin * delta_begin * delta_begin
  943. ) / 3
  944. return density * norm
  945. def _compute_quantization_error(self, next_start_bin: int, next_end_bin: int):
  946. r"""
  947. Compute the quantization error if we use start_bin to end_bin as the
  948. min and max to do the quantization.
  949. """
  950. bin_width = (self.max_val.item() - self.min_val.item()) / self.bins
  951. dst_bin_width = bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins
  952. if dst_bin_width == 0.0:
  953. return 0.0
  954. src_bin = torch.arange(self.bins, device=self.histogram.device)
  955. # distances from the beginning of first dst_bin to the beginning and
  956. # end of src_bin
  957. src_bin_begin = (src_bin - next_start_bin) * bin_width
  958. src_bin_end = src_bin_begin + bin_width
  959. # which dst_bins the beginning and end of src_bin belong to?
  960. dst_bin_of_begin = torch.clamp(
  961. torch.div(src_bin_begin, dst_bin_width, rounding_mode="floor"),
  962. 0,
  963. self.dst_nbins - 1,
  964. )
  965. dst_bin_of_begin_center = (dst_bin_of_begin + 0.5) * dst_bin_width
  966. dst_bin_of_end = torch.clamp(
  967. torch.div(src_bin_end, dst_bin_width, rounding_mode="floor"),
  968. 0,
  969. self.dst_nbins - 1,
  970. )
  971. density = self.histogram / bin_width
  972. norm = torch.zeros(self.bins, device=self.histogram.device)
  973. delta_begin = src_bin_begin - dst_bin_of_begin_center
  974. delta_end = dst_bin_width / 2
  975. norm += self._get_norm(
  976. delta_begin,
  977. torch.ones(self.bins, device=self.histogram.device) * delta_end,
  978. density,
  979. )
  980. norm += (dst_bin_of_end - dst_bin_of_begin - 1) * self._get_norm(
  981. torch.tensor(-dst_bin_width / 2), torch.tensor(dst_bin_width / 2), density
  982. )
  983. dst_bin_of_end_center = dst_bin_of_end * dst_bin_width + dst_bin_width / 2
  984. delta_begin = -dst_bin_width / 2
  985. delta_end = src_bin_end - dst_bin_of_end_center
  986. norm += self._get_norm(torch.tensor(delta_begin), delta_end, density)
  987. return norm.sum().item()
  988. def _non_linear_param_search(self) -> tuple[torch.Tensor, torch.Tensor]:
  989. r"""Non-linear parameter search.
  990. An approximation for L2 error minimization for selecting min/max.
  991. By selecting new min/max, we filter out outliers in input distribution.
  992. This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in
  993. caffe2/quantization/server/norm_minimization.cc
  994. """
  995. if self.histogram.size()[0] != self.bins:
  996. raise AssertionError("bins mismatch")
  997. bin_width = (self.max_val - self.min_val) / self.bins
  998. # cumulative sum
  999. total = torch.sum(self.histogram).item()
  1000. cSum = torch.cumsum(self.histogram, dim=0)
  1001. stepsize = 1e-5 # granularity
  1002. alpha = 0.0 # lower bound
  1003. beta = 1.0 # upper bound
  1004. start_bin = 0
  1005. end_bin = self.bins - 1
  1006. norm_min = float("inf")
  1007. while alpha < beta:
  1008. # Find the next step
  1009. next_alpha = alpha + stepsize
  1010. next_beta = beta - stepsize
  1011. # find the left and right bins between the quantile bounds
  1012. l = start_bin
  1013. r = end_bin
  1014. # pyrefly: ignore [bad-assignment]
  1015. while l < end_bin and cSum[l] < next_alpha * total:
  1016. l = l + 1
  1017. while r > start_bin and cSum[r] > next_beta * total:
  1018. r = r - 1
  1019. # decide the next move
  1020. next_start_bin = start_bin
  1021. next_end_bin = end_bin
  1022. if (l - start_bin) > (end_bin - r):
  1023. # move the start bin
  1024. next_start_bin = l
  1025. alpha = next_alpha
  1026. else:
  1027. # move the end bin
  1028. next_end_bin = r
  1029. beta = next_beta
  1030. if next_start_bin == start_bin and next_end_bin == end_bin:
  1031. continue
  1032. # calculate the quantization error using next_start_bin and next_end_bin
  1033. norm = self._compute_quantization_error(next_start_bin, next_end_bin)
  1034. if norm > norm_min:
  1035. break
  1036. norm_min = norm
  1037. start_bin = next_start_bin
  1038. end_bin = next_end_bin
  1039. new_min = self.min_val + bin_width * start_bin
  1040. new_max = self.min_val + bin_width * (end_bin + 1)
  1041. return new_min, new_max
  1042. def _upscale_histogram(
  1043. self,
  1044. histogram: torch.Tensor,
  1045. orig_min: torch.Tensor,
  1046. orig_max: torch.Tensor,
  1047. update_min: torch.Tensor,
  1048. update_max: torch.Tensor,
  1049. ):
  1050. # this turns the histogram into a more fine-coarsed histogram to reduce
  1051. # bin quantization errors
  1052. histogram = histogram.repeat_interleave(self.upsample_rate) / self.upsample_rate
  1053. bin_size = (orig_max - orig_min) / (self.bins * self.upsample_rate)
  1054. mid_points_histogram = (
  1055. torch.linspace(
  1056. orig_min,
  1057. orig_max,
  1058. self.bins * self.upsample_rate + 1,
  1059. device=orig_min.device,
  1060. )[:-1].to(histogram.device)
  1061. + 0.5 * bin_size
  1062. )
  1063. boundaries_new_histogram = torch.linspace(
  1064. update_min, update_max, self.bins + 1, device=update_min.device
  1065. ).to(histogram.device)
  1066. # this maps the mid-points of the histogram to the new histogram's space
  1067. bucket_assignments = (
  1068. torch.bucketize(mid_points_histogram, boundaries_new_histogram, right=True)
  1069. - 1
  1070. )
  1071. # this then maps the histogram mid-points in the new space, weighted by the original histogram's values
  1072. # this is just the old histogram in the new histogram's space
  1073. # In case due to numerical issues the values land higher/lower than the maximum/minimum
  1074. bucket_assignments[bucket_assignments >= self.bins] = self.bins - 1
  1075. bucket_assignments[bucket_assignments < 0] = 0
  1076. update_histogram = torch.bincount(
  1077. bucket_assignments, weights=histogram, minlength=self.bins
  1078. )
  1079. return update_histogram
  1080. def _combine_histograms(
  1081. self,
  1082. orig_hist: torch.Tensor,
  1083. orig_min: torch.Tensor,
  1084. orig_max: torch.Tensor,
  1085. update_hist: torch.Tensor,
  1086. update_min: torch.Tensor,
  1087. update_max: torch.Tensor,
  1088. ) -> torch.Tensor:
  1089. # If the new min and max are the same as the current min and max,
  1090. # we can just add the new histogram to the original histogram
  1091. if update_min == orig_min and update_max == orig_max:
  1092. return orig_hist + update_hist
  1093. # If the orig hist only has one value (i.e., the min and max are the same)
  1094. # we can just add it into new histogram
  1095. if orig_min == orig_max:
  1096. bin_value = torch.sum(orig_hist)
  1097. transformed_orig_hist = (
  1098. torch.histc(orig_min, bins=self.bins, min=update_min, max=update_max) # type: ignore[arg-type]
  1099. * bin_value
  1100. )
  1101. return transformed_orig_hist + update_hist
  1102. # We assume the update_hist is already in the target range, we will map the orig_max to it
  1103. if update_min > orig_min:
  1104. raise AssertionError("update_min must be <= orig_min")
  1105. if update_max < orig_max:
  1106. raise AssertionError("update_max must be >= orig_max")
  1107. # Now we need to turn the old_histogram, into the range of the new histogram
  1108. transformed_orig_hist = self._upscale_histogram(
  1109. orig_hist,
  1110. orig_min,
  1111. orig_max,
  1112. update_min,
  1113. update_max,
  1114. )
  1115. return update_hist + transformed_orig_hist
  1116. def reset_histogram(
  1117. self, x: torch.Tensor, min_val: torch.Tensor, max_val: torch.Tensor
  1118. ) -> None:
  1119. self.min_val.resize_(min_val.shape)
  1120. self.min_val.copy_(min_val)
  1121. self.max_val.resize_(max_val.shape)
  1122. self.max_val.copy_(max_val)
  1123. if min_val.numel() != 1 or max_val.numel() != 1:
  1124. raise AssertionError("histogram min/max values must be scalar.")
  1125. new_histogram = torch.histc(x, self.bins, min=min_val, max=max_val) # type: ignore[arg-type]
  1126. self.histogram.detach_().resize_(new_histogram.shape)
  1127. self.histogram.copy_(new_histogram)
  1128. def forward(self, x_orig: torch.Tensor) -> torch.Tensor: # pyre-ignore[14]
  1129. if x_orig.numel() == 0:
  1130. return x_orig
  1131. x = x_orig.detach()
  1132. x_min, x_max = torch.aminmax(x)
  1133. # want to ignore torch.inf since we don't actually
  1134. # want to make our quantization range infinite
  1135. # and in practice those values will be clamped
  1136. if x_min == -torch.inf or x_max == torch.inf:
  1137. warnings.warn(
  1138. "torch.inf detected in input tensor, ignoring input", stacklevel=2
  1139. )
  1140. x = x[x.abs() != torch.inf]
  1141. if x.numel() == 0:
  1142. return x_orig
  1143. x_min, x_max = torch.aminmax(x)
  1144. current_min = self.min_val
  1145. current_max = self.max_val
  1146. is_uninitialized = self.min_val == float("inf") or self.max_val == float("-inf")
  1147. if is_uninitialized:
  1148. self.reset_histogram(x, x_min, x_max)
  1149. else:
  1150. update_min, update_max = x_min, x_max
  1151. new_min = torch.min(current_min, update_min)
  1152. new_max = torch.max(current_max, update_max)
  1153. # TODO: For some reason, this is required for it to pass torchscript test
  1154. # new_min and new_max should already have requires_grad set to False
  1155. new_min, new_max = new_min.detach(), new_max.detach()
  1156. update_histogram = torch.histc(
  1157. x,
  1158. self.bins,
  1159. min=new_min, # type: ignore[arg-type]
  1160. max=new_max, # type: ignore[arg-type]
  1161. ).to(self.histogram.device)
  1162. if new_min == current_min and new_max == current_max:
  1163. combined_histogram = self.histogram + update_histogram
  1164. self.histogram.detach_().resize_(combined_histogram.shape)
  1165. self.histogram.copy_(combined_histogram)
  1166. else:
  1167. combined_histogram = self._combine_histograms(
  1168. self.histogram,
  1169. current_min,
  1170. current_max,
  1171. update_histogram,
  1172. new_min,
  1173. new_max,
  1174. )
  1175. self.histogram.detach_().resize_(combined_histogram.shape)
  1176. self.histogram.copy_(combined_histogram)
  1177. self.min_val.detach_().resize_(new_min.shape)
  1178. self.min_val.copy_(new_min)
  1179. self.max_val.detach_().resize_(new_max.shape)
  1180. self.max_val.copy_(new_max)
  1181. return x_orig
  1182. @torch.jit.export
  1183. def calculate_qparams(self): # type: ignore[override]
  1184. is_uninitialized = self.min_val == float("inf") and self.max_val == float(
  1185. "-inf"
  1186. )
  1187. if is_uninitialized:
  1188. warnings.warn(
  1189. "must run observer before calling calculate_qparams.\
  1190. Returning default scale and zero point ",
  1191. stacklevel=2,
  1192. )
  1193. return torch.tensor([1.0], device=self.min_val.device.type), torch.tensor(
  1194. [0], device=self.min_val.device.type
  1195. )
  1196. if self.bins != len(self.histogram):
  1197. raise AssertionError(
  1198. "The number of bins in histogram should be equal to the number of bins "
  1199. "supplied while making this observer"
  1200. )
  1201. new_min, new_max = self._non_linear_param_search()
  1202. return self._calculate_qparams(new_min, new_max)
  1203. def _save_to_state_dict(self, destination, prefix, keep_vars):
  1204. super()._save_to_state_dict(destination, prefix, keep_vars)
  1205. destination[prefix + "min_val"] = self.min_val
  1206. destination[prefix + "max_val"] = self.max_val
  1207. def _load_from_state_dict(
  1208. self,
  1209. state_dict,
  1210. prefix,
  1211. local_metadata,
  1212. strict,
  1213. missing_keys,
  1214. unexpected_keys,
  1215. error_msgs,
  1216. ):
  1217. version = local_metadata.get("version", None)
  1218. if version is None or version < 3:
  1219. # if min_val and max_val are not initialized, update their shape
  1220. # to account for the differences between v2 and v3
  1221. min_val_name, max_val_name = prefix + "min_val", prefix + "max_val"
  1222. if min_val_name in state_dict:
  1223. if state_dict[min_val_name].shape == torch.Size([0]):
  1224. state_dict[min_val_name] = torch.tensor(float("inf"))
  1225. if max_val_name in state_dict:
  1226. if state_dict[max_val_name].shape == torch.Size([0]):
  1227. state_dict[max_val_name] = torch.tensor(float("-inf"))
  1228. local_state = ["min_val", "max_val"]
  1229. for name in local_state:
  1230. key = prefix + name
  1231. if key in state_dict:
  1232. val = state_dict[key]
  1233. setattr(self, name, val)
  1234. elif strict:
  1235. missing_keys.append(key)
  1236. super()._load_from_state_dict(
  1237. state_dict,
  1238. prefix,
  1239. local_metadata,
  1240. strict,
  1241. missing_keys,
  1242. unexpected_keys,
  1243. error_msgs,
  1244. )
  1245. def extra_repr(self):
  1246. return f"min_val={self.min_val}, max_val={self.max_val}"
  1247. class FixedQParamsObserver(ObserverBase):
  1248. r"""
  1249. Observer that simulates quantize and dequantize with fixed
  1250. quantization parameters in training time. Only per tensor
  1251. quantization is supported.
  1252. Args:
  1253. `scale` (float): fixed scale for the observer
  1254. `zero_point` (int): fixed zero point for the observer
  1255. `dtype`, `qscheme`, `quant_min`, `quant_max`
  1256. """
  1257. scale: torch.Tensor
  1258. zero_point: torch.Tensor
  1259. def __init__(
  1260. self,
  1261. scale,
  1262. zero_point,
  1263. dtype=torch.quint8,
  1264. qscheme=torch.per_tensor_affine,
  1265. quant_min=0,
  1266. quant_max=255,
  1267. is_dynamic=False,
  1268. **kwargs,
  1269. ):
  1270. if is_dynamic:
  1271. raise NotImplementedError(
  1272. "FixedQParamsObserver doesn't support dynamic quantization"
  1273. )
  1274. super().__init__(dtype=dtype, is_dynamic=is_dynamic, **kwargs)
  1275. self.quant_min = quant_min
  1276. self.quant_max = quant_max
  1277. self.register_buffer("scale", torch.tensor([scale], dtype=torch.float))
  1278. self.register_buffer("zero_point", torch.tensor([zero_point], dtype=torch.int))
  1279. self.dtype = dtype
  1280. self.qscheme = qscheme
  1281. def forward(self, X):
  1282. return X
  1283. @torch.jit.export
  1284. def calculate_qparams(self): # type: ignore[override]
  1285. return self.scale, self.zero_point
  1286. class PlaceholderObserver(ObserverBase):
  1287. r"""
  1288. Observer that doesn't do anything and just passes its configuration to the
  1289. quantized module's ``.from_float()``.
  1290. Can be used for quantization to float16 which doesn't require determining
  1291. ranges.
  1292. Args:
  1293. dtype: dtype argument to the `quantize` node needed to implement the
  1294. reference model spec.
  1295. quant_min: minimum value in quantized domain (TODO: align behavior with other observers)
  1296. quant_max: maximum value in quantized domain
  1297. custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation
  1298. (Can be used in Graph Mode Passes for special case ops).
  1299. compute_dtype (deprecated): if set, marks the future quantize function to use
  1300. dynamic quantization instead of static quantization.
  1301. This field is deprecated, use `is_dynamic=True` instead.
  1302. is_dynamic: if True, the `quantize` function in the reference model
  1303. representation taking stats from this observer instance will
  1304. use dynamic quantization.
  1305. """
  1306. def __init__(
  1307. self,
  1308. dtype=torch.float32,
  1309. custom_op_name="",
  1310. compute_dtype=None,
  1311. quant_min=None,
  1312. quant_max=None,
  1313. qscheme=None,
  1314. eps=None,
  1315. is_dynamic=False,
  1316. ) -> None:
  1317. super().__init__(dtype=dtype, is_dynamic=is_dynamic)
  1318. if qscheme is None:
  1319. qscheme = torch.per_tensor_affine
  1320. if eps is None:
  1321. eps = torch.finfo(torch.float32).eps
  1322. # dtype of input of the target operator, e.g. for dynamic quantization
  1323. # ops, the dtype will be float32
  1324. self.dtype = dtype
  1325. self.qscheme = qscheme
  1326. self.quant_min = quant_min
  1327. self.quant_max = quant_max
  1328. self.eps = eps
  1329. self.custom_op = custom_op_name
  1330. # used for configuration of computation type for dynamic quantization
  1331. if compute_dtype:
  1332. is_dynamic = True
  1333. warnings.warn(
  1334. "Please use `is_dynamic` instead of `compute_dtype`. \
  1335. `compute_dtype` will be deprecated in a future release \
  1336. of PyTorch.",
  1337. stacklevel=2,
  1338. )
  1339. def forward(self, x):
  1340. return x
  1341. @torch.jit.export
  1342. def extra_repr(self):
  1343. return f"dtype={self.dtype}, is_dynamic={self.is_dynamic}"
  1344. @torch.jit.export
  1345. def calculate_qparams(self): # type: ignore[override]
  1346. raise Exception( # noqa: TRY002
  1347. "calculate_qparams should not be called for PlaceholderObserver"
  1348. )
  1349. class RecordingObserver(ObserverBase):
  1350. r"""
  1351. The module is mainly for debug and records the tensor values during runtime.
  1352. Args:
  1353. dtype: Quantized data type
  1354. qscheme: Quantization scheme to be used
  1355. reduce_range: Reduces the range of the quantized data type by 1 bit
  1356. """
  1357. __annotations__ = {"tensor_val": list[torch.Tensor | None]}
  1358. def __init__(self, dtype=torch.quint8):
  1359. super().__init__(dtype=dtype, is_dynamic=False)
  1360. self.tensor_val = []
  1361. def forward(self, x):
  1362. self.tensor_val.append(x.clone())
  1363. return x
  1364. @torch.jit.export
  1365. def calculate_qparams(self): # type: ignore[override]
  1366. raise Exception( # noqa: TRY002
  1367. "calculate_qparams should not be called for RecordingObserver"
  1368. )
  1369. @torch.jit.export
  1370. def get_tensor_value(self):
  1371. return self.tensor_val
  1372. class NoopObserver(ObserverBase):
  1373. r"""
  1374. Observer that doesn't do anything and just passes its configuration to the
  1375. quantized module's ``.from_float()``.
  1376. Primarily used for quantization to float16 which doesn't require determining
  1377. ranges.
  1378. Args:
  1379. dtype: Quantized data type
  1380. custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation
  1381. (Can be used in Graph Mode Passes for special case ops).
  1382. """
  1383. def __init__(self, dtype=torch.float16, custom_op_name="") -> None:
  1384. super().__init__(dtype=dtype, is_dynamic=False)
  1385. self.dtype = dtype
  1386. self.custom_op = custom_op_name
  1387. def forward(self, x):
  1388. return x
  1389. @torch.jit.export
  1390. def calculate_qparams(self): # type: ignore[override]
  1391. raise Exception( # noqa: TRY002
  1392. "calculate_qparams should not be called for NoopObserver"
  1393. )
  1394. class ReuseInputObserver(ObserverBase):
  1395. r"""This observer is used when we want to reuse the observer from the operator
  1396. that produces the input Tensor, typically used for operators like reshape, e.g.
  1397. ```
  1398. x0 = ...
  1399. x1 = x0.reshape()
  1400. ```
  1401. if we configure x0 to be observed by some observer, let's say MinMaxObserver,
  1402. and reshape is configured with ReuseInputObserver, we'll reuse the observer instance
  1403. for x0 for x1 (output of reshape). If x0 is not observed, we also won't observe x1.
  1404. Note: this is only enabled in FX Graph Mode Quantization
  1405. """
  1406. def __init__(self) -> None:
  1407. super().__init__(torch.quint8, is_dynamic=False)
  1408. def forward(self, x):
  1409. return x
  1410. @torch.jit.export
  1411. def calculate_qparams(self): # type: ignore[override]
  1412. raise Exception( # noqa: TRY002
  1413. "calculate_qparams should not be called for ReuseInputObserver"
  1414. )
  1415. """
  1416. # Experimental Affine Quantization Feature START
  1417. We plan to merge the following with torchao repo after we move pt2e flow to torchao
  1418. copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py
  1419. """
  1420. from dataclasses import dataclass
  1421. from enum import auto, Enum
  1422. class MappingType(Enum):
  1423. """How floating point number is mapped to integer number
  1424. symmetric mapping means floating point range is symmetrically mapped to integer range
  1425. let's say we have floating point range (-3.5, 10.2) and integer range (-8, 7) (int4)
  1426. we'll use (-10.2, 10.2) as the range for floating point and map that to (-8, 7)
  1427. e.g. scale = (10.2 - (-10.2)) / (7 - (-8))
  1428. SYMMETRIC_NO_CLIPPING_ERR is a variant of symmetric mapping, where the scale is the max of smin
  1429. and smax, where smin = min_val_neg / quant_min, and smax = max_val_pos / quant_max. By calculating
  1430. smin and smax individually, there can be less round error on negative values, and no out-of-range
  1431. of all floating point values.
  1432. asymmetric mapping means we just directly map the floating point range to integer range,
  1433. for the above example, we will map (-3.5, 10.2) to (-8, 7) and calculate quantization parameter
  1434. based on this mapping
  1435. e.g. scale = (10.2 - (-3.5)) / (7 - (-8))
  1436. """
  1437. SYMMETRIC = auto()
  1438. SYMMETRIC_NO_CLIPPING_ERR = auto()
  1439. ASYMMETRIC = auto()
  1440. class ZeroPointDomain(Enum):
  1441. """Enum that indicate whether zero_point is in integer domain or floating point domain
  1442. integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer)
  1443. float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale
  1444. none domain: quantized_val = (float_val / scale)
  1445. """
  1446. INT = auto()
  1447. FLOAT = auto()
  1448. NONE = auto()
  1449. class TorchAODType(Enum):
  1450. """
  1451. Placeholder for dtypes that do not exist in PyTorch core yet.
  1452. """
  1453. # torch.int1 to torch.int7 will be added to PyTorch 2.6
  1454. # These will remain here for BC with older PyTorch versions
  1455. INT1 = auto()
  1456. INT2 = auto()
  1457. INT3 = auto()
  1458. INT4 = auto()
  1459. INT5 = auto()
  1460. INT6 = auto()
  1461. INT7 = auto()
  1462. @dataclass(frozen=True)
  1463. class Granularity:
  1464. """
  1465. Base class for representing the granularity of quantization.
  1466. This class serves as a parent for specific granularity types used in
  1467. quantization operations, such as per-tensor or per-axis quantization.
  1468. """
  1469. @dataclass(frozen=True)
  1470. class PerBlock(Granularity):
  1471. """
  1472. Represents per-block granularity in quantization. See
  1473. :func:`~torchao.quantization.quant_primitives.quantize_affine` for docs for
  1474. `block_size`
  1475. Attributes:
  1476. block_size (Tuple[int, ...]): The size of each quantization group
  1477. """
  1478. block_size: tuple[int, ...]
  1479. @dataclass(frozen=True)
  1480. class PerTensor(Granularity):
  1481. """
  1482. Represents per-tensor granularity in quantization.
  1483. This granularity type calculates the quantization parameters
  1484. based off the entire tensor.
  1485. """
  1486. @dataclass(frozen=True)
  1487. class PerAxis(Granularity):
  1488. """
  1489. Represents per-axis granularity in quantization.
  1490. This granularity type calculates different quantization parameters
  1491. along a specified axis of the tensor.
  1492. For example if the input tensor is shape [8, 16] and axis=0, then
  1493. the quantization parameters are calculated for each row of the tensor.
  1494. Giving a total of 8 quantization parameters.
  1495. Attributes:
  1496. axis (int): The axis along which reduction is performed.
  1497. """
  1498. axis: int
  1499. @dataclass(frozen=True)
  1500. class PerGroup(Granularity):
  1501. """
  1502. Represents per-channel group granularity in quantization.
  1503. This granularity type calculates different quantization parameters
  1504. for each group of <group_size> elements.
  1505. For example if the input tensor is shape [8, 16], and the group size is 4, then
  1506. the input tensor is reshaped to [64, 4]
  1507. quantization parameters are calculated for each group of 4 elements,
  1508. giving a total of 64 quantization parameters.
  1509. Attributes:
  1510. group_size (int): The size of each quantization group
  1511. """
  1512. group_size: int
  1513. class PerRow(Granularity):
  1514. """
  1515. Represents row-wise granularity in quantization.
  1516. This is a special case of per-axis quantization and is unique to Float8 matmuls
  1517. where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight
  1518. is quantized with a block_size of (1, weight.shape[1]).
  1519. """
  1520. class PerToken(Granularity):
  1521. """
  1522. Represents per-token granularity in quantization.
  1523. This granularity type calculates a different set of quantization parameters
  1524. for each token, which is represented as the last dimension of the tensor.
  1525. For example, if the input tensor has shape [2, 3, 4], then there are 6 tokens
  1526. with 4 elements each, and we will calculate 6 sets of quantization parameters,
  1527. one for each token.
  1528. If the input tensor has only two dimensions, e.g. [8, 16], then this is
  1529. equivalent to `PerAxis(axis=0)`, which yields 8 sets of quantization parameters.
  1530. """
  1531. def get_block_size(
  1532. input_shape: tuple[int, ...], granularity: Granularity
  1533. ) -> tuple[int, ...]:
  1534. """Get the block size based on the input shape and granularity type.
  1535. Args:
  1536. input_shape: The input tensor shape possibly more than 2 dimensions
  1537. granularity: The granularity type of the quantization
  1538. """
  1539. if not isinstance(granularity, Granularity):
  1540. raise AssertionError(
  1541. "Please provide an instance of Granularity, not subclass of it"
  1542. )
  1543. if isinstance(granularity, PerTensor):
  1544. return input_shape
  1545. elif isinstance(granularity, PerAxis):
  1546. block_size = list(input_shape)
  1547. block_size[granularity.axis] = 1
  1548. return tuple(block_size)
  1549. elif isinstance(granularity, PerRow):
  1550. return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
  1551. elif isinstance(granularity, PerGroup):
  1552. if len(input_shape) != 2:
  1553. raise AssertionError(
  1554. f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}"
  1555. )
  1556. return (1, granularity.group_size)
  1557. elif isinstance(granularity, PerToken):
  1558. block_size = [1] * len(input_shape)
  1559. block_size[-1] = input_shape[-1]
  1560. return tuple(block_size)
  1561. raise ValueError(f"Unsupported Granularity: {granularity}")
  1562. class AffineQuantizedObserverBase(ABC, torch.nn.Module):
  1563. """Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization)
  1564. Args:
  1565. `granularity` and `block_size`: The granularity of the quantization,
  1566. must specify at least one, if both are specified `block_size` takes precedence
  1567. Current supported granularity type are `PerTensor` and `PerAxis`
  1568. other args: please see `:class:torchao.dtypes.AffineQuantizedTensor`
  1569. """
  1570. with_args = classmethod(_with_args)
  1571. def __init__(
  1572. self,
  1573. mapping_type: MappingType,
  1574. target_dtype: torch.dtype,
  1575. granularity: Granularity,
  1576. quant_min: int | None = None,
  1577. quant_max: int | None = None,
  1578. eps: float | None = None,
  1579. scale_dtype: torch.dtype | None = None,
  1580. zero_point_dtype: torch.dtype | None = None,
  1581. preserve_zero: bool = True,
  1582. zero_point_domain: ZeroPointDomain | None = ZeroPointDomain.INT,
  1583. # there could be some extra args that's ignored
  1584. **kwargs,
  1585. ):
  1586. super().__init__()
  1587. if granularity is None:
  1588. raise AssertionError("granularity is None")
  1589. self.mapping_type = mapping_type
  1590. self.target_dtype = target_dtype
  1591. self.granularity = granularity
  1592. self.quant_min = quant_min
  1593. self.quant_max = quant_max
  1594. self.eps = eps
  1595. self.scale_dtype = scale_dtype
  1596. self.zero_point_dtype = zero_point_dtype
  1597. self.preserve_zero = preserve_zero
  1598. self.zero_point_domain = zero_point_domain
  1599. # populatd during forward
  1600. self.block_size = None
  1601. self.original_dtype = None
  1602. @abstractmethod
  1603. def forward(self, input: torch.Tensor) -> torch.Tensor:
  1604. """forward function should take the input tensor
  1605. and updates internal stats and return the original input Tensor
  1606. """
  1607. @abstractmethod
  1608. def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]:
  1609. """Calculate quantization parameter based on the stats attached to the observer module
  1610. and returns a tuple of scale and zero_point Tensor
  1611. """
  1612. def convert(self, model: torch.fx.GraphModule, observer_node: Node):
  1613. """
  1614. Converts the observer node in the graph into its quantized representation
  1615. Args:
  1616. model: graph module to convert the observer node in
  1617. observer_node: the observer node to convert
  1618. """
  1619. from torch.ao.quantization.fx.utils import create_getattr_from_value
  1620. with model.graph.inserting_before(observer_node):
  1621. if self.block_size is None:
  1622. raise AssertionError("Expecting block_size to be populated")
  1623. if self.original_dtype is None:
  1624. raise AssertionError("Expecting original_dtype to be populated")
  1625. if hasattr(self, "is_dynamic") and self.is_dynamic:
  1626. choose_qparams_affine = model.graph.call_function(
  1627. torch.ops.pt2e_quant.choose_qparams_affine,
  1628. (
  1629. observer_node.args[0],
  1630. self.mapping_type.name,
  1631. self.block_size,
  1632. self.target_dtype,
  1633. self.quant_min,
  1634. self.quant_max,
  1635. self.eps,
  1636. self.scale_dtype,
  1637. self.zero_point_dtype,
  1638. self.preserve_zero,
  1639. # pyrefly: ignore [missing-attribute]
  1640. self.zero_point_domain.name,
  1641. ),
  1642. )
  1643. scale_node = model.graph.call_function(
  1644. operator.getitem, (choose_qparams_affine, 0)
  1645. )
  1646. zero_point_node = model.graph.call_function(
  1647. operator.getitem, (choose_qparams_affine, 1)
  1648. )
  1649. else:
  1650. scale, zero_point = self.calculate_qparams()
  1651. scale_node = create_getattr_from_value(
  1652. model,
  1653. model.graph,
  1654. "_scale",
  1655. scale,
  1656. scale.device if isinstance(scale, torch.Tensor) else None,
  1657. )
  1658. zero_point_node = create_getattr_from_value(
  1659. model,
  1660. model.graph,
  1661. "_zero_point",
  1662. zero_point,
  1663. zero_point.device if isinstance(zero_point, torch.Tensor) else None,
  1664. )
  1665. q_node = model.graph.call_function(
  1666. torch.ops.pt2e_quant.quantize_affine,
  1667. (
  1668. observer_node.args[0],
  1669. self.block_size,
  1670. scale_node,
  1671. zero_point_node,
  1672. self.target_dtype,
  1673. self.quant_min,
  1674. self.quant_max,
  1675. # pyrefly: ignore [missing-attribute]
  1676. self.zero_point_domain.name,
  1677. ),
  1678. {},
  1679. )
  1680. dq_node = model.graph.call_function(
  1681. torch.ops.pt2e_quant.dequantize_affine,
  1682. (
  1683. q_node,
  1684. self.block_size,
  1685. scale_node,
  1686. zero_point_node,
  1687. self.target_dtype,
  1688. self.quant_min,
  1689. self.quant_max,
  1690. # pyrefly: ignore [missing-attribute]
  1691. self.zero_point_domain.name,
  1692. ),
  1693. {"output_dtype": self.original_dtype},
  1694. )
  1695. observer_node.replace_all_uses_with(dq_node)
  1696. model.graph.erase_node(observer_node)
  1697. def _is_observer_script_module(mod, obs_type_name):
  1698. """Returns true if given mod is an instance of Observer script module."""
  1699. if isinstance(mod, torch.jit.RecursiveScriptModule):
  1700. # qualified name looks like '__torch__.torch.ao.quantization.observer.___torch_mangle_2.MinMaxObserver'
  1701. suffix = mod._c.qualified_name.split(".", 1)[1]
  1702. name = re.sub(r"\.___torch_mangle_\d+", "", suffix)
  1703. return obs_type_name in name
  1704. return False
  1705. # Experimental Affine Quantization Feature END
  1706. def _is_activation_post_process(module):
  1707. return isinstance(
  1708. module,
  1709. (
  1710. torch.ao.quantization.ObserverBase,
  1711. torch.ao.quantization.FakeQuantizeBase,
  1712. AffineQuantizedObserverBase,
  1713. ),
  1714. ) or _is_observer_script_module(module, "quantization.observer")
  1715. def _is_per_channel_script_obs_instance(module):
  1716. if isinstance(module, torch.jit.RecursiveScriptModule):
  1717. return _is_observer_script_module(
  1718. module, "quantization.observer.PerChannelMinMaxObserver"
  1719. ) or _is_observer_script_module(
  1720. module, "quantization.observer.MovingAveragePerChannelMinMaxObserver"
  1721. )
  1722. return False
  1723. def get_observer_state_dict(mod):
  1724. r"""
  1725. Returns the state dict corresponding to the observer stats.
  1726. Traverse the model state_dict and extract out the stats.
  1727. """
  1728. od = OrderedDict()
  1729. if isinstance(mod, torch.jit.RecursiveScriptModule):
  1730. for k, v in mod.state_dict().items():
  1731. if "observer" in k:
  1732. od[k] = v
  1733. else:
  1734. # path for GraphModule and nn.Module (eager mode)
  1735. for k, v in mod.state_dict().items():
  1736. if "activation_post_process" in k:
  1737. od[k] = v
  1738. od._metadata = mod.state_dict()._metadata # type: ignore[attr-defined]
  1739. return od
  1740. def load_observer_state_dict(mod, obs_dict):
  1741. r"""
  1742. Given input model and a state_dict containing model observer stats,
  1743. load the stats back into the model. The observer state_dict can be saved
  1744. using torch.ao.quantization.get_observer_state_dict
  1745. """
  1746. missing_keys: list[str] = []
  1747. unexpected_keys: list[str] = []
  1748. for name, module in mod.named_modules():
  1749. prefix = name + "."
  1750. if _is_activation_post_process(module):
  1751. if _is_per_channel_script_obs_instance(module):
  1752. # For per-channel observers we need to call a custom load_from_state_dict to resize the tensor.
  1753. # However this is not called when the module is scripted and we end up calling the default one in module.py
  1754. module._load_from_state_dict_script(
  1755. obs_dict, prefix, {}, True, missing_keys, unexpected_keys, []
  1756. )
  1757. else:
  1758. module._load_from_state_dict(
  1759. obs_dict, prefix, {}, False, missing_keys, unexpected_keys, []
  1760. )
  1761. for k in missing_keys:
  1762. if "observer" in k or "activation_post_process" in k:
  1763. raise Exception( # noqa: TRY002
  1764. f"Missing keys for observer {k} in state_dict"
  1765. )
  1766. for k in unexpected_keys:
  1767. if "observer" in k or "activation_post_process" in k:
  1768. raise Exception( # noqa: TRY002
  1769. f"Unexpected keys for observer {k} in state_dict"
  1770. )
  1771. # Restrict activations to be in the range (0,127)
  1772. default_observer = MinMaxObserver.with_args(quant_min=0, quant_max=127)
  1773. """
  1774. Default observer for static quantization, usually used for debugging.
  1775. """
  1776. default_placeholder_observer = PlaceholderObserver
  1777. """
  1778. Default placeholder observer, usually used for quantization to torch.float16.
  1779. """
  1780. default_debug_observer = RecordingObserver
  1781. """
  1782. Default debug-only observer.
  1783. """
  1784. default_weight_observer = MinMaxObserver.with_args(
  1785. dtype=torch.qint8, qscheme=torch.per_tensor_symmetric
  1786. )
  1787. """
  1788. Default weight observer.
  1789. """
  1790. weight_observer_range_neg_127_to_127 = MinMaxObserver.with_args(
  1791. dtype=torch.qint8,
  1792. qscheme=torch.per_tensor_symmetric,
  1793. quant_min=-127,
  1794. quant_max=127,
  1795. eps=2**-12,
  1796. )
  1797. """
  1798. Symmetric weight observer with the 8-bit values restricted to [-127, +127], excluding -128.
  1799. """
  1800. default_histogram_observer = HistogramObserver.with_args(quant_min=0, quant_max=127)
  1801. """
  1802. Default histogram observer, usually used for PTQ.
  1803. """
  1804. default_per_channel_weight_observer = PerChannelMinMaxObserver.with_args(
  1805. dtype=torch.qint8, qscheme=torch.per_channel_symmetric
  1806. )
  1807. """
  1808. Default per-channel weight observer, usually used on backends where per-channel
  1809. weight quantization is supported, such as `fbgemm`.
  1810. """
  1811. per_channel_weight_observer_range_neg_127_to_127 = PerChannelMinMaxObserver.with_args(
  1812. dtype=torch.qint8,
  1813. qscheme=torch.per_channel_symmetric,
  1814. quant_min=-127,
  1815. quant_max=127,
  1816. eps=2**-12,
  1817. )
  1818. """
  1819. Per-channel, symmetric weight observer with the 8-bit values restricted to [-127, +127], excluding -128.
  1820. """
  1821. default_dynamic_quant_observer = PlaceholderObserver.with_args(
  1822. dtype=torch.quint8,
  1823. quant_min=0,
  1824. quant_max=255,
  1825. is_dynamic=True,
  1826. )
  1827. """
  1828. Default observer for dynamic quantization.
  1829. """
  1830. default_float_qparams_observer = PerChannelMinMaxObserver.with_args(
  1831. dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0
  1832. )
  1833. """
  1834. Default observer for a floating point zero-point.
  1835. """
  1836. default_float_qparams_observer_4bit = PerChannelMinMaxObserver.with_args(
  1837. dtype=torch.quint4x2, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0
  1838. )
  1839. """
  1840. Default observer for a floating point zero-point and 4 bit activations.
  1841. """
  1842. # TODO(future PR): remove these defaults and enforce activation functions
  1843. # to explicitly specify their output range
  1844. default_fixed_qparams_range_neg1to1_observer = FixedQParamsObserver.with_args(
  1845. scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8, quant_min=0, quant_max=255
  1846. )
  1847. default_fixed_qparams_range_0to1_observer = FixedQParamsObserver.with_args(
  1848. scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8, quant_min=0, quant_max=255
  1849. )
  1850. # TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases
  1851. default_symmetric_fixed_qparams_observer = default_fixed_qparams_range_neg1to1_observer
  1852. default_affine_fixed_qparams_observer = default_fixed_qparams_range_0to1_observer
  1853. """
  1854. Default observers for fixed qparams operations.
  1855. """
  1856. default_reuse_input_observer = ReuseInputObserver
  1857. """
  1858. Default observer for operators like reshape that reuses the observer of input to
  1859. the operator
  1860. """