jpeg.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. import math
  19. import torch
  20. import torch.nn.functional as F
  21. from kornia.color import rgb_to_ycbcr, ycbcr_to_rgb
  22. from kornia.constants import pi
  23. from kornia.core import Device, Dtype, Parameter, Tensor
  24. from kornia.core import ImageModule as Module
  25. from kornia.core.check import (
  26. KORNIA_CHECK,
  27. KORNIA_CHECK_IS_TENSOR,
  28. KORNIA_CHECK_SHAPE,
  29. )
  30. from kornia.geometry.transform.affwarp import rescale
  31. from kornia.utils.image import perform_keep_shape_image
  32. from kornia.utils.misc import (
  33. differentiable_clipping,
  34. differentiable_polynomial_floor,
  35. differentiable_polynomial_rounding,
  36. )
  37. __all__ = ["JPEGCodecDifferentiable", "jpeg_codec_differentiable"]
  38. def _get_default_qt_y(device: Device, dtype: Dtype) -> Tensor:
  39. """Generate default Quantization table of Y channel."""
  40. return torch.tensor(
  41. [
  42. [16, 11, 10, 16, 24, 40, 51, 61],
  43. [12, 12, 14, 19, 26, 58, 60, 55],
  44. [14, 13, 16, 24, 40, 57, 69, 56],
  45. [14, 17, 22, 29, 51, 87, 80, 62],
  46. [18, 22, 37, 56, 68, 109, 103, 77],
  47. [24, 35, 55, 64, 81, 104, 113, 92],
  48. [49, 64, 78, 87, 103, 121, 120, 101],
  49. [72, 92, 95, 98, 112, 100, 103, 99],
  50. ],
  51. device=device,
  52. dtype=dtype,
  53. )
  54. def _get_default_qt_c(device: Device, dtype: Dtype) -> Tensor:
  55. """Generate default Quantization table of C channels."""
  56. return torch.tensor(
  57. [
  58. [17, 18, 24, 47, 99, 99, 99, 99],
  59. [18, 21, 26, 66, 99, 99, 99, 99],
  60. [24, 26, 56, 99, 99, 99, 99, 99],
  61. [47, 66, 99, 99, 99, 99, 99, 99],
  62. [99, 99, 99, 99, 99, 99, 99, 99],
  63. [99, 99, 99, 99, 99, 99, 99, 99],
  64. [99, 99, 99, 99, 99, 99, 99, 99],
  65. [99, 99, 99, 99, 99, 99, 99, 99],
  66. ],
  67. device=device,
  68. dtype=dtype,
  69. )
  70. def _patchify_8x8(input: Tensor) -> Tensor:
  71. """Extract non-overlapping 8 x 8 patches from the given input image.
  72. Args:
  73. input (Tensor): Input image of the shape :math:`(B, H, W)`.
  74. Returns:
  75. output (Tensor): Image patchify of the shape :math:`(B, N, 8, 8)`.
  76. """
  77. # Get input shape
  78. B, H, W = input.shape
  79. # Patchify to shape [B, N, H // 8, W // 8]
  80. output: Tensor = input.view(B, H // 8, 8, W // 8, 8).permute(0, 1, 3, 2, 4).reshape(B, -1, 8, 8)
  81. return output
  82. def _unpatchify_8x8(input: Tensor, H: int, W: int) -> Tensor:
  83. """Reverse non-overlapping 8 x 8 patching.
  84. Args:
  85. input (Tensor): Input image of the shape :math:`(B, N, 8, 8)`.
  86. H: height of resulting tensor.
  87. W: width of resulting tensor.
  88. Returns:
  89. output (Tensor): Image patchify of the shape :math:`(B, H, W)`.
  90. """
  91. # Get input shape
  92. B, _N = input.shape[:2]
  93. # Unpatch to [B, H, W]
  94. output: Tensor = input.view(B, H // 8, W // 8, 8, 8).permute(0, 1, 3, 2, 4).reshape(B, H, W)
  95. return output
  96. def _dct_8x8(input: Tensor) -> Tensor:
  97. """Perform an 8 x 8 discrete cosine transform.
  98. Args:
  99. input (Tensor): Patched input tensor of the shape :math:`(B, N, 8, 8)`.
  100. Returns:
  101. output (Tensor): DCT output tensor of the shape :math:`(B, N, 8, 8)`.
  102. """
  103. # Get dtype and device
  104. dtype: Dtype = input.dtype
  105. device: Device = input.device
  106. # Make DCT tensor and scaling
  107. index: Tensor = torch.arange(8, dtype=dtype, device=device)
  108. x, y, u, v = torch.meshgrid(index, index, index, index)
  109. dct_tensor: Tensor = ((2.0 * x + 1.0) * u * pi / 16.0).cos() * ((2.0 * y + 1.0) * v * pi / 16.0).cos()
  110. alpha: Tensor = torch.ones(8, dtype=dtype, device=device)
  111. alpha[0] = 1.0 / (2**0.5)
  112. dct_scale: Tensor = torch.einsum("i, j -> ij", alpha, alpha) * 0.25
  113. # Apply DCT
  114. output: Tensor = dct_scale[None, None] * torch.tensordot(input - 128.0, dct_tensor)
  115. return output
  116. def _idct_8x8(input: Tensor) -> Tensor:
  117. """Perform an 8 x 8 discrete cosine transform.
  118. Args:
  119. input (Tensor): Patched input tensor of the shape :math:`(B, N, 8, 8)`.
  120. Returns:
  121. output (Tensor): DCT output tensor of the shape :math:`(B, N, 8, 8)`.
  122. """
  123. # Get dtype and device
  124. dtype: Dtype = input.dtype
  125. device: Device = input.device
  126. # Make and apply scaling
  127. alpha: Tensor = torch.ones(8, dtype=dtype, device=device)
  128. alpha[0] = 1.0 / (2**0.5)
  129. dct_scale: Tensor = torch.outer(alpha, alpha)
  130. input = input * dct_scale[None, None]
  131. # Make DCT tensor and scaling
  132. index: Tensor = torch.arange(8, dtype=dtype, device=device)
  133. x, y, u, v = torch.meshgrid(index, index, index, index)
  134. idct_tensor: Tensor = ((2.0 * u + 1.0) * x * pi / 16.0).cos() * ((2.0 * v + 1.0) * y * pi / 16.0).cos()
  135. # Apply DCT
  136. output: Tensor = 0.25 * torch.tensordot(input, idct_tensor, dims=2) + 128.0
  137. return output
  138. def _jpeg_quality_to_scale(
  139. compression_strength: Tensor,
  140. ) -> Tensor:
  141. """Convert a given JPEG quality to the scaling factor.
  142. Args:
  143. compression_strength (Tensor): Compression strength ranging from 0 to 100. Any shape is supported.
  144. Returns:
  145. scale (Tensor): Scaling factor to be applied to quantization matrix. Same shape as input.
  146. """
  147. # Get scale
  148. scale: Tensor = differentiable_polynomial_floor(
  149. torch.where(
  150. compression_strength < 50,
  151. 5000.0 / compression_strength,
  152. 200.0 - 2.0 * compression_strength,
  153. )
  154. )
  155. return scale
  156. def _quantize(
  157. input: Tensor,
  158. jpeg_quality: Tensor,
  159. quantization_table: Tensor,
  160. ) -> Tensor:
  161. """Perform quantization.
  162. Args:
  163. input (Tensor): Input tensor of the shape :math:`(B, N, 8, 8)`.
  164. jpeg_quality (Tensor): Compression strength to be applied, shape is :math:`(B)`.
  165. quantization_table (Tensor): Quantization table of the shape :math:`(1, 8, 8)` or :math:`(B, 8, 8)`.
  166. Returns:
  167. output (Tensor): Quantized output tensor of the shape :math:`(B, N, 8, 8)`.
  168. """
  169. # Scale quantization table
  170. quantization_table_scaled: Tensor = (
  171. quantization_table[:, None] * _jpeg_quality_to_scale(jpeg_quality)[:, None, None, None]
  172. )
  173. # Perform scaling
  174. quantization_table = differentiable_polynomial_floor(
  175. differentiable_clipping((quantization_table_scaled + 50.0) / 100.0, 1, 255)
  176. )
  177. output: Tensor = input / quantization_table
  178. # Perform rounding
  179. output = differentiable_polynomial_rounding(output)
  180. return output
  181. def _dequantize(
  182. input: Tensor,
  183. jpeg_quality: Tensor,
  184. quantization_table: Tensor,
  185. ) -> Tensor:
  186. """Perform dequantization.
  187. Args:
  188. input (Tensor): Input tensor of the shape :math:`(B, N, 8, 8)`.
  189. jpeg_quality (Tensor): Compression strength to be applied, shape is :math:`(B)`.
  190. quantization_table (Tensor): Quantization table of the shape :math:`(1, 8, 8)` or :math:`(B, 8, 8)`.
  191. Returns:
  192. output (Tensor): Quantized output tensor of the shape :math:`(B, N, 8, 8)`.
  193. """
  194. # Scale quantization table
  195. quantization_table_scaled: Tensor = (
  196. quantization_table[:, None] * _jpeg_quality_to_scale(jpeg_quality)[:, None, None, None]
  197. )
  198. # Perform scaling
  199. output: Tensor = input * differentiable_polynomial_floor(
  200. differentiable_clipping((quantization_table_scaled + 50.0) / 100.0, 1, 255)
  201. )
  202. return output
  203. def _chroma_subsampling(input_ycbcr: Tensor) -> tuple[Tensor, Tensor, Tensor]:
  204. """Implement chroma subsampling.
  205. Args:
  206. input_ycbcr (Tensor): YCbCr input tensor of the shape :math:`(B, 3, H, W)`.
  207. Returns:
  208. output_y (Tensor): Y component (not-subsampled), shape is :math:`(B, H, W)`.
  209. output_cb (Tensor): Cb component (subsampled), shape is :math:`(B, H // 2, W // 2)`.
  210. output_cr (Tensor): Cr component (subsampled), shape is :math:`(B, H // 2, W // 2)`.
  211. """
  212. # Get components
  213. output_y: Tensor = input_ycbcr[:, 0]
  214. output_cb: Tensor = input_ycbcr[:, 1]
  215. output_cr: Tensor = input_ycbcr[:, 2]
  216. # Perform average pooling of Cb and Cr channels
  217. output_cb = rescale(
  218. output_cb[:, None],
  219. factor=0.5,
  220. interpolation="bilinear",
  221. align_corners=False,
  222. antialias=True,
  223. )
  224. output_cr = rescale(
  225. output_cr[:, None],
  226. factor=0.5,
  227. interpolation="bilinear",
  228. align_corners=False,
  229. antialias=True,
  230. )
  231. return output_y, output_cb[:, 0], output_cr[:, 0]
  232. def _chroma_upsampling(input_c: Tensor) -> Tensor:
  233. """Perform chroma upsampling.
  234. Args:
  235. input_c (Tensor): Cb or Cr component to be upsampled of the shape :math:`(B, H, W)`.
  236. Returns:
  237. output_c (Tensor): Upsampled C(b or r) component of the shape :math:`(B, H * 2, W * 2)`.
  238. """
  239. # Upsample component
  240. output_c: Tensor = rescale(
  241. input_c[:, None],
  242. factor=2.0,
  243. interpolation="bilinear",
  244. align_corners=False,
  245. antialias=False,
  246. )
  247. return output_c[:, 0]
  248. def _jpeg_encode(
  249. image_rgb: Tensor,
  250. jpeg_quality: Tensor,
  251. quantization_table_y: Tensor,
  252. quantization_table_c: Tensor,
  253. ) -> tuple[Tensor, Tensor, Tensor]:
  254. """Perform JPEG encoding.
  255. Args:
  256. image_rgb (Tensor): RGB input images of the shape :math:`(B, 3, H, W)`.
  257. jpeg_quality (Tensor): Compression strength of the shape :math:`(B)`.
  258. quantization_table_y (Tensor): Quantization table for Y channel.
  259. quantization_table_c (Tensor): Quantization table for C channels.
  260. Returns:
  261. y_encoded (Tensor): Encoded Y component of the shape :math:`(B, N, 8, 8)`.
  262. cb_encoded (Tensor): Encoded Cb component of the shape :math:`(B, N, 8, 8)`.
  263. cr_encoded (Tensor): Encoded Cr component of the shape :math:`(B, N, 8, 8)`.
  264. """
  265. # Convert RGB image to YCbCr.
  266. image_ycbcr: Tensor = rgb_to_ycbcr(image_rgb)
  267. # Scale pixel-range to [0, 255]
  268. image_ycbcr = 255.0 * image_ycbcr
  269. # Perform chroma subsampling
  270. input_y, input_cb, input_cr = _chroma_subsampling(image_ycbcr)
  271. # Patchify, DCT, and rounding
  272. input_y, input_cb, input_cr = (
  273. _patchify_8x8(input_y),
  274. _patchify_8x8(input_cb),
  275. _patchify_8x8(input_cr),
  276. )
  277. dct_y = _dct_8x8(input_y)
  278. dct_cb_cr = _dct_8x8(torch.cat((input_cb, input_cr), dim=1))
  279. y_encoded: Tensor = _quantize(
  280. dct_y,
  281. jpeg_quality,
  282. quantization_table_y,
  283. )
  284. cb_encoded, cr_encoded = _quantize(
  285. dct_cb_cr,
  286. jpeg_quality,
  287. quantization_table_c,
  288. ).chunk(2, dim=1)
  289. return y_encoded, cb_encoded, cr_encoded
  290. def _jpeg_decode(
  291. input_y: Tensor,
  292. input_cb: Tensor,
  293. input_cr: Tensor,
  294. jpeg_quality: Tensor,
  295. H: int,
  296. W: int,
  297. quantization_table_y: Tensor,
  298. quantization_table_c: Tensor,
  299. ) -> Tensor:
  300. """Perform JPEG decoding.
  301. Args:
  302. input_y (Tensor): Compressed Y component of the shape :math:`(B, N, 8, 8)`.
  303. input_cb (Tensor): Compressed Cb component of the shape :math:`(B, N, 8, 8)`.
  304. input_cr (Tensor): Compressed Cr component of the shape :math:`(B, N, 8, 8)`.
  305. jpeg_quality (Tensor): Compression strength of the shape :math:`(B)`.
  306. H (int): Original image height.
  307. W (int): Original image width.
  308. quantization_table_y (Tensor): Quantization table for Y channel.
  309. quantization_table_c (Tensor): Quantization table for C channels.
  310. Returns:
  311. rgb_decoded (Tensor): Decompressed RGB image of the shape :math:`(B, 3, H, W)`.
  312. """
  313. # Dequantize inputs
  314. input_y = _dequantize(
  315. input_y,
  316. jpeg_quality,
  317. quantization_table_y,
  318. )
  319. input_cb_cr = _dequantize(
  320. torch.cat((input_cb, input_cr), dim=1),
  321. jpeg_quality,
  322. quantization_table_c,
  323. )
  324. # Perform inverse DCT
  325. idct_y: Tensor = _idct_8x8(input_y)
  326. idct_cb, idct_cr = _idct_8x8(input_cb_cr).chunk(2, dim=1)
  327. # Reverse patching
  328. image_y: Tensor = _unpatchify_8x8(idct_y, H, W)
  329. image_cb: Tensor = _unpatchify_8x8(idct_cb, H // 2, W // 2)
  330. image_cr: Tensor = _unpatchify_8x8(idct_cr, H // 2, W // 2)
  331. # Perform chroma upsampling
  332. image_cb = _chroma_upsampling(image_cb)
  333. image_cr = _chroma_upsampling(image_cr)
  334. # Back to [0, 1] pixel-range
  335. image_ycbcr: Tensor = torch.stack((image_y, image_cb, image_cr), dim=1) / 255.0
  336. # Convert back to RGB space.
  337. rgb_decoded: Tensor = ycbcr_to_rgb(image_ycbcr)
  338. return rgb_decoded
  339. def _perform_padding(image: Tensor) -> tuple[Tensor, int, int]:
  340. """Pad a given image to be dividable by 16.
  341. Args:
  342. image: Image of the shape :math:`(*, 3, H, W)`.
  343. Returns:
  344. image_padded: Padded image of the shape :math:`(*, 3, H_{new}, W_{new})`.
  345. h_pad: Padded pixels along the horizontal axis.
  346. w_pad: Padded pixels along the vertical axis.
  347. """
  348. # Get spatial dimensions of the image
  349. H, W = image.shape[-2:]
  350. # Compute horizontal and vertical padding
  351. h_pad: int = math.ceil(H / 16) * 16 - H
  352. w_pad: int = math.ceil(W / 16) * 16 - W
  353. # Perform padding (we follow JPEG and pad only the bottom and right side of the image)
  354. image_padded: Tensor = F.pad(image, (0, w_pad, 0, h_pad), "replicate")
  355. return image_padded, h_pad, w_pad
  356. @perform_keep_shape_image
  357. def jpeg_codec_differentiable(
  358. image_rgb: Tensor,
  359. jpeg_quality: Tensor,
  360. quantization_table_y: Tensor | None = None,
  361. quantization_table_c: Tensor | None = None,
  362. ) -> Tensor:
  363. r"""Differentiable JPEG encoding-decoding module.
  364. Based on :cite:`reich2024` :cite:`shin2017`, we perform differentiable JPEG encoding-decoding as follows:
  365. .. image:: _static/img/jpeg_codec_differentiable.png
  366. .. math::
  367. \text{JPEG}_{\text{diff}}(I, q, QT_{y}, QT_{c}) = \hat{I}
  368. Where:
  369. - :math:`I` is the original image to be coded.
  370. - :math:`q` is the JPEG quality controlling the compression strength.
  371. - :math:`QT_{y}` is the luma quantization table.
  372. - :math:`QT_{c}` is the chroma quantization table.
  373. - :math:`\hat{I}` is the resulting JPEG encoded-decoded image.
  374. .. note:::
  375. The input (and output) pixel range is :math:`[0, 1]`. In case you want to handle normalized images you are
  376. required to first perform denormalization followed by normalizing the output images again.
  377. Note, that this implementation models the encoding-decoding mapping of JPEG in a differentiable setting,
  378. however, does not allow the excess of the JPEG-coded byte file itself.
  379. For more details please refer to :cite:`reich2024`.
  380. This implementation is not meant for data loading. For loading JPEG images please refer to `kornia.io`.
  381. There we provide an optimized Rust implementation for fast JPEG loading.
  382. Args:
  383. image_rgb: the RGB image to be coded.
  384. jpeg_quality: JPEG quality in the range :math:`[0, 100]` controlling the compression strength.
  385. quantization_table_y: quantization table for Y channel. Default: `None`, which will load the standard
  386. quantization table.
  387. quantization_table_c: quantization table for C channels. Default: `None`, which will load the standard
  388. quantization table.
  389. Shape:
  390. - image_rgb: :math:`(*, 3, H, W)`.
  391. - jpeg_quality: :math:`(1)` or :math:`(B)` (if used batch dim. needs to match w/ image_rgb).
  392. - quantization_table_y: :math:`(8, 8)` or :math:`(B, 8, 8)` (if used batch dim. needs to match w/ image_rgb).
  393. - quantization_table_c: :math:`(8, 8)` or :math:`(B, 8, 8)` (if used batch dim. needs to match w/ image_rgb).
  394. Return:
  395. JPEG coded image of the shape :math:`(B, 3, H, W)`
  396. Example:
  397. To perform JPEG coding with the standard quantization tables just provide a JPEG quality
  398. >>> img = torch.rand(3, 3, 64, 64, requires_grad=True, dtype=torch.float)
  399. >>> jpeg_quality = torch.tensor((99.0, 25.0, 1.0), requires_grad=True)
  400. >>> img_jpeg = jpeg_codec_differentiable(img, jpeg_quality)
  401. >>> img_jpeg.sum().backward()
  402. You also have the option to provide custom quantization tables
  403. >>> img = torch.rand(3, 3, 64, 64, requires_grad=True, dtype=torch.float)
  404. >>> jpeg_quality = torch.tensor((99.0, 25.0, 1.0), requires_grad=True)
  405. >>> quantization_table_y = torch.randint(1, 256, size=(3, 8, 8), dtype=torch.float)
  406. >>> quantization_table_c = torch.randint(1, 256, size=(3, 8, 8), dtype=torch.float)
  407. >>> img_jpeg = jpeg_codec_differentiable(img, jpeg_quality, quantization_table_y, quantization_table_c)
  408. >>> img_jpeg.sum().backward()
  409. In case you want to control the quantization purly base on the quantization tables use a JPEG quality of 99.5.
  410. Setting the JPEG quality to 99.5 leads to a QT scaling of 1, see Eq. 2 of :cite:`reich2024` for details.
  411. >>> img = torch.rand(3, 3, 64, 64, requires_grad=True, dtype=torch.float)
  412. >>> jpeg_quality = torch.ones(3) * 99.5
  413. >>> quantization_table_y = torch.randint(1, 256, size=(3, 8, 8), dtype=torch.float)
  414. >>> quantization_table_c = torch.randint(1, 256, size=(3, 8, 8), dtype=torch.float)
  415. >>> img_jpeg = jpeg_codec_differentiable(img, jpeg_quality, quantization_table_y, quantization_table_c)
  416. >>> img_jpeg.sum().backward()
  417. """
  418. # Check that inputs are tensors
  419. KORNIA_CHECK_IS_TENSOR(image_rgb)
  420. KORNIA_CHECK_IS_TENSOR(jpeg_quality)
  421. # Get device and dtype
  422. dtype: Dtype = image_rgb.dtype
  423. device: Device = image_rgb.device
  424. # Use default QT if QT is not given
  425. quantization_table_y = _get_default_qt_y(device, dtype) if quantization_table_y is None else quantization_table_y
  426. quantization_table_c = _get_default_qt_c(device, dtype) if quantization_table_c is None else quantization_table_c
  427. KORNIA_CHECK_IS_TENSOR(quantization_table_y)
  428. KORNIA_CHECK_IS_TENSOR(quantization_table_c)
  429. # Check shape of inputs
  430. KORNIA_CHECK_SHAPE(image_rgb, ["*", "3", "H", "W"])
  431. KORNIA_CHECK_SHAPE(jpeg_quality, ["B"])
  432. # Add batch dimension to quantization tables if needed
  433. if quantization_table_y.ndim == 2:
  434. quantization_table_y = quantization_table_y.unsqueeze(dim=0)
  435. if quantization_table_c.ndim == 2:
  436. quantization_table_c = quantization_table_c.unsqueeze(dim=0)
  437. # Check resulting shape of quantization tables
  438. KORNIA_CHECK_SHAPE(quantization_table_y, ["B", "8", "8"])
  439. KORNIA_CHECK_SHAPE(quantization_table_c, ["B", "8", "8"])
  440. # Check value range of JPEG quality
  441. KORNIA_CHECK(
  442. (jpeg_quality.amin().item() >= 0.0) and (jpeg_quality.amax().item() <= 100.0),
  443. f"JPEG quality is out of range. Expected range is [0, 100], "
  444. f"got [{jpeg_quality.amin().item()}, {jpeg_quality.amax().item()}]. Consider clipping jpeg_quality.",
  445. )
  446. # Pad the image to a shape dividable by 16
  447. image_rgb, h_pad, w_pad = _perform_padding(image_rgb)
  448. # Get height and shape
  449. H, W = image_rgb.shape[-2:]
  450. # Check matching batch dimensions
  451. if quantization_table_y.shape[0] != 1:
  452. KORNIA_CHECK(
  453. quantization_table_y.shape[0] == image_rgb.shape[0],
  454. f"Batch dimensions do not match. "
  455. f"Got {image_rgb.shape[0]} images and {quantization_table_y.shape[0]} quantization tables (Y).",
  456. )
  457. if quantization_table_c.shape[0] != 1:
  458. KORNIA_CHECK(
  459. quantization_table_c.shape[0] == image_rgb.shape[0],
  460. f"Batch dimensions do not match. "
  461. f"Got {image_rgb.shape[0]} images and {quantization_table_c.shape[0]} quantization tables (C).",
  462. )
  463. if jpeg_quality.shape[0] != 1:
  464. KORNIA_CHECK(
  465. jpeg_quality.shape[0] == image_rgb.shape[0],
  466. f"Batch dimensions do not match. "
  467. f"Got {image_rgb.shape[0]} images and {jpeg_quality.shape[0]} JPEG qualities.",
  468. )
  469. # keep jpeg_quality same device as input tensor
  470. jpeg_quality = jpeg_quality.to(device, dtype)
  471. # Quantization tables to same device and dtype as input image
  472. quantization_table_y = quantization_table_y.to(device, dtype)
  473. quantization_table_c = quantization_table_c.to(device, dtype)
  474. # Perform encoding
  475. y_encoded, cb_encoded, cr_encoded = _jpeg_encode(
  476. image_rgb=image_rgb,
  477. jpeg_quality=jpeg_quality,
  478. quantization_table_c=quantization_table_c,
  479. quantization_table_y=quantization_table_y,
  480. )
  481. image_rgb_jpeg: Tensor = _jpeg_decode(
  482. input_y=y_encoded,
  483. input_cb=cb_encoded,
  484. input_cr=cr_encoded,
  485. jpeg_quality=jpeg_quality,
  486. H=H,
  487. W=W,
  488. quantization_table_c=quantization_table_c,
  489. quantization_table_y=quantization_table_y,
  490. )
  491. # Clip coded image
  492. image_rgb_jpeg = differentiable_clipping(input=image_rgb_jpeg, min_val=0.0, max_val=255.0)
  493. # Crop the image again to the original shape
  494. image_rgb_jpeg = image_rgb_jpeg[..., : H - h_pad, : W - w_pad]
  495. return image_rgb_jpeg
  496. class JPEGCodecDifferentiable(Module):
  497. r"""Differentiable JPEG encoding-decoding module.
  498. Based on :cite:`reich2024` :cite:`shin2017`, we perform differentiable JPEG encoding-decoding as follows:
  499. .. math::
  500. \text{JPEG}_{\text{diff}}(I, q, QT_{y}, QT_{c}) = \hat{I}
  501. Where:
  502. - :math:`I` is the original image to be coded.
  503. - :math:`q` is the JPEG quality controlling the compression strength.
  504. - :math:`QT_{y}` is the luma quantization table.
  505. - :math:`QT_{c}` is the chroma quantization table.
  506. - :math:`\hat{I}` is the resulting JPEG encoded-decoded image.
  507. .. image:: _static/img/jpeg_codec_differentiable.png
  508. .. note::
  509. The input (and output) pixel range is :math:`[0, 1]`. In case you want to handle normalized images you are
  510. required to first perform denormalization followed by normalizing the output images again.
  511. Note, that this implementation models the encoding-decoding mapping of JPEG in a differentiable setting,
  512. however, does not allow the excess of the JPEG-coded byte file itself.
  513. For more details please refer to :cite:`reich2024`.
  514. This implementation is not meant for data loading. For loading JPEG images please refer to `kornia.io`.
  515. There we provide an optimized Rust implementation for fast JPEG loading.
  516. Args:
  517. quantization_table_y: quantization table for Y channel. Default: `None`, which will load the standard
  518. quantization table.
  519. quantization_table_c: quantization table for C channels. Default: `None`, which will load the standard
  520. quantization table.
  521. Shape:
  522. - quantization_table_y: :math:`(8, 8)` or :math:`(B, 8, 8)` (if used batch dim. needs to match w/ image_rgb).
  523. - quantization_table_c: :math:`(8, 8)` or :math:`(B, 8, 8)` (if used batch dim. needs to match w/ image_rgb).
  524. - image_rgb: :math:`(*, 3, H, W)`.
  525. - jpeg_quality: :math:`(1)` or :math:`(B)` (if used batch dim. needs to match w/ image_rgb).
  526. Example:
  527. You can use the differentiable JPEG module with standard quantization tables by
  528. >>> diff_jpeg_module = JPEGCodecDifferentiable()
  529. >>> img = torch.rand(2, 3, 32, 32, requires_grad=True, dtype=torch.float)
  530. >>> jpeg_quality = torch.tensor((99.0, 1.0), requires_grad=True)
  531. >>> img_jpeg = diff_jpeg_module(img, jpeg_quality)
  532. >>> img_jpeg.sum().backward()
  533. You can also specify custom quantization tables to be used by
  534. >>> quantization_table_y = torch.randint(1, 256, size=(2, 8, 8), dtype=torch.float)
  535. >>> quantization_table_c = torch.randint(1, 256, size=(2, 8, 8), dtype=torch.float)
  536. >>> diff_jpeg_module = JPEGCodecDifferentiable(quantization_table_y, quantization_table_c)
  537. >>> img = torch.rand(2, 3, 32, 32, requires_grad=True, dtype=torch.float)
  538. >>> jpeg_quality = torch.tensor((99.0, 1.0), requires_grad=True)
  539. >>> img_jpeg = diff_jpeg_module(img, jpeg_quality)
  540. >>> img_jpeg.sum().backward()
  541. In case you want to learn the quantization tables just pass parameters `nn.Parameter`
  542. >>> quantization_table_y = torch.nn.Parameter(torch.randint(1, 256, size=(2, 8, 8), dtype=torch.float))
  543. >>> quantization_table_c = torch.nn.Parameter(torch.randint(1, 256, size=(2, 8, 8), dtype=torch.float))
  544. >>> diff_jpeg_module = JPEGCodecDifferentiable(quantization_table_y, quantization_table_c)
  545. >>> img = torch.rand(2, 3, 32, 32, requires_grad=True, dtype=torch.float)
  546. >>> jpeg_quality = torch.tensor((99.0, 1.0), requires_grad=True)
  547. >>> img_jpeg = diff_jpeg_module(img, jpeg_quality)
  548. >>> img_jpeg.sum().backward()
  549. """
  550. def __init__(
  551. self,
  552. quantization_table_y: Tensor | Parameter | None = None,
  553. quantization_table_c: Tensor | Parameter | None = None,
  554. ) -> None:
  555. super().__init__()
  556. # Get default quantization tables if needed
  557. quantization_table_y = _get_default_qt_y(None, None) if quantization_table_y is None else quantization_table_y
  558. quantization_table_c = _get_default_qt_c(None, None) if quantization_table_c is None else quantization_table_c
  559. if isinstance(quantization_table_y, Parameter):
  560. self.register_parameter("quantization_table_y", quantization_table_y)
  561. else:
  562. self.register_buffer("quantization_table_y", quantization_table_y)
  563. if isinstance(quantization_table_c, Parameter):
  564. self.register_parameter("quantization_table_c", quantization_table_c)
  565. else:
  566. self.register_buffer("quantization_table_c", quantization_table_c)
  567. def forward(
  568. self,
  569. image_rgb: Tensor,
  570. jpeg_quality: Tensor,
  571. ) -> Tensor:
  572. device = image_rgb.device
  573. dtype = image_rgb.dtype
  574. # Move quantization tables to the same device and dtype as input
  575. # and store it in the local variables created in init
  576. quantization_table_y = self.quantization_table_y.to(device, dtype)
  577. quantization_table_c = self.quantization_table_c.to(device, dtype)
  578. # Perform encoding-decoding
  579. image_rgb_jpeg: Tensor = jpeg_codec_differentiable(
  580. image_rgb,
  581. jpeg_quality=jpeg_quality,
  582. quantization_table_c=quantization_table_c,
  583. quantization_table_y=quantization_table_y,
  584. )
  585. return image_rgb_jpeg