functional.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782
  1. # mypy: allow-untyped-defs
  2. r"""Functional interface (quantized)."""
  3. import warnings
  4. import torch
  5. from torch import Tensor
  6. from torch.jit.annotations import BroadcastingList2
  7. from torch.nn.modules.utils import _pair, _triple
  8. from .modules.utils import _pair_from_first
  9. # Although some of the functions and docstrings are mirrored from the torch.nn,
  10. # we want to have them here for future changes.
  11. __all__ = [
  12. "avg_pool2d",
  13. "avg_pool3d",
  14. "adaptive_avg_pool2d",
  15. "adaptive_avg_pool3d",
  16. "conv1d",
  17. "conv2d",
  18. "conv3d",
  19. "interpolate",
  20. "linear",
  21. "max_pool1d",
  22. "max_pool2d",
  23. "celu",
  24. "leaky_relu",
  25. "hardtanh",
  26. "hardswish",
  27. "threshold",
  28. "elu",
  29. "hardsigmoid",
  30. "clamp",
  31. "upsample",
  32. "upsample_bilinear",
  33. "upsample_nearest",
  34. ]
  35. def avg_pool2d(
  36. input,
  37. kernel_size,
  38. stride=None,
  39. padding=0,
  40. ceil_mode=False,
  41. count_include_pad=True,
  42. divisor_override=None,
  43. ):
  44. r"""
  45. Applies 2D average-pooling operation in :math:`kH \times kW` regions by step size
  46. :math:`sH \times sW` steps. The number of output features is equal to the number of
  47. input planes.
  48. .. note:: The input quantization parameters propagate to the output.
  49. See :class:`~torch.ao.nn.quantized.AvgPool2d` for details and output shape.
  50. Args:
  51. input: quantized input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
  52. kernel_size: size of the pooling region. Can be a single number or a
  53. tuple `(kH, kW)`
  54. stride: stride of the pooling operation. Can be a single number or a
  55. tuple `(sH, sW)`. Default: :attr:`kernel_size`
  56. padding: implicit zero paddings on both sides of the input. Can be a
  57. single number or a tuple `(padH, padW)`. Default: 0
  58. ceil_mode: when True, will use `ceil` instead of `floor` in the formula
  59. to compute the output shape. Default: ``False``
  60. count_include_pad: when True, will include the zero-padding in the
  61. averaging calculation. Default: ``True``
  62. divisor_override: if specified, it will be used as divisor, otherwise
  63. size of the pooling region will be used. Default: None
  64. """
  65. if not input.is_quantized:
  66. raise ValueError("Input to 'quantized.avg_pool2d' must be quantized!")
  67. return torch.nn.functional.avg_pool2d(
  68. input,
  69. kernel_size,
  70. stride,
  71. padding,
  72. ceil_mode,
  73. count_include_pad,
  74. divisor_override,
  75. )
  76. def avg_pool3d(
  77. input,
  78. kernel_size,
  79. stride=None,
  80. padding=0,
  81. ceil_mode=False,
  82. count_include_pad=True,
  83. divisor_override=None,
  84. ):
  85. r"""
  86. Applies 3D average-pooling operation in :math:`kD \ times kH \times kW` regions by step size
  87. :math:`sD \times sH \times sW` steps. The number of output features is equal to the number of
  88. input planes.
  89. .. note:: The input quantization parameters propagate to the output.
  90. Args:
  91. input: quantized input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
  92. kernel_size: size of the pooling region. Can be a single number or a
  93. tuple `(kD, kH, kW)`
  94. stride: stride of the pooling operation. Can be a single number or a
  95. tuple `(sD, sH, sW)`. Default: :attr:`kernel_size`
  96. padding: implicit zero paddings on both sides of the input. Can be a
  97. single number or a tuple `(padD, padH, padW)`. Default: 0
  98. ceil_mode: when True, will use `ceil` instead of `floor` in the formula
  99. to compute the output shape. Default: ``False``
  100. count_include_pad: when True, will include the zero-padding in the
  101. averaging calculation. Default: ``True``
  102. divisor_override: if specified, it will be used as divisor, otherwise
  103. size of the pooling region will be used. Default: None
  104. """
  105. if not input.is_quantized:
  106. raise ValueError("Input to 'quantized.avg_pool3d' must be quantized!")
  107. return torch.nn.functional.avg_pool3d(
  108. input,
  109. kernel_size,
  110. stride,
  111. padding,
  112. ceil_mode,
  113. count_include_pad,
  114. divisor_override,
  115. )
  116. def adaptive_avg_pool2d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
  117. r"""
  118. Applies a 2D adaptive average pooling over a quantized input signal composed
  119. of several quantized input planes.
  120. .. note:: The input quantization parameters propagate to the output.
  121. See :class:`~torch.ao.nn.quantized.AdaptiveAvgPool2d` for details and output shape.
  122. Args:
  123. output_size: the target output size (single integer or
  124. double-integer tuple)
  125. """
  126. if not input.is_quantized:
  127. raise ValueError(
  128. "Input to 'quantized.functional.adaptive_avg_pool2d' must be quantized!"
  129. )
  130. return torch.nn.functional.adaptive_avg_pool2d(input, output_size)
  131. def adaptive_avg_pool3d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
  132. r"""
  133. Applies a 3D adaptive average pooling over a quantized input signal composed
  134. of several quantized input planes.
  135. .. note:: The input quantization parameters propagate to the output.
  136. See :class:`~torch.ao.nn.quantized.AdaptiveAvgPool3d` for details and output shape.
  137. Args:
  138. output_size: the target output size (single integer or
  139. double-integer tuple)
  140. """
  141. if not input.is_quantized:
  142. raise ValueError(
  143. "Input to 'quantized.functional.adaptive_avg_pool3d' must be quantized!"
  144. )
  145. return torch.nn.functional.adaptive_avg_pool3d(input, output_size)
  146. def conv1d(
  147. input,
  148. weight,
  149. bias,
  150. stride=1,
  151. padding=0,
  152. dilation=1,
  153. groups=1,
  154. padding_mode="zeros",
  155. scale=1.0,
  156. zero_point=0,
  157. dtype=torch.quint8,
  158. ):
  159. r"""
  160. Applies a 1D convolution over a quantized 1D input composed of several input
  161. planes.
  162. See :class:`~torch.ao.nn.quantized.Conv1d` for details and output shape.
  163. Args:
  164. input: quantized input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`
  165. weight: quantized filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , iW)`
  166. bias: **non-quantized** bias tensor of shape :math:`(\text{out\_channels})`. The tensor type must be `torch.float`.
  167. stride: the stride of the convolving kernel. Can be a single number or a
  168. tuple `(sW,)`. Default: 1
  169. padding: implicit paddings on both sides of the input. Can be a
  170. single number or a tuple `(padW,)`. Default: 0
  171. dilation: the spacing between kernel elements. Can be a single number or
  172. a tuple `(dW,)`. Default: 1
  173. groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
  174. number of groups. Default: 1
  175. padding_mode: the padding mode to use. Only "zeros" is supported for quantized convolution at the moment. Default: "zeros"
  176. scale: quantization scale for the output. Default: 1.0
  177. zero_point: quantization zero_point for the output. Default: 0
  178. dtype: quantization data type to use. Default: ``torch.quint8``
  179. Examples::
  180. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
  181. >>> from torch.ao.nn.quantized import functional as qF
  182. >>> filters = torch.randn(33, 16, 3, dtype=torch.float)
  183. >>> inputs = torch.randn(20, 16, 50, dtype=torch.float)
  184. >>> bias = torch.randn(33, dtype=torch.float)
  185. >>>
  186. >>> scale, zero_point = 1.0, 0
  187. >>> dtype_inputs = torch.quint8
  188. >>> dtype_filters = torch.qint8
  189. >>>
  190. >>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters)
  191. >>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs)
  192. >>> qF.conv1d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point)
  193. """ # noqa: E501
  194. if padding_mode != "zeros":
  195. raise NotImplementedError("Only zero-padding is supported!")
  196. if input.dtype != torch.quint8:
  197. raise NotImplementedError(
  198. "Only torch.quint8 is supported for activation tensor!"
  199. )
  200. if weight.dtype != torch.qint8:
  201. raise NotImplementedError("Only torch.qint8 is supported for weight tensor!")
  202. if input.ndim != 3:
  203. raise ValueError("Input shape must be `(N, C, L)`!")
  204. stride = _pair_from_first(stride)
  205. padding = _pair_from_first(padding)
  206. dilation = _pair_from_first(dilation)
  207. packed_params = torch.ops.quantized.conv1d_prepack(
  208. weight, bias, stride, padding, dilation, groups
  209. )
  210. return torch.ops.quantized.conv1d(input, packed_params, scale, zero_point)
  211. def conv2d(
  212. input,
  213. weight,
  214. bias,
  215. stride=1,
  216. padding=0,
  217. dilation=1,
  218. groups=1,
  219. padding_mode="zeros",
  220. scale=1.0,
  221. zero_point=0,
  222. dtype=torch.quint8,
  223. ):
  224. r"""
  225. Applies a 2D convolution over a quantized 2D input composed of several input
  226. planes.
  227. See :class:`~torch.ao.nn.quantized.Conv2d` for details and output shape.
  228. Args:
  229. input: quantized input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
  230. weight: quantized filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW)`
  231. bias: **non-quantized** bias tensor of shape :math:`(\text{out\_channels})`. The tensor type must be `torch.float`.
  232. stride: the stride of the convolving kernel. Can be a single number or a
  233. tuple `(sH, sW)`. Default: 1
  234. padding: implicit paddings on both sides of the input. Can be a
  235. single number or a tuple `(padH, padW)`. Default: 0
  236. dilation: the spacing between kernel elements. Can be a single number or
  237. a tuple `(dH, dW)`. Default: 1
  238. groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
  239. number of groups. Default: 1
  240. padding_mode: the padding mode to use. Only "zeros" is supported for quantized convolution at the moment. Default: "zeros"
  241. scale: quantization scale for the output. Default: 1.0
  242. zero_point: quantization zero_point for the output. Default: 0
  243. dtype: quantization data type to use. Default: ``torch.quint8``
  244. Examples::
  245. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
  246. >>> from torch.ao.nn.quantized import functional as qF
  247. >>> filters = torch.randn(8, 4, 3, 3, dtype=torch.float)
  248. >>> inputs = torch.randn(1, 4, 5, 5, dtype=torch.float)
  249. >>> bias = torch.randn(8, dtype=torch.float)
  250. >>>
  251. >>> scale, zero_point = 1.0, 0
  252. >>> dtype_inputs = torch.quint8
  253. >>> dtype_filters = torch.qint8
  254. >>>
  255. >>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters)
  256. >>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs)
  257. >>> qF.conv2d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point)
  258. """ # noqa: E501
  259. if padding_mode != "zeros":
  260. raise NotImplementedError("Only zero-padding is supported!")
  261. if input.dtype != torch.quint8:
  262. raise NotImplementedError(
  263. "Only torch.quint8 is supported for activation tensor!"
  264. )
  265. if weight.dtype != torch.qint8:
  266. raise NotImplementedError("Only torch.qint8 is supported for weight tensor!")
  267. if input.ndim != 4:
  268. raise ValueError("Input shape must be `(N, C, H, W)`!")
  269. stride = _pair(stride)
  270. padding = _pair(padding)
  271. dilation = _pair(dilation)
  272. packed_params = torch.ops.quantized.conv2d_prepack(
  273. weight, bias, stride, padding, dilation, groups
  274. )
  275. return torch.ops.quantized.conv2d(input, packed_params, scale, zero_point)
  276. def conv3d(
  277. input,
  278. weight,
  279. bias,
  280. stride=1,
  281. padding=0,
  282. dilation=1,
  283. groups=1,
  284. padding_mode="zeros",
  285. scale=1.0,
  286. zero_point=0,
  287. dtype=torch.quint8,
  288. ):
  289. r"""
  290. Applies a 3D convolution over a quantized 3D input composed of several input
  291. planes.
  292. See :class:`~torch.ao.nn.quantized.Conv3d` for details and output shape.
  293. Args:
  294. input: quantized input tensor of shape
  295. :math:`(\text{minibatch} , \text{in\_channels} , iD , iH , iW)`
  296. weight: quantized filters of shape
  297. :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kD , kH , kW)`
  298. bias: **non-quantized** bias tensor of shape
  299. :math:`(\text{out\_channels})`. The tensor type must be `torch.float`.
  300. stride: the stride of the convolving kernel. Can be a single number or a
  301. tuple `(sD, sH, sW)`. Default: 1
  302. padding: implicit paddings on both sides of the input. Can be a
  303. single number or a tuple `(padD, padH, padW)`. Default: 0
  304. dilation: the spacing between kernel elements. Can be a single number or
  305. a tuple `(dD, dH, dW)`. Default: 1
  306. groups: split input into groups, :math:`\text{in\_channels}` should be
  307. divisible by the number of groups. Default: 1
  308. padding_mode: the padding mode to use. Only "zeros" is supported for
  309. quantized convolution at the moment. Default: "zeros"
  310. scale: quantization scale for the output. Default: 1.0
  311. zero_point: quantization zero_point for the output. Default: 0
  312. dtype: quantization data type to use. Default: ``torch.quint8``
  313. Examples::
  314. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
  315. >>> from torch.ao.nn.quantized import functional as qF
  316. >>> filters = torch.randn(8, 4, 3, 3, 3, dtype=torch.float)
  317. >>> inputs = torch.randn(1, 4, 5, 5, 5, dtype=torch.float)
  318. >>> bias = torch.randn(8, dtype=torch.float)
  319. >>>
  320. >>> scale, zero_point = 1.0, 0
  321. >>> dtype_inputs = torch.quint8
  322. >>> dtype_filters = torch.qint8
  323. >>>
  324. >>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters)
  325. >>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs)
  326. >>> qF.conv3d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point)
  327. """ # noqa: E501
  328. if padding_mode != "zeros":
  329. raise NotImplementedError("Only zero-padding is supported!")
  330. if input.dtype != torch.quint8:
  331. raise NotImplementedError(
  332. "Only torch.quint8 is supported for activation tensor!"
  333. )
  334. if weight.dtype != torch.qint8:
  335. raise NotImplementedError("Only torch.qint8 is supported for weight tensor!")
  336. if input.ndim != 5:
  337. raise ValueError("Input shape must be `(N, C, D, H, W)`!")
  338. stride = _triple(stride)
  339. padding = _triple(padding)
  340. dilation = _triple(dilation)
  341. packed_params = torch.ops.quantized.conv3d_prepack(
  342. weight, bias, stride, padding, dilation, groups
  343. )
  344. return torch.ops.quantized.conv3d(input, packed_params, scale, zero_point)
  345. def interpolate(
  346. input, size=None, scale_factor=None, mode="nearest", align_corners=None
  347. ):
  348. r"""Down/up samples the input to either the given :attr:`size` or the given
  349. :attr:`scale_factor`
  350. See :func:`torch.nn.functional.interpolate` for implementation details.
  351. The input dimensions are interpreted in the form:
  352. `mini-batch x channels x [optional depth] x [optional height] x width`.
  353. .. note:: The input quantization parameters propagate to the output.
  354. .. note:: Only 2D/3D input is supported for quantized inputs
  355. .. note:: Only the following modes are supported for the quantized inputs:
  356. - `bilinear`
  357. - `nearest`
  358. Args:
  359. input (Tensor): the input tensor
  360. size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):
  361. output spatial size.
  362. scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple.
  363. mode (str): algorithm used for upsampling:
  364. ``'nearest'`` | ``'bilinear'``
  365. align_corners (bool, optional): Geometrically, we consider the pixels of the
  366. input and output as squares rather than points.
  367. If set to ``True``, the input and output tensors are aligned by the
  368. center points of their corner pixels, preserving the values at the corner pixels.
  369. If set to ``False``, the input and output tensors are aligned by the corner
  370. points of their corner pixels, and the interpolation uses edge value padding
  371. for out-of-boundary values, making this operation *independent* of input size
  372. when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode`
  373. is ``'bilinear'``.
  374. Default: ``False``
  375. """
  376. if not input.is_quantized:
  377. raise ValueError("Input to 'quantized.interpolate' must be quantized!")
  378. return torch.nn.functional.interpolate(
  379. input, size, scale_factor, mode, align_corners
  380. )
  381. def linear(
  382. input: Tensor,
  383. weight: Tensor,
  384. bias: Tensor | None = None,
  385. scale: float | None = None,
  386. zero_point: int | None = None,
  387. ) -> Tensor:
  388. r"""
  389. Applies a linear transformation to the incoming quantized data:
  390. :math:`y = xA^T + b`.
  391. See :class:`~torch.ao.nn.quantized.Linear`
  392. .. note::
  393. Current implementation packs weights on every call, which has penalty on performance.
  394. If you want to avoid the overhead, use :class:`~torch.ao.nn.quantized.Linear`.
  395. Args:
  396. input (Tensor): Quantized input of type `torch.quint8`
  397. weight (Tensor): Quantized weight of type `torch.qint8`
  398. bias (Tensor): None or fp32 bias of type `torch.float`
  399. scale (double): output scale. If None, derived from the input scale
  400. zero_point (long): output zero point. If None, derived from the input zero_point
  401. Shape:
  402. - Input: :math:`(N, *, in\_features)` where `*` means any number of
  403. additional dimensions
  404. - Weight: :math:`(out\_features, in\_features)`
  405. - Bias: :math:`(out\_features)`
  406. - Output: :math:`(N, *, out\_features)`
  407. """
  408. if scale is None:
  409. scale = input.q_scale()
  410. if zero_point is None:
  411. zero_point = input.q_zero_point()
  412. _packed_params = torch.ops.quantized.linear_prepack(weight, bias)
  413. return torch.ops.quantized.linear(input, _packed_params, scale, zero_point)
  414. def max_pool1d(
  415. input,
  416. kernel_size,
  417. stride=None,
  418. padding=0,
  419. dilation=1,
  420. ceil_mode=False,
  421. return_indices=False,
  422. ):
  423. r"""Applies a 1D max pooling over a quantized input signal composed of
  424. several quantized input planes.
  425. .. note:: The input quantization parameters are propagated to the output.
  426. See :class:`~torch.ao.nn.quantized.MaxPool1d` for details.
  427. """
  428. if return_indices:
  429. raise NotImplementedError("return_indices is not yet implemented!")
  430. if stride is None:
  431. stride = torch.jit.annotate(list[int], [])
  432. return torch.nn.functional.max_pool1d(
  433. input,
  434. kernel_size,
  435. stride,
  436. padding,
  437. dilation,
  438. ceil_mode=ceil_mode,
  439. return_indices=return_indices,
  440. )
  441. def max_pool2d(
  442. input,
  443. kernel_size,
  444. stride=None,
  445. padding=0,
  446. dilation=1,
  447. ceil_mode=False,
  448. return_indices=False,
  449. ):
  450. r"""Applies a 2D max pooling over a quantized input signal composed of
  451. several quantized input planes.
  452. .. note:: The input quantization parameters are propagated to the output.
  453. See :class:`~torch.ao.nn.quantized.MaxPool2d` for details.
  454. """
  455. if return_indices:
  456. raise NotImplementedError("return_indices is not yet implemented!")
  457. if stride is None:
  458. stride = torch.jit.annotate(list[int], [])
  459. return torch.nn.functional.max_pool2d(
  460. input,
  461. kernel_size,
  462. stride,
  463. padding,
  464. dilation,
  465. ceil_mode=ceil_mode,
  466. return_indices=return_indices,
  467. )
  468. def celu(input: Tensor, scale: float, zero_point: int, alpha: float = 1.0) -> Tensor:
  469. r"""celu(input, scale, zero_point, alpha=1.) -> Tensor
  470. Applies the quantized CELU function element-wise.
  471. .. math::
  472. \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x / \alpha) - 1))
  473. Args:
  474. input: quantized input
  475. alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
  476. """
  477. if not input.is_quantized:
  478. raise ValueError("Input to 'quantized.celu' must be quantized!")
  479. return torch.ops.quantized.celu(input, scale, zero_point, alpha)
  480. def leaky_relu(
  481. input: Tensor,
  482. negative_slope: float = 0.01,
  483. inplace: bool = False,
  484. scale: float | None = None,
  485. zero_point: int | None = None,
  486. ):
  487. r"""
  488. Quantized version of the.
  489. leaky_relu(input, negative_slope=0.01, inplace=False, scale, zero_point) -> Tensor
  490. Applies element-wise,
  491. :math:`\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)`
  492. Args:
  493. input: Quantized input
  494. negative_slope: The slope of the negative input
  495. inplace: Inplace modification of the input tensor
  496. scale, zero_point: Scale and zero point of the output tensor.
  497. See :class:`~torch.nn.LeakyReLU` for more details.
  498. """
  499. if scale is not None and zero_point is not None:
  500. if inplace:
  501. raise AssertionError("Cannot rescale with `inplace`")
  502. output = torch._empty_affine_quantized(
  503. input.shape, scale=scale, zero_point=int(zero_point), dtype=input.dtype
  504. )
  505. torch._C._nn.leaky_relu(input, negative_slope, out=output)
  506. return output
  507. if inplace:
  508. result = torch._C._nn.leaky_relu_(input, negative_slope)
  509. else:
  510. result = torch._C._nn.leaky_relu(input, negative_slope)
  511. return result
  512. def hardtanh(
  513. input: Tensor, min_val: float = -1.0, max_val: float = 1.0, inplace: bool = False
  514. ) -> Tensor:
  515. r"""This is the quantized version of :func:`~torch.nn.functional.hardtanh`."""
  516. if not input.is_quantized:
  517. raise ValueError("Input to 'quantized.hardtanh' must be quantized!")
  518. if inplace:
  519. return torch._C._nn.hardtanh_(input, min_val, max_val)
  520. return torch._C._nn.hardtanh(input, min_val, max_val)
  521. def hardswish(input: Tensor, scale: float, zero_point: int) -> Tensor:
  522. r"""This is the quantized version of :func:`~torch.nn.functional.hardswish`.
  523. Args:
  524. input: quantized input
  525. scale: quantization scale of the output tensor
  526. zero_point: quantization zero point of the output tensor
  527. """
  528. if not input.is_quantized:
  529. raise ValueError("Input to 'quantized.hardswish' must be quantized!")
  530. return torch._ops.ops.quantized.hardswish(input, scale, zero_point)
  531. def threshold(input: Tensor, threshold: float, value: float) -> Tensor:
  532. r"""Applies the quantized version of the threshold function element-wise:
  533. .. math::
  534. x = \begin{cases}
  535. x & \text{if~} x > \text{threshold} \\
  536. \text{value} & \text{otherwise}
  537. \end{cases}
  538. See :class:`~torch.nn.Threshold` for more details.
  539. """
  540. if not input.is_quantized:
  541. raise ValueError("Input to 'quantized.threshold' must be quantized!")
  542. if threshold is None:
  543. raise ValueError("Input to 'threshold' must be specified!")
  544. if value is None:
  545. raise ValueError("Input to 'value' must be specified!")
  546. return torch._ops.ops.quantized.threshold(input, threshold, value)
  547. def elu(input: Tensor, scale: float, zero_point: int, alpha: float = 1.0) -> Tensor:
  548. r"""This is the quantized version of :func:`~torch.nn.functional.elu`.
  549. Args:
  550. input: quantized input
  551. scale: quantization scale of the output tensor
  552. zero_point: quantization zero point of the output tensor
  553. alpha: the alpha constant
  554. """
  555. if not input.is_quantized:
  556. raise ValueError("Input to 'quantized.elu' must be quantized!")
  557. return torch.ops.quantized.elu(input, scale, zero_point, alpha)
  558. def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor:
  559. r"""This is the quantized version of :func:`~torch.nn.functional.hardsigmoid`."""
  560. if not input.is_quantized:
  561. raise ValueError("Input to 'quantized.hardsigmoid' must be quantized!")
  562. if inplace:
  563. return torch._C._nn.hardsigmoid_(input) # type: ignore[attr-defined]
  564. return torch._C._nn.hardsigmoid(input)
  565. def clamp(input: Tensor, min_: float, max_: float) -> Tensor:
  566. r"""float(input, min\_, max\_) -> Tensor
  567. Applies the clamp function element-wise.
  568. See :class:`~torch.ao.nn.quantized.clamp` for more details.
  569. Args:
  570. input: quantized input
  571. min_: minimum value for clamping
  572. max_: maximum value for clamping
  573. """
  574. if not input.is_quantized:
  575. raise ValueError("Input to 'quantized.clamp' must be quantized!")
  576. return torch.clamp(input, min_, max_)
  577. def upsample(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
  578. r"""Upsamples the input to either the given :attr:`size` or the given
  579. :attr:`scale_factor`
  580. .. warning::
  581. This function is deprecated in favor of
  582. :func:`torch.ao.nn.quantized.functional.interpolate`.
  583. This is equivalent with ``nn.quantized.functional.interpolate(...)``.
  584. See :func:`torch.nn.functional.interpolate` for implementation details.
  585. The input dimensions are interpreted in the form:
  586. `mini-batch x channels x [optional depth] x [optional height] x width`.
  587. .. note:: The input quantization parameters propagate to the output.
  588. .. note:: Only 2D input is supported for quantized inputs
  589. .. note:: Only the following modes are supported for the quantized inputs:
  590. - `bilinear`
  591. - `nearest`
  592. Args:
  593. input (Tensor): quantized input tensor
  594. size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):
  595. output spatial size.
  596. scale_factor (float or Tuple[float]): multiplier for spatial size. Has to be an integer.
  597. mode (str): algorithm used for upsampling:
  598. ``'nearest'`` | ``'bilinear'``
  599. align_corners (bool, optional): Geometrically, we consider the pixels of the
  600. input and output as squares rather than points.
  601. If set to ``True``, the input and output tensors are aligned by the
  602. center points of their corner pixels, preserving the values at the corner pixels.
  603. If set to ``False``, the input and output tensors are aligned by the corner
  604. points of their corner pixels, and the interpolation uses edge value padding
  605. for out-of-boundary values, making this operation *independent* of input size
  606. when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode`
  607. is ``'bilinear'``.
  608. Default: ``False``
  609. .. warning::
  610. With ``align_corners = True``, the linearly interpolating modes
  611. (`bilinear`) don't proportionally align the
  612. output and input pixels, and thus the output values can depend on the
  613. input size. This was the default behavior for these modes up to version
  614. 0.3.1. Since then, the default behavior is ``align_corners = False``.
  615. See :class:`~torch.nn.Upsample` for concrete examples on how this
  616. affects the outputs.
  617. """
  618. warnings.warn(
  619. "nn.quantized.functional.upsample is deprecated. Use nn.quantized.functional.interpolate instead.",
  620. stacklevel=2,
  621. )
  622. return interpolate(input, size, scale_factor, mode, align_corners)
  623. def upsample_bilinear(input, size=None, scale_factor=None):
  624. r"""Upsamples the input, using bilinear upsampling.
  625. .. warning::
  626. This function is deprecated in favor of
  627. :func:`torch.ao.nn.quantized.functional.interpolate`.
  628. This is equivalent with
  629. ``nn.quantized.functional.interpolate(..., mode='bilinear', align_corners=True)``.
  630. .. note:: The input quantization parameters propagate to the output.
  631. .. note:: Only 2D inputs are supported
  632. Args:
  633. input (Tensor): quantized input
  634. size (int or Tuple[int, int]): output spatial size.
  635. scale_factor (int or Tuple[int, int]): multiplier for spatial size
  636. """
  637. # DeprecationWarning is ignored by default
  638. warnings.warn(
  639. "nn.quantized.functional.upsample_bilinear is deprecated. Use nn.quantized.functional.interpolate instead.",
  640. stacklevel=2,
  641. )
  642. return interpolate(input, size, scale_factor, mode="bilinear", align_corners=True)
  643. def upsample_nearest(input, size=None, scale_factor=None):
  644. r"""Upsamples the input, using nearest neighbours' pixel values.
  645. .. warning::
  646. This function is deprecated in favor of
  647. :func:`torch.ao.nn.quantized.functional.interpolate`.
  648. This is equivalent with ``nn.quantized.functional.interpolate(..., mode='nearest')``.
  649. .. note:: The input quantization parameters propagate to the output.
  650. .. note:: Only 2D inputs are supported
  651. Args:
  652. input (Tensor): quantized input
  653. size (int or Tuple[int, int] or Tuple[int, int, int]): output spatial
  654. size.
  655. scale_factor (int): multiplier for spatial size. Has to be an integer.
  656. """
  657. # DeprecationWarning is ignored by default
  658. warnings.warn(
  659. "nn.quantized.functional.upsample_nearest is deprecated. Use nn.quantized.functional.interpolate instead.",
  660. stacklevel=2,
  661. )
  662. return interpolate(input, size, scale_factor, mode="nearest")