quant_modules.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819
  1. # Copyright 2021 The I-BERT Authors (Sehoon Kim, Amir Gholami, Zhewei Yao,
  2. # Michael Mahoney, Kurt Keutzer - UC Berkeley) and The HuggingFace Inc. team.
  3. # Copyright (c) 20121, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import decimal
  17. import numpy as np
  18. import torch
  19. from torch import nn
  20. from torch.autograd import Function
  21. from ...utils import logging
  22. logger = logging.get_logger(__name__)
  23. class QuantEmbedding(nn.Module):
  24. """
  25. Quantized version of `torch.nn.Embedding`. Adds quantization-specific arguments on top of `torch.nn.Embedding`.
  26. Args:
  27. weight_bit (`int`, *optional*, defaults to `8`):
  28. Bitwidth for the quantized weight.
  29. momentum (`float`, *optional*, defaults to `0.95`):
  30. Momentum for updating the activation quantization range.
  31. quant_mode (`bool`, *optional*, defaults to `False`):
  32. Whether or not the layer is quantized.
  33. """
  34. def __init__(
  35. self,
  36. num_embeddings,
  37. embedding_dim,
  38. padding_idx=None,
  39. max_norm=None,
  40. norm_type=2.0,
  41. scale_grad_by_freq=False,
  42. sparse=False,
  43. _weight=None,
  44. weight_bit=8,
  45. momentum=0.95,
  46. quant_mode=False,
  47. ):
  48. super().__init__()
  49. self.num_ = num_embeddings
  50. self.dim = embedding_dim
  51. self.padding_idx = padding_idx
  52. self.max_norm = max_norm
  53. self.norm_type = norm_type
  54. self.scale_grad_by_freq = scale_grad_by_freq
  55. self.sparse = sparse
  56. self.weight = nn.Parameter(torch.zeros([num_embeddings, embedding_dim]))
  57. self.register_buffer("weight_scaling_factor", torch.zeros(1))
  58. self.register_buffer("weight_integer", torch.zeros_like(self.weight))
  59. self.weight_bit = weight_bit
  60. self.momentum = momentum
  61. self.quant_mode = quant_mode
  62. self.percentile_mode = False
  63. self.weight_function = SymmetricQuantFunction.apply
  64. def forward(self, x, positions=None, incremental_state=None):
  65. if not self.quant_mode:
  66. return (
  67. nn.functional.embedding(
  68. x,
  69. self.weight,
  70. self.padding_idx,
  71. self.max_norm,
  72. self.norm_type,
  73. self.scale_grad_by_freq,
  74. self.sparse,
  75. ),
  76. None,
  77. )
  78. w = self.weight
  79. w_transform = w.data.detach()
  80. w_min = w_transform.min().expand(1)
  81. w_max = w_transform.max().expand(1)
  82. self.weight_scaling_factor = symmetric_linear_quantization_params(self.weight_bit, w_min, w_max, False)
  83. self.weight_integer = self.weight_function(
  84. self.weight, self.weight_bit, self.percentile_mode, self.weight_scaling_factor
  85. )
  86. emb_int = nn.functional.embedding(
  87. x,
  88. self.weight_integer,
  89. self.padding_idx,
  90. self.max_norm,
  91. self.norm_type,
  92. self.scale_grad_by_freq,
  93. self.sparse,
  94. )
  95. return emb_int * self.weight_scaling_factor, self.weight_scaling_factor
  96. class QuantAct(nn.Module):
  97. """
  98. Quantizes the given activation.
  99. Args:
  100. activation_bit (`int`):
  101. Bitwidth for the quantized activation.
  102. act_range_momentum (`float`, *optional*, defaults to `0.95`):
  103. Momentum for updating the activation quantization range.
  104. per_channel (`bool`, *optional*, defaults to `False`):
  105. Whether to or not use channel-wise quantization.
  106. channel_len (`int`, *optional*):
  107. Specify the channel length when set the *per_channel* True.
  108. quant_mode (`bool`, *optional*, defaults to `False`):
  109. Whether or not the layer is quantized.
  110. """
  111. def __init__(self, activation_bit, act_range_momentum=0.95, per_channel=False, channel_len=None, quant_mode=False):
  112. super().__init__()
  113. self.activation_bit = activation_bit
  114. self.act_range_momentum = act_range_momentum
  115. self.quant_mode = quant_mode
  116. self.per_channel = per_channel
  117. self.percentile = False
  118. self.act_function = SymmetricQuantFunction.apply
  119. if not self.per_channel:
  120. self.register_buffer("x_min", torch.zeros(1))
  121. self.register_buffer("x_max", torch.zeros(1))
  122. self.register_buffer("act_scaling_factor", torch.zeros(1))
  123. self.x_min -= 1e-5
  124. self.x_max += 1e-5
  125. else:
  126. raise NotImplementedError("per-channel mode is not currently supported for activation.")
  127. def __repr__(self):
  128. return (
  129. f"{self.__class__.__name__}(activation_bit={self.activation_bit}, "
  130. f"quant_mode: {self.quant_mode}, Act_min: {self.x_min.item():.2f}, "
  131. f"Act_max: {self.x_max.item():.2f})"
  132. )
  133. def forward(
  134. self,
  135. x,
  136. pre_act_scaling_factor=None,
  137. identity=None,
  138. identity_scaling_factor=None,
  139. specified_min=None,
  140. specified_max=None,
  141. ):
  142. x_act = x if identity is None else identity + x
  143. # collect running stats if training
  144. if self.training:
  145. assert not self.percentile, "percentile mode is not currently supported for activation."
  146. assert not self.per_channel, "per-channel mode is not currently supported for activation."
  147. x_min = x_act.data.min()
  148. x_max = x_act.data.max()
  149. assert x_max.isnan().sum() == 0 and x_min.isnan().sum() == 0, (
  150. "NaN detected when computing min/max of the activation"
  151. )
  152. # Initialization
  153. if self.x_min.min() > -1.1e-5 and self.x_max.max() < 1.1e-5:
  154. self.x_min = self.x_min + x_min
  155. self.x_max = self.x_max + x_max
  156. # exponential moving average (EMA)
  157. # use momentum to prevent the quantized values change greatly every iteration
  158. elif self.act_range_momentum == -1:
  159. self.x_min = torch.min(self.x_min, x_min)
  160. self.x_max = torch.max(self.x_max, x_max)
  161. else:
  162. self.x_min = self.x_min * self.act_range_momentum + x_min * (1 - self.act_range_momentum)
  163. self.x_max = self.x_max * self.act_range_momentum + x_max * (1 - self.act_range_momentum)
  164. if not self.quant_mode:
  165. return x_act, None
  166. x_min = self.x_min if specified_min is None else specified_min
  167. x_max = self.x_max if specified_max is None else specified_max
  168. self.act_scaling_factor = symmetric_linear_quantization_params(
  169. self.activation_bit, x_min, x_max, per_channel=self.per_channel
  170. )
  171. if pre_act_scaling_factor is None:
  172. # this is for the input quantization
  173. quant_act_int = self.act_function(x, self.activation_bit, self.percentile, self.act_scaling_factor)
  174. else:
  175. quant_act_int = FixedPointMul.apply(
  176. x,
  177. pre_act_scaling_factor,
  178. self.activation_bit,
  179. self.act_scaling_factor,
  180. identity,
  181. identity_scaling_factor,
  182. )
  183. correct_output_scale = self.act_scaling_factor.view(-1)
  184. return quant_act_int * correct_output_scale, self.act_scaling_factor
  185. class QuantLinear(nn.Module):
  186. """
  187. Quantized version of `torch.nn.Linear`. Adds quantization-specific arguments on top of `torch.nn.Linear`.
  188. Args:
  189. weight_bit (`int`, *optional*, defaults to `8`):
  190. Bitwidth for the quantized weight.
  191. bias_bit (`int`, *optional*, defaults to `32`):
  192. Bitwidth for the quantized bias.
  193. per_channel (`bool`, *optional*, defaults to `False`):
  194. Whether or not to use channel-wise quantization.
  195. quant_mode (`bool`, *optional*, defaults to `False`):
  196. Whether or not the layer is quantized.
  197. """
  198. def __init__(
  199. self, in_features, out_features, bias=True, weight_bit=8, bias_bit=32, per_channel=False, quant_mode=False
  200. ):
  201. super().__init__()
  202. self.in_features = in_features
  203. self.out_features = out_features
  204. self.weight = nn.Parameter(torch.zeros([out_features, in_features]))
  205. self.register_buffer("weight_integer", torch.zeros_like(self.weight))
  206. self.register_buffer("fc_scaling_factor", torch.zeros(self.out_features))
  207. if bias:
  208. self.bias = nn.Parameter(torch.zeros(out_features))
  209. self.register_buffer("bias_integer", torch.zeros_like(self.bias))
  210. self.weight_bit = weight_bit
  211. self.quant_mode = quant_mode
  212. self.per_channel = per_channel
  213. self.bias_bit = bias_bit
  214. self.quant_mode = quant_mode
  215. self.percentile_mode = False
  216. self.weight_function = SymmetricQuantFunction.apply
  217. def __repr__(self):
  218. s = super().__repr__()
  219. s = f"({s} weight_bit={self.weight_bit}, quant_mode={self.quant_mode})"
  220. return s
  221. def forward(self, x, prev_act_scaling_factor=None):
  222. if not self.quant_mode:
  223. return nn.functional.linear(x, weight=self.weight, bias=self.bias), None
  224. # assert that prev_act_scaling_factor is a scalar tensor
  225. assert prev_act_scaling_factor is not None and prev_act_scaling_factor.shape == (1,), (
  226. "Input activation to the QuantLinear layer should be globally (non-channel-wise) quantized. "
  227. "Please add a QuantAct layer with `per_channel = True` before this QuantAct layer"
  228. )
  229. w = self.weight
  230. w_transform = w.data.detach()
  231. if self.per_channel:
  232. w_min, _ = torch.min(w_transform, dim=1, out=None)
  233. w_max, _ = torch.max(w_transform, dim=1, out=None)
  234. else:
  235. w_min = w_transform.min().expand(1)
  236. w_max = w_transform.max().expand(1)
  237. self.fc_scaling_factor = symmetric_linear_quantization_params(self.weight_bit, w_min, w_max, self.per_channel)
  238. self.weight_integer = self.weight_function(
  239. self.weight, self.weight_bit, self.percentile_mode, self.fc_scaling_factor
  240. )
  241. bias_scaling_factor = self.fc_scaling_factor * prev_act_scaling_factor
  242. if self.bias is not None:
  243. self.bias_integer = self.weight_function(self.bias, self.bias_bit, False, bias_scaling_factor)
  244. prev_act_scaling_factor = prev_act_scaling_factor.view(1, -1)
  245. x_int = x / prev_act_scaling_factor
  246. return (
  247. nn.functional.linear(x_int, weight=self.weight_integer, bias=self.bias_integer) * bias_scaling_factor,
  248. bias_scaling_factor,
  249. )
  250. class IntGELU(nn.Module):
  251. """
  252. Quantized version of `torch.nn.GELU`. Adds quantization-specific arguments on top of `torch.nn.GELU`.
  253. Args:
  254. quant_mode (`bool`, *optional*, defaults to `False`):
  255. Whether or not the layer is quantized.
  256. force_dequant (`str`, *optional*, defaults to `"none"`):
  257. Force dequantize the layer if either "gelu" or "nonlinear" is given.
  258. """
  259. def __init__(self, quant_mode=True, force_dequant="none"):
  260. super().__init__()
  261. self.quant_mode = quant_mode
  262. if force_dequant in ["nonlinear", "gelu"]:
  263. logger.info("Force dequantize gelu")
  264. self.quant_mode = False
  265. if not self.quant_mode:
  266. self.activation_fn = nn.GELU()
  267. self.k = 1.4142
  268. self.const = 14 # dummy integer constant
  269. self.coeff = [-0.2888, -1.769, 1] # a(x+b)**2 + c
  270. self.coeff[2] /= self.coeff[0]
  271. def int_erf(self, x_int, scaling_factor):
  272. b_int = torch.floor(self.coeff[1] / scaling_factor)
  273. c_int = torch.floor(self.coeff[2] / scaling_factor**2)
  274. sign = torch.sign(x_int)
  275. abs_int = torch.min(torch.abs(x_int), -b_int)
  276. y_int = sign * ((abs_int + b_int) ** 2 + c_int)
  277. scaling_factor = scaling_factor**2 * self.coeff[0]
  278. # avoid overflow
  279. y_int = floor_ste.apply(y_int / 2**self.const)
  280. scaling_factor = scaling_factor * 2**self.const
  281. return y_int, scaling_factor
  282. def forward(self, x, scaling_factor=None):
  283. if not self.quant_mode:
  284. return self.activation_fn(x), None
  285. x_int = x / scaling_factor
  286. sigmoid_int, sigmoid_scaling_factor = self.int_erf(x_int, scaling_factor / self.k)
  287. shift_int = 1.0 // sigmoid_scaling_factor
  288. x_int = x_int * (sigmoid_int + shift_int)
  289. scaling_factor = scaling_factor * sigmoid_scaling_factor / 2
  290. return x_int * scaling_factor, scaling_factor
  291. class IntSoftmax(nn.Module):
  292. """
  293. Quantized version of `torch.nn.Softmax`. Adds quantization-specific arguments on top of `torch.nn.Softmax`.
  294. Args:
  295. output_bit (`int`):
  296. Bitwidth for the layer output activation.
  297. quant_mode (`bool`, *optional*, defaults to `False`):
  298. Whether or not the layer is quantized.
  299. force_dequant (`str`, *optional*, defaults to `"none"`):
  300. Force dequantize the layer if either "softmax" or "nonlinear" is given.
  301. """
  302. def __init__(self, output_bit, quant_mode=False, force_dequant="none"):
  303. super().__init__()
  304. self.output_bit = output_bit
  305. self.max_bit = 32
  306. self.quant_mode = quant_mode
  307. if force_dequant in ["nonlinear", "softmax"]:
  308. logger.info("Force dequantize softmax")
  309. self.quant_mode = False
  310. self.act = QuantAct(16, quant_mode=self.quant_mode)
  311. self.x0 = -0.6931 # -ln2
  312. self.const = 30 # dummy integer constant
  313. self.coef = [0.35815147, 0.96963238, 1.0] # ax**2 + bx + c
  314. self.coef[1] /= self.coef[0]
  315. self.coef[2] /= self.coef[0]
  316. def int_polynomial(self, x_int, scaling_factor):
  317. with torch.no_grad():
  318. b_int = torch.floor(self.coef[1] / scaling_factor)
  319. c_int = torch.floor(self.coef[2] / scaling_factor**2)
  320. z = (x_int + b_int) * x_int + c_int
  321. scaling_factor = self.coef[0] * scaling_factor**2
  322. return z, scaling_factor
  323. def int_exp(self, x_int, scaling_factor):
  324. with torch.no_grad():
  325. x0_int = torch.floor(self.x0 / scaling_factor)
  326. x_int = torch.max(x_int, self.const * x0_int)
  327. q = floor_ste.apply(x_int / x0_int)
  328. r = x_int - x0_int * q
  329. exp_int, exp_scaling_factor = self.int_polynomial(r, scaling_factor)
  330. exp_int = torch.clamp(floor_ste.apply(exp_int * 2 ** (self.const - q)), min=0)
  331. scaling_factor = exp_scaling_factor / 2**self.const
  332. return exp_int, scaling_factor
  333. def forward(self, x, scaling_factor):
  334. if not self.quant_mode:
  335. return nn.functional.softmax(x, dim=-1), None
  336. x_int = x / scaling_factor
  337. x_int_max, _ = x_int.max(dim=-1, keepdim=True)
  338. x_int = x_int - x_int_max
  339. exp_int, exp_scaling_factor = self.int_exp(x_int, scaling_factor)
  340. # Avoid overflow
  341. exp, exp_scaling_factor = self.act(exp_int, exp_scaling_factor)
  342. exp_int = exp / exp_scaling_factor
  343. exp_int_sum = exp_int.sum(dim=-1, keepdim=True)
  344. factor = floor_ste.apply(2**self.max_bit / exp_int_sum)
  345. exp_int = floor_ste.apply(exp_int * factor / 2 ** (self.max_bit - self.output_bit))
  346. scaling_factor = 1 / 2**self.output_bit
  347. return exp_int * scaling_factor, scaling_factor
  348. class IntLayerNorm(nn.Module):
  349. """
  350. Quantized version of `torch.nn.LayerNorm`. Adds quantization-specific arguments on top of `torch.nn.LayerNorm`.
  351. Args:
  352. output_bit (`int`, *optional*, defaults to `8`):
  353. Bitwidth for the layer output activation.
  354. quant_mode (`bool`, *optional*, defaults to `False`):
  355. Whether or not the layer is quantized.
  356. force_dequant (`str`, *optional*, defaults to `"none"`):
  357. Force dequantize the layer if either "layernorm" or "nonlinear" is given.
  358. """
  359. def __init__(self, normalized_shape, eps, output_bit=8, quant_mode=False, force_dequant="none"):
  360. super().__init__()
  361. self.normalized_shape = normalized_shape
  362. self.eps = eps
  363. self.weight = nn.Parameter(torch.zeros(normalized_shape))
  364. self.bias = nn.Parameter(torch.zeros(normalized_shape))
  365. self.quant_mode = quant_mode
  366. if force_dequant in ["nonlinear", "layernorm"]:
  367. logger.info("Force dequantize layernorm")
  368. self.quant_mode = False
  369. self.register_buffer("shift", torch.zeros(1))
  370. self.output_bit = output_bit
  371. self.max_bit = 32
  372. self.dim_sqrt = None
  373. self.activation = QuantAct(self.output_bit, quant_mode=self.quant_mode)
  374. def set_shift(self, y_int):
  375. with torch.no_grad():
  376. y_sq_int = y_int**2
  377. var_int = torch.sum(y_sq_int, axis=2, keepdim=True)
  378. shift = (torch.log2(torch.sqrt(var_int / 2**self.max_bit)).ceil()).max()
  379. shift_old = self.shift
  380. self.shift = torch.max(self.shift, shift)
  381. logger.info(f"Dynamic shift adjustment: {int(shift_old)} -> {int(self.shift)}")
  382. def overflow_fallback(self, y_int):
  383. """
  384. This fallback function is called when overflow is detected during training time, and adjusts the `self.shift`
  385. to avoid overflow in the subsequent runs.
  386. """
  387. self.set_shift(y_int) # adjusts `self.shift`
  388. y_int_shifted = floor_ste.apply(y_int / 2**self.shift)
  389. y_sq_int = y_int_shifted**2
  390. var_int = torch.sum(y_sq_int, axis=2, keepdim=True)
  391. return var_int
  392. def forward(self, x, scaling_factor=None):
  393. if not self.quant_mode:
  394. mean = x.mean(axis=2, keepdim=True)
  395. y = x - mean
  396. var = torch.mean(y**2, axis=2, keepdim=True)
  397. x = y / torch.sqrt(self.eps + var)
  398. x = x * self.weight + self.bias
  399. return x, None
  400. # compute sqrt of the feature dimension if it is the first run
  401. if self.dim_sqrt is None:
  402. n = torch.tensor(x.shape[2], dtype=torch.float)
  403. self.dim_sqrt = torch.sqrt(n).to(x.device)
  404. # Normalization: computes mean and variance(std)
  405. x_int = x / scaling_factor
  406. mean_int = round_ste.apply(x_int.mean(axis=2, keepdim=True))
  407. y_int = x_int - mean_int
  408. y_int_shifted = floor_ste.apply(y_int / 2**self.shift)
  409. y_sq_int = y_int_shifted**2
  410. var_int = torch.sum(y_sq_int, axis=2, keepdim=True)
  411. # overflow handling in training time
  412. if self.training:
  413. # if overflow is detected
  414. if var_int.max() >= 2**self.max_bit:
  415. var_int = self.overflow_fallback(y_int)
  416. assert var_int.max() < 2**self.max_bit + 0.1, (
  417. "Error detected in overflow handling: "
  418. "`var_int` exceeds `self.max_bit` (the maximum possible bit width)"
  419. )
  420. # To be replaced with integer-sqrt kernel that produces the same output
  421. std_int = floor_ste.apply(torch.sqrt(var_int)) * 2**self.shift
  422. factor = floor_ste.apply(2**31 / std_int)
  423. y_int = floor_ste.apply(y_int * factor / 2)
  424. scaling_factor = self.dim_sqrt / 2**30
  425. # scaling and shifting
  426. bias = self.bias.data.detach() / (self.weight.data.detach())
  427. bias_int = floor_ste.apply(bias / scaling_factor)
  428. y_int = y_int + bias_int
  429. scaling_factor = scaling_factor * self.weight
  430. x = y_int * scaling_factor
  431. return x, scaling_factor
  432. def get_percentile_min_max(input, lower_percentile, upper_percentile, output_tensor=False):
  433. """
  434. Calculate the percentile max and min values in a given tensor
  435. Args:
  436. input (`torch.Tensor`):
  437. The target tensor to calculate percentile max and min.
  438. lower_percentile (`float`):
  439. If 0.1, means we return the value of the smallest 0.1% value in the tensor as percentile min.
  440. upper_percentile (`float`):
  441. If 99.9, means we return the value of the largest 0.1% value in the tensor as percentile max.
  442. output_tensor (`bool`, *optional*, defaults to `False`):
  443. If True, this function returns tensors, otherwise it returns values.
  444. Returns:
  445. `Tuple(torch.Tensor, torch.Tensor)`: Percentile min and max value of *input*
  446. """
  447. input_length = input.shape[0]
  448. lower_index = round(input_length * (1 - lower_percentile * 0.01))
  449. upper_index = round(input_length * upper_percentile * 0.01)
  450. upper_bound = torch.kthvalue(input, k=upper_index).values
  451. if lower_percentile == 0:
  452. lower_bound = upper_bound * 0
  453. # lower_index += 1
  454. else:
  455. lower_bound = -torch.kthvalue(-input, k=lower_index).values
  456. if not output_tensor:
  457. lower_bound = lower_bound.item()
  458. upper_bound = upper_bound.item()
  459. return lower_bound, upper_bound
  460. def linear_quantize(input, scale, zero_point, inplace=False):
  461. """
  462. Quantize single-precision input tensor to integers with the given scaling factor and zeropoint.
  463. Args:
  464. input (`torch.Tensor`):
  465. Single-precision input tensor to be quantized.
  466. scale (`torch.Tensor`):
  467. Scaling factor for quantization.
  468. zero_pint (`torch.Tensor`):
  469. Shift for quantization.
  470. inplace (`bool`, *optional*, defaults to `False`):
  471. Whether to compute inplace or not.
  472. Returns:
  473. `torch.Tensor`: Linearly quantized value of *input* according to *scale* and *zero_point*.
  474. """
  475. # reshape scale and zeropoint for convolutional weights and activation
  476. if len(input.shape) == 4:
  477. scale = scale.view(-1, 1, 1, 1)
  478. zero_point = zero_point.view(-1, 1, 1, 1)
  479. # reshape scale and zeropoint for linear weights
  480. elif len(input.shape) == 2:
  481. scale = scale.view(-1, 1)
  482. zero_point = zero_point.view(-1, 1)
  483. else:
  484. scale = scale.view(-1)
  485. zero_point = zero_point.view(-1)
  486. # quantized = float / scale + zero_point
  487. if inplace:
  488. input.mul_(1.0 / scale).add_(zero_point).round_()
  489. return input
  490. return torch.round(1.0 / scale * input + zero_point)
  491. def symmetric_linear_quantization_params(num_bits, saturation_min, saturation_max, per_channel=False):
  492. """
  493. Compute the scaling factor with the given quantization range for symmetric quantization.
  494. Args:
  495. saturation_min (`torch.Tensor`):
  496. Lower bound for quantization range.
  497. saturation_max (`torch.Tensor`):
  498. Upper bound for quantization range.
  499. per_channel (`bool`, *optional*, defaults to `False`):
  500. Whether to or not use channel-wise quantization.
  501. Returns:
  502. `torch.Tensor`: Scaling factor that linearly quantizes the given range between *saturation_min* and
  503. *saturation_max*.
  504. """
  505. # in this part, we do not need any gradient computation,
  506. # in order to enforce this, we put torch.no_grad()
  507. with torch.no_grad():
  508. n = 2 ** (num_bits - 1) - 1
  509. if per_channel:
  510. scale, _ = torch.max(torch.stack([saturation_min.abs(), saturation_max.abs()], dim=1), dim=1)
  511. scale = torch.clamp(scale, min=1e-8) / n
  512. else:
  513. scale = max(saturation_min.abs(), saturation_max.abs())
  514. scale = torch.clamp(scale, min=1e-8) / n
  515. return scale
  516. class SymmetricQuantFunction(Function):
  517. """
  518. Class to quantize the given floating-point values using symmetric quantization with given range and bitwidth.
  519. """
  520. @staticmethod
  521. def forward(ctx, x, k, percentile_mode, scale):
  522. """
  523. Args:
  524. x (`torch.Tensor`):
  525. Floating point tensor to be quantized.
  526. k (`int`):
  527. Quantization bitwidth.
  528. percentile_mode (`bool`):
  529. Whether or not to use percentile calibration.
  530. scale (`torch.Tensor`):
  531. Pre-calculated scaling factor for *x*. Note that the current implementation of SymmetricQuantFunction
  532. requires pre-calculated scaling factor.
  533. Returns:
  534. `torch.Tensor`: Symmetric-quantized value of *input*.
  535. """
  536. zero_point = torch.tensor(0.0, device=scale.device)
  537. n = 2 ** (k - 1) - 1
  538. new_quant_x = linear_quantize(x, scale, zero_point, inplace=False)
  539. new_quant_x = torch.clamp(new_quant_x, -n, n - 1)
  540. ctx.scale = scale
  541. return new_quant_x
  542. @staticmethod
  543. def backward(ctx, grad_output):
  544. scale = ctx.scale
  545. if len(grad_output.shape) == 4:
  546. scale = scale.view(-1, 1, 1, 1)
  547. # reshape scale and zeropoint for linear weights
  548. elif len(grad_output.shape) == 2:
  549. scale = scale.view(-1, 1)
  550. else:
  551. scale = scale.view(-1)
  552. return grad_output.clone() / scale, None, None, None, None
  553. class floor_ste(Function):
  554. """
  555. Straight-through Estimator(STE) for torch.floor()
  556. """
  557. @staticmethod
  558. def forward(ctx, x):
  559. return torch.floor(x)
  560. @staticmethod
  561. def backward(ctx, grad_output):
  562. return grad_output.clone()
  563. class round_ste(Function):
  564. """
  565. Straight-through Estimator(STE) for torch.round()
  566. """
  567. @staticmethod
  568. def forward(ctx, x):
  569. return torch.round(x)
  570. @staticmethod
  571. def backward(ctx, grad_output):
  572. return grad_output.clone()
  573. def batch_frexp(inputs, max_bit=31):
  574. """
  575. Decompose the scaling factor into mantissa and twos exponent.
  576. Args:
  577. scaling_factor (`torch.Tensor`):
  578. Target scaling factor to decompose.
  579. Returns:
  580. ``Tuple(torch.Tensor, torch.Tensor)`: mantisa and exponent
  581. """
  582. shape_of_input = inputs.size()
  583. # trans the input to be a 1-d tensor
  584. inputs = inputs.view(-1)
  585. output_m, output_e = np.frexp(inputs.cpu().numpy())
  586. tmp_m = []
  587. for m in output_m:
  588. int_m_shifted = int(
  589. decimal.Decimal(m * (2**max_bit)).quantize(decimal.Decimal(1), rounding=decimal.ROUND_HALF_UP)
  590. )
  591. tmp_m.append(int_m_shifted)
  592. output_m = np.array(tmp_m)
  593. output_e = float(max_bit) - output_e
  594. return (
  595. torch.from_numpy(output_m).to(inputs.device).view(shape_of_input),
  596. torch.from_numpy(output_e).to(inputs.device).view(shape_of_input),
  597. )
  598. class FixedPointMul(Function):
  599. """
  600. Function to perform fixed-point arithmetic that can match integer arithmetic on hardware.
  601. Args:
  602. pre_act (`torch.Tensor`):
  603. Input tensor.
  604. pre_act_scaling_factor (`torch.Tensor`):
  605. Scaling factor of the input tensor *pre_act*.
  606. bit_num (`int`):
  607. Quantization bitwidth.
  608. z_scaling_factor (`torch.Tensor`):
  609. Scaling factor of the output tensor.
  610. identity (`torch.Tensor`, *optional*):
  611. Identity tensor, if exists.
  612. identity_scaling_factor (`torch.Tensor`, *optional*):
  613. Scaling factor of the identity tensor *identity*, if exists.
  614. Returns:
  615. `torch.Tensor`: Output tensor(*pre_act* if *identity* is not given, otherwise the addition of *pre_act* and
  616. *identity*), whose scale is rescaled to *z_scaling_factor*.
  617. """
  618. @staticmethod
  619. def forward(
  620. ctx,
  621. pre_act,
  622. pre_act_scaling_factor,
  623. bit_num,
  624. z_scaling_factor,
  625. identity=None,
  626. identity_scaling_factor=None,
  627. ):
  628. if len(pre_act_scaling_factor.shape) == 3:
  629. reshape = lambda x: x # noqa: E731
  630. else:
  631. reshape = lambda x: x.view(1, 1, -1) # noqa: E731
  632. ctx.identity = identity
  633. n = 2 ** (bit_num - 1) - 1
  634. with torch.no_grad():
  635. pre_act_scaling_factor = reshape(pre_act_scaling_factor)
  636. if identity is not None:
  637. identity_scaling_factor = reshape(identity_scaling_factor)
  638. ctx.z_scaling_factor = z_scaling_factor
  639. z_int = torch.round(pre_act / pre_act_scaling_factor)
  640. _A = pre_act_scaling_factor.type(torch.double)
  641. _B = (z_scaling_factor.type(torch.float)).type(torch.double)
  642. new_scale = _A / _B
  643. new_scale = reshape(new_scale)
  644. m, e = batch_frexp(new_scale)
  645. output = z_int.type(torch.double) * m.type(torch.double)
  646. output = torch.round(output / (2.0**e))
  647. if identity is not None:
  648. # needs addition of identity activation
  649. wx_int = torch.round(identity / identity_scaling_factor)
  650. _A = identity_scaling_factor.type(torch.double)
  651. _B = (z_scaling_factor.type(torch.float)).type(torch.double)
  652. new_scale = _A / _B
  653. new_scale = reshape(new_scale)
  654. m1, e1 = batch_frexp(new_scale)
  655. output1 = wx_int.type(torch.double) * m1.type(torch.double)
  656. output1 = torch.round(output1 / (2.0**e1))
  657. output = output1 + output
  658. return torch.clamp(output.type(torch.float), -n - 1, n)
  659. @staticmethod
  660. def backward(ctx, grad_output):
  661. identity_grad = None
  662. if ctx.identity is not None:
  663. identity_grad = grad_output.clone() / ctx.z_scaling_factor
  664. return grad_output.clone() / ctx.z_scaling_factor, None, None, None, None, identity_grad, None