| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from __future__ import annotations
- import math
- import torch
- import torch.nn.functional as F
- from kornia.color import rgb_to_ycbcr, ycbcr_to_rgb
- from kornia.constants import pi
- from kornia.core import Device, Dtype, Parameter, Tensor
- from kornia.core import ImageModule as Module
- from kornia.core.check import (
- KORNIA_CHECK,
- KORNIA_CHECK_IS_TENSOR,
- KORNIA_CHECK_SHAPE,
- )
- from kornia.geometry.transform.affwarp import rescale
- from kornia.utils.image import perform_keep_shape_image
- from kornia.utils.misc import (
- differentiable_clipping,
- differentiable_polynomial_floor,
- differentiable_polynomial_rounding,
- )
- __all__ = ["JPEGCodecDifferentiable", "jpeg_codec_differentiable"]
- def _get_default_qt_y(device: Device, dtype: Dtype) -> Tensor:
- """Generate default Quantization table of Y channel."""
- return torch.tensor(
- [
- [16, 11, 10, 16, 24, 40, 51, 61],
- [12, 12, 14, 19, 26, 58, 60, 55],
- [14, 13, 16, 24, 40, 57, 69, 56],
- [14, 17, 22, 29, 51, 87, 80, 62],
- [18, 22, 37, 56, 68, 109, 103, 77],
- [24, 35, 55, 64, 81, 104, 113, 92],
- [49, 64, 78, 87, 103, 121, 120, 101],
- [72, 92, 95, 98, 112, 100, 103, 99],
- ],
- device=device,
- dtype=dtype,
- )
- def _get_default_qt_c(device: Device, dtype: Dtype) -> Tensor:
- """Generate default Quantization table of C channels."""
- return torch.tensor(
- [
- [17, 18, 24, 47, 99, 99, 99, 99],
- [18, 21, 26, 66, 99, 99, 99, 99],
- [24, 26, 56, 99, 99, 99, 99, 99],
- [47, 66, 99, 99, 99, 99, 99, 99],
- [99, 99, 99, 99, 99, 99, 99, 99],
- [99, 99, 99, 99, 99, 99, 99, 99],
- [99, 99, 99, 99, 99, 99, 99, 99],
- [99, 99, 99, 99, 99, 99, 99, 99],
- ],
- device=device,
- dtype=dtype,
- )
- def _patchify_8x8(input: Tensor) -> Tensor:
- """Extract non-overlapping 8 x 8 patches from the given input image.
- Args:
- input (Tensor): Input image of the shape :math:`(B, H, W)`.
- Returns:
- output (Tensor): Image patchify of the shape :math:`(B, N, 8, 8)`.
- """
- # Get input shape
- B, H, W = input.shape
- # Patchify to shape [B, N, H // 8, W // 8]
- output: Tensor = input.view(B, H // 8, 8, W // 8, 8).permute(0, 1, 3, 2, 4).reshape(B, -1, 8, 8)
- return output
- def _unpatchify_8x8(input: Tensor, H: int, W: int) -> Tensor:
- """Reverse non-overlapping 8 x 8 patching.
- Args:
- input (Tensor): Input image of the shape :math:`(B, N, 8, 8)`.
- H: height of resulting tensor.
- W: width of resulting tensor.
- Returns:
- output (Tensor): Image patchify of the shape :math:`(B, H, W)`.
- """
- # Get input shape
- B, _N = input.shape[:2]
- # Unpatch to [B, H, W]
- output: Tensor = input.view(B, H // 8, W // 8, 8, 8).permute(0, 1, 3, 2, 4).reshape(B, H, W)
- return output
- def _dct_8x8(input: Tensor) -> Tensor:
- """Perform an 8 x 8 discrete cosine transform.
- Args:
- input (Tensor): Patched input tensor of the shape :math:`(B, N, 8, 8)`.
- Returns:
- output (Tensor): DCT output tensor of the shape :math:`(B, N, 8, 8)`.
- """
- # Get dtype and device
- dtype: Dtype = input.dtype
- device: Device = input.device
- # Make DCT tensor and scaling
- index: Tensor = torch.arange(8, dtype=dtype, device=device)
- x, y, u, v = torch.meshgrid(index, index, index, index)
- dct_tensor: Tensor = ((2.0 * x + 1.0) * u * pi / 16.0).cos() * ((2.0 * y + 1.0) * v * pi / 16.0).cos()
- alpha: Tensor = torch.ones(8, dtype=dtype, device=device)
- alpha[0] = 1.0 / (2**0.5)
- dct_scale: Tensor = torch.einsum("i, j -> ij", alpha, alpha) * 0.25
- # Apply DCT
- output: Tensor = dct_scale[None, None] * torch.tensordot(input - 128.0, dct_tensor)
- return output
- def _idct_8x8(input: Tensor) -> Tensor:
- """Perform an 8 x 8 discrete cosine transform.
- Args:
- input (Tensor): Patched input tensor of the shape :math:`(B, N, 8, 8)`.
- Returns:
- output (Tensor): DCT output tensor of the shape :math:`(B, N, 8, 8)`.
- """
- # Get dtype and device
- dtype: Dtype = input.dtype
- device: Device = input.device
- # Make and apply scaling
- alpha: Tensor = torch.ones(8, dtype=dtype, device=device)
- alpha[0] = 1.0 / (2**0.5)
- dct_scale: Tensor = torch.outer(alpha, alpha)
- input = input * dct_scale[None, None]
- # Make DCT tensor and scaling
- index: Tensor = torch.arange(8, dtype=dtype, device=device)
- x, y, u, v = torch.meshgrid(index, index, index, index)
- idct_tensor: Tensor = ((2.0 * u + 1.0) * x * pi / 16.0).cos() * ((2.0 * v + 1.0) * y * pi / 16.0).cos()
- # Apply DCT
- output: Tensor = 0.25 * torch.tensordot(input, idct_tensor, dims=2) + 128.0
- return output
- def _jpeg_quality_to_scale(
- compression_strength: Tensor,
- ) -> Tensor:
- """Convert a given JPEG quality to the scaling factor.
- Args:
- compression_strength (Tensor): Compression strength ranging from 0 to 100. Any shape is supported.
- Returns:
- scale (Tensor): Scaling factor to be applied to quantization matrix. Same shape as input.
- """
- # Get scale
- scale: Tensor = differentiable_polynomial_floor(
- torch.where(
- compression_strength < 50,
- 5000.0 / compression_strength,
- 200.0 - 2.0 * compression_strength,
- )
- )
- return scale
- def _quantize(
- input: Tensor,
- jpeg_quality: Tensor,
- quantization_table: Tensor,
- ) -> Tensor:
- """Perform quantization.
- Args:
- input (Tensor): Input tensor of the shape :math:`(B, N, 8, 8)`.
- jpeg_quality (Tensor): Compression strength to be applied, shape is :math:`(B)`.
- quantization_table (Tensor): Quantization table of the shape :math:`(1, 8, 8)` or :math:`(B, 8, 8)`.
- Returns:
- output (Tensor): Quantized output tensor of the shape :math:`(B, N, 8, 8)`.
- """
- # Scale quantization table
- quantization_table_scaled: Tensor = (
- quantization_table[:, None] * _jpeg_quality_to_scale(jpeg_quality)[:, None, None, None]
- )
- # Perform scaling
- quantization_table = differentiable_polynomial_floor(
- differentiable_clipping((quantization_table_scaled + 50.0) / 100.0, 1, 255)
- )
- output: Tensor = input / quantization_table
- # Perform rounding
- output = differentiable_polynomial_rounding(output)
- return output
- def _dequantize(
- input: Tensor,
- jpeg_quality: Tensor,
- quantization_table: Tensor,
- ) -> Tensor:
- """Perform dequantization.
- Args:
- input (Tensor): Input tensor of the shape :math:`(B, N, 8, 8)`.
- jpeg_quality (Tensor): Compression strength to be applied, shape is :math:`(B)`.
- quantization_table (Tensor): Quantization table of the shape :math:`(1, 8, 8)` or :math:`(B, 8, 8)`.
- Returns:
- output (Tensor): Quantized output tensor of the shape :math:`(B, N, 8, 8)`.
- """
- # Scale quantization table
- quantization_table_scaled: Tensor = (
- quantization_table[:, None] * _jpeg_quality_to_scale(jpeg_quality)[:, None, None, None]
- )
- # Perform scaling
- output: Tensor = input * differentiable_polynomial_floor(
- differentiable_clipping((quantization_table_scaled + 50.0) / 100.0, 1, 255)
- )
- return output
- def _chroma_subsampling(input_ycbcr: Tensor) -> tuple[Tensor, Tensor, Tensor]:
- """Implement chroma subsampling.
- Args:
- input_ycbcr (Tensor): YCbCr input tensor of the shape :math:`(B, 3, H, W)`.
- Returns:
- output_y (Tensor): Y component (not-subsampled), shape is :math:`(B, H, W)`.
- output_cb (Tensor): Cb component (subsampled), shape is :math:`(B, H // 2, W // 2)`.
- output_cr (Tensor): Cr component (subsampled), shape is :math:`(B, H // 2, W // 2)`.
- """
- # Get components
- output_y: Tensor = input_ycbcr[:, 0]
- output_cb: Tensor = input_ycbcr[:, 1]
- output_cr: Tensor = input_ycbcr[:, 2]
- # Perform average pooling of Cb and Cr channels
- output_cb = rescale(
- output_cb[:, None],
- factor=0.5,
- interpolation="bilinear",
- align_corners=False,
- antialias=True,
- )
- output_cr = rescale(
- output_cr[:, None],
- factor=0.5,
- interpolation="bilinear",
- align_corners=False,
- antialias=True,
- )
- return output_y, output_cb[:, 0], output_cr[:, 0]
- def _chroma_upsampling(input_c: Tensor) -> Tensor:
- """Perform chroma upsampling.
- Args:
- input_c (Tensor): Cb or Cr component to be upsampled of the shape :math:`(B, H, W)`.
- Returns:
- output_c (Tensor): Upsampled C(b or r) component of the shape :math:`(B, H * 2, W * 2)`.
- """
- # Upsample component
- output_c: Tensor = rescale(
- input_c[:, None],
- factor=2.0,
- interpolation="bilinear",
- align_corners=False,
- antialias=False,
- )
- return output_c[:, 0]
- def _jpeg_encode(
- image_rgb: Tensor,
- jpeg_quality: Tensor,
- quantization_table_y: Tensor,
- quantization_table_c: Tensor,
- ) -> tuple[Tensor, Tensor, Tensor]:
- """Perform JPEG encoding.
- Args:
- image_rgb (Tensor): RGB input images of the shape :math:`(B, 3, H, W)`.
- jpeg_quality (Tensor): Compression strength of the shape :math:`(B)`.
- quantization_table_y (Tensor): Quantization table for Y channel.
- quantization_table_c (Tensor): Quantization table for C channels.
- Returns:
- y_encoded (Tensor): Encoded Y component of the shape :math:`(B, N, 8, 8)`.
- cb_encoded (Tensor): Encoded Cb component of the shape :math:`(B, N, 8, 8)`.
- cr_encoded (Tensor): Encoded Cr component of the shape :math:`(B, N, 8, 8)`.
- """
- # Convert RGB image to YCbCr.
- image_ycbcr: Tensor = rgb_to_ycbcr(image_rgb)
- # Scale pixel-range to [0, 255]
- image_ycbcr = 255.0 * image_ycbcr
- # Perform chroma subsampling
- input_y, input_cb, input_cr = _chroma_subsampling(image_ycbcr)
- # Patchify, DCT, and rounding
- input_y, input_cb, input_cr = (
- _patchify_8x8(input_y),
- _patchify_8x8(input_cb),
- _patchify_8x8(input_cr),
- )
- dct_y = _dct_8x8(input_y)
- dct_cb_cr = _dct_8x8(torch.cat((input_cb, input_cr), dim=1))
- y_encoded: Tensor = _quantize(
- dct_y,
- jpeg_quality,
- quantization_table_y,
- )
- cb_encoded, cr_encoded = _quantize(
- dct_cb_cr,
- jpeg_quality,
- quantization_table_c,
- ).chunk(2, dim=1)
- return y_encoded, cb_encoded, cr_encoded
- def _jpeg_decode(
- input_y: Tensor,
- input_cb: Tensor,
- input_cr: Tensor,
- jpeg_quality: Tensor,
- H: int,
- W: int,
- quantization_table_y: Tensor,
- quantization_table_c: Tensor,
- ) -> Tensor:
- """Perform JPEG decoding.
- Args:
- input_y (Tensor): Compressed Y component of the shape :math:`(B, N, 8, 8)`.
- input_cb (Tensor): Compressed Cb component of the shape :math:`(B, N, 8, 8)`.
- input_cr (Tensor): Compressed Cr component of the shape :math:`(B, N, 8, 8)`.
- jpeg_quality (Tensor): Compression strength of the shape :math:`(B)`.
- H (int): Original image height.
- W (int): Original image width.
- quantization_table_y (Tensor): Quantization table for Y channel.
- quantization_table_c (Tensor): Quantization table for C channels.
- Returns:
- rgb_decoded (Tensor): Decompressed RGB image of the shape :math:`(B, 3, H, W)`.
- """
- # Dequantize inputs
- input_y = _dequantize(
- input_y,
- jpeg_quality,
- quantization_table_y,
- )
- input_cb_cr = _dequantize(
- torch.cat((input_cb, input_cr), dim=1),
- jpeg_quality,
- quantization_table_c,
- )
- # Perform inverse DCT
- idct_y: Tensor = _idct_8x8(input_y)
- idct_cb, idct_cr = _idct_8x8(input_cb_cr).chunk(2, dim=1)
- # Reverse patching
- image_y: Tensor = _unpatchify_8x8(idct_y, H, W)
- image_cb: Tensor = _unpatchify_8x8(idct_cb, H // 2, W // 2)
- image_cr: Tensor = _unpatchify_8x8(idct_cr, H // 2, W // 2)
- # Perform chroma upsampling
- image_cb = _chroma_upsampling(image_cb)
- image_cr = _chroma_upsampling(image_cr)
- # Back to [0, 1] pixel-range
- image_ycbcr: Tensor = torch.stack((image_y, image_cb, image_cr), dim=1) / 255.0
- # Convert back to RGB space.
- rgb_decoded: Tensor = ycbcr_to_rgb(image_ycbcr)
- return rgb_decoded
- def _perform_padding(image: Tensor) -> tuple[Tensor, int, int]:
- """Pad a given image to be dividable by 16.
- Args:
- image: Image of the shape :math:`(*, 3, H, W)`.
- Returns:
- image_padded: Padded image of the shape :math:`(*, 3, H_{new}, W_{new})`.
- h_pad: Padded pixels along the horizontal axis.
- w_pad: Padded pixels along the vertical axis.
- """
- # Get spatial dimensions of the image
- H, W = image.shape[-2:]
- # Compute horizontal and vertical padding
- h_pad: int = math.ceil(H / 16) * 16 - H
- w_pad: int = math.ceil(W / 16) * 16 - W
- # Perform padding (we follow JPEG and pad only the bottom and right side of the image)
- image_padded: Tensor = F.pad(image, (0, w_pad, 0, h_pad), "replicate")
- return image_padded, h_pad, w_pad
- @perform_keep_shape_image
- def jpeg_codec_differentiable(
- image_rgb: Tensor,
- jpeg_quality: Tensor,
- quantization_table_y: Tensor | None = None,
- quantization_table_c: Tensor | None = None,
- ) -> Tensor:
- r"""Differentiable JPEG encoding-decoding module.
- Based on :cite:`reich2024` :cite:`shin2017`, we perform differentiable JPEG encoding-decoding as follows:
- .. image:: _static/img/jpeg_codec_differentiable.png
- .. math::
- \text{JPEG}_{\text{diff}}(I, q, QT_{y}, QT_{c}) = \hat{I}
- Where:
- - :math:`I` is the original image to be coded.
- - :math:`q` is the JPEG quality controlling the compression strength.
- - :math:`QT_{y}` is the luma quantization table.
- - :math:`QT_{c}` is the chroma quantization table.
- - :math:`\hat{I}` is the resulting JPEG encoded-decoded image.
- .. note:::
- The input (and output) pixel range is :math:`[0, 1]`. In case you want to handle normalized images you are
- required to first perform denormalization followed by normalizing the output images again.
- Note, that this implementation models the encoding-decoding mapping of JPEG in a differentiable setting,
- however, does not allow the excess of the JPEG-coded byte file itself.
- For more details please refer to :cite:`reich2024`.
- This implementation is not meant for data loading. For loading JPEG images please refer to `kornia.io`.
- There we provide an optimized Rust implementation for fast JPEG loading.
- Args:
- image_rgb: the RGB image to be coded.
- jpeg_quality: JPEG quality in the range :math:`[0, 100]` controlling the compression strength.
- quantization_table_y: quantization table for Y channel. Default: `None`, which will load the standard
- quantization table.
- quantization_table_c: quantization table for C channels. Default: `None`, which will load the standard
- quantization table.
- Shape:
- - image_rgb: :math:`(*, 3, H, W)`.
- - jpeg_quality: :math:`(1)` or :math:`(B)` (if used batch dim. needs to match w/ image_rgb).
- - quantization_table_y: :math:`(8, 8)` or :math:`(B, 8, 8)` (if used batch dim. needs to match w/ image_rgb).
- - quantization_table_c: :math:`(8, 8)` or :math:`(B, 8, 8)` (if used batch dim. needs to match w/ image_rgb).
- Return:
- JPEG coded image of the shape :math:`(B, 3, H, W)`
- Example:
- To perform JPEG coding with the standard quantization tables just provide a JPEG quality
- >>> img = torch.rand(3, 3, 64, 64, requires_grad=True, dtype=torch.float)
- >>> jpeg_quality = torch.tensor((99.0, 25.0, 1.0), requires_grad=True)
- >>> img_jpeg = jpeg_codec_differentiable(img, jpeg_quality)
- >>> img_jpeg.sum().backward()
- You also have the option to provide custom quantization tables
- >>> img = torch.rand(3, 3, 64, 64, requires_grad=True, dtype=torch.float)
- >>> jpeg_quality = torch.tensor((99.0, 25.0, 1.0), requires_grad=True)
- >>> quantization_table_y = torch.randint(1, 256, size=(3, 8, 8), dtype=torch.float)
- >>> quantization_table_c = torch.randint(1, 256, size=(3, 8, 8), dtype=torch.float)
- >>> img_jpeg = jpeg_codec_differentiable(img, jpeg_quality, quantization_table_y, quantization_table_c)
- >>> img_jpeg.sum().backward()
- In case you want to control the quantization purly base on the quantization tables use a JPEG quality of 99.5.
- Setting the JPEG quality to 99.5 leads to a QT scaling of 1, see Eq. 2 of :cite:`reich2024` for details.
- >>> img = torch.rand(3, 3, 64, 64, requires_grad=True, dtype=torch.float)
- >>> jpeg_quality = torch.ones(3) * 99.5
- >>> quantization_table_y = torch.randint(1, 256, size=(3, 8, 8), dtype=torch.float)
- >>> quantization_table_c = torch.randint(1, 256, size=(3, 8, 8), dtype=torch.float)
- >>> img_jpeg = jpeg_codec_differentiable(img, jpeg_quality, quantization_table_y, quantization_table_c)
- >>> img_jpeg.sum().backward()
- """
- # Check that inputs are tensors
- KORNIA_CHECK_IS_TENSOR(image_rgb)
- KORNIA_CHECK_IS_TENSOR(jpeg_quality)
- # Get device and dtype
- dtype: Dtype = image_rgb.dtype
- device: Device = image_rgb.device
- # Use default QT if QT is not given
- quantization_table_y = _get_default_qt_y(device, dtype) if quantization_table_y is None else quantization_table_y
- quantization_table_c = _get_default_qt_c(device, dtype) if quantization_table_c is None else quantization_table_c
- KORNIA_CHECK_IS_TENSOR(quantization_table_y)
- KORNIA_CHECK_IS_TENSOR(quantization_table_c)
- # Check shape of inputs
- KORNIA_CHECK_SHAPE(image_rgb, ["*", "3", "H", "W"])
- KORNIA_CHECK_SHAPE(jpeg_quality, ["B"])
- # Add batch dimension to quantization tables if needed
- if quantization_table_y.ndim == 2:
- quantization_table_y = quantization_table_y.unsqueeze(dim=0)
- if quantization_table_c.ndim == 2:
- quantization_table_c = quantization_table_c.unsqueeze(dim=0)
- # Check resulting shape of quantization tables
- KORNIA_CHECK_SHAPE(quantization_table_y, ["B", "8", "8"])
- KORNIA_CHECK_SHAPE(quantization_table_c, ["B", "8", "8"])
- # Check value range of JPEG quality
- KORNIA_CHECK(
- (jpeg_quality.amin().item() >= 0.0) and (jpeg_quality.amax().item() <= 100.0),
- f"JPEG quality is out of range. Expected range is [0, 100], "
- f"got [{jpeg_quality.amin().item()}, {jpeg_quality.amax().item()}]. Consider clipping jpeg_quality.",
- )
- # Pad the image to a shape dividable by 16
- image_rgb, h_pad, w_pad = _perform_padding(image_rgb)
- # Get height and shape
- H, W = image_rgb.shape[-2:]
- # Check matching batch dimensions
- if quantization_table_y.shape[0] != 1:
- KORNIA_CHECK(
- quantization_table_y.shape[0] == image_rgb.shape[0],
- f"Batch dimensions do not match. "
- f"Got {image_rgb.shape[0]} images and {quantization_table_y.shape[0]} quantization tables (Y).",
- )
- if quantization_table_c.shape[0] != 1:
- KORNIA_CHECK(
- quantization_table_c.shape[0] == image_rgb.shape[0],
- f"Batch dimensions do not match. "
- f"Got {image_rgb.shape[0]} images and {quantization_table_c.shape[0]} quantization tables (C).",
- )
- if jpeg_quality.shape[0] != 1:
- KORNIA_CHECK(
- jpeg_quality.shape[0] == image_rgb.shape[0],
- f"Batch dimensions do not match. "
- f"Got {image_rgb.shape[0]} images and {jpeg_quality.shape[0]} JPEG qualities.",
- )
- # keep jpeg_quality same device as input tensor
- jpeg_quality = jpeg_quality.to(device, dtype)
- # Quantization tables to same device and dtype as input image
- quantization_table_y = quantization_table_y.to(device, dtype)
- quantization_table_c = quantization_table_c.to(device, dtype)
- # Perform encoding
- y_encoded, cb_encoded, cr_encoded = _jpeg_encode(
- image_rgb=image_rgb,
- jpeg_quality=jpeg_quality,
- quantization_table_c=quantization_table_c,
- quantization_table_y=quantization_table_y,
- )
- image_rgb_jpeg: Tensor = _jpeg_decode(
- input_y=y_encoded,
- input_cb=cb_encoded,
- input_cr=cr_encoded,
- jpeg_quality=jpeg_quality,
- H=H,
- W=W,
- quantization_table_c=quantization_table_c,
- quantization_table_y=quantization_table_y,
- )
- # Clip coded image
- image_rgb_jpeg = differentiable_clipping(input=image_rgb_jpeg, min_val=0.0, max_val=255.0)
- # Crop the image again to the original shape
- image_rgb_jpeg = image_rgb_jpeg[..., : H - h_pad, : W - w_pad]
- return image_rgb_jpeg
- class JPEGCodecDifferentiable(Module):
- r"""Differentiable JPEG encoding-decoding module.
- Based on :cite:`reich2024` :cite:`shin2017`, we perform differentiable JPEG encoding-decoding as follows:
- .. math::
- \text{JPEG}_{\text{diff}}(I, q, QT_{y}, QT_{c}) = \hat{I}
- Where:
- - :math:`I` is the original image to be coded.
- - :math:`q` is the JPEG quality controlling the compression strength.
- - :math:`QT_{y}` is the luma quantization table.
- - :math:`QT_{c}` is the chroma quantization table.
- - :math:`\hat{I}` is the resulting JPEG encoded-decoded image.
- .. image:: _static/img/jpeg_codec_differentiable.png
- .. note::
- The input (and output) pixel range is :math:`[0, 1]`. In case you want to handle normalized images you are
- required to first perform denormalization followed by normalizing the output images again.
- Note, that this implementation models the encoding-decoding mapping of JPEG in a differentiable setting,
- however, does not allow the excess of the JPEG-coded byte file itself.
- For more details please refer to :cite:`reich2024`.
- This implementation is not meant for data loading. For loading JPEG images please refer to `kornia.io`.
- There we provide an optimized Rust implementation for fast JPEG loading.
- Args:
- quantization_table_y: quantization table for Y channel. Default: `None`, which will load the standard
- quantization table.
- quantization_table_c: quantization table for C channels. Default: `None`, which will load the standard
- quantization table.
- Shape:
- - quantization_table_y: :math:`(8, 8)` or :math:`(B, 8, 8)` (if used batch dim. needs to match w/ image_rgb).
- - quantization_table_c: :math:`(8, 8)` or :math:`(B, 8, 8)` (if used batch dim. needs to match w/ image_rgb).
- - image_rgb: :math:`(*, 3, H, W)`.
- - jpeg_quality: :math:`(1)` or :math:`(B)` (if used batch dim. needs to match w/ image_rgb).
- Example:
- You can use the differentiable JPEG module with standard quantization tables by
- >>> diff_jpeg_module = JPEGCodecDifferentiable()
- >>> img = torch.rand(2, 3, 32, 32, requires_grad=True, dtype=torch.float)
- >>> jpeg_quality = torch.tensor((99.0, 1.0), requires_grad=True)
- >>> img_jpeg = diff_jpeg_module(img, jpeg_quality)
- >>> img_jpeg.sum().backward()
- You can also specify custom quantization tables to be used by
- >>> quantization_table_y = torch.randint(1, 256, size=(2, 8, 8), dtype=torch.float)
- >>> quantization_table_c = torch.randint(1, 256, size=(2, 8, 8), dtype=torch.float)
- >>> diff_jpeg_module = JPEGCodecDifferentiable(quantization_table_y, quantization_table_c)
- >>> img = torch.rand(2, 3, 32, 32, requires_grad=True, dtype=torch.float)
- >>> jpeg_quality = torch.tensor((99.0, 1.0), requires_grad=True)
- >>> img_jpeg = diff_jpeg_module(img, jpeg_quality)
- >>> img_jpeg.sum().backward()
- In case you want to learn the quantization tables just pass parameters `nn.Parameter`
- >>> quantization_table_y = torch.nn.Parameter(torch.randint(1, 256, size=(2, 8, 8), dtype=torch.float))
- >>> quantization_table_c = torch.nn.Parameter(torch.randint(1, 256, size=(2, 8, 8), dtype=torch.float))
- >>> diff_jpeg_module = JPEGCodecDifferentiable(quantization_table_y, quantization_table_c)
- >>> img = torch.rand(2, 3, 32, 32, requires_grad=True, dtype=torch.float)
- >>> jpeg_quality = torch.tensor((99.0, 1.0), requires_grad=True)
- >>> img_jpeg = diff_jpeg_module(img, jpeg_quality)
- >>> img_jpeg.sum().backward()
- """
- def __init__(
- self,
- quantization_table_y: Tensor | Parameter | None = None,
- quantization_table_c: Tensor | Parameter | None = None,
- ) -> None:
- super().__init__()
- # Get default quantization tables if needed
- quantization_table_y = _get_default_qt_y(None, None) if quantization_table_y is None else quantization_table_y
- quantization_table_c = _get_default_qt_c(None, None) if quantization_table_c is None else quantization_table_c
- if isinstance(quantization_table_y, Parameter):
- self.register_parameter("quantization_table_y", quantization_table_y)
- else:
- self.register_buffer("quantization_table_y", quantization_table_y)
- if isinstance(quantization_table_c, Parameter):
- self.register_parameter("quantization_table_c", quantization_table_c)
- else:
- self.register_buffer("quantization_table_c", quantization_table_c)
- def forward(
- self,
- image_rgb: Tensor,
- jpeg_quality: Tensor,
- ) -> Tensor:
- device = image_rgb.device
- dtype = image_rgb.dtype
- # Move quantization tables to the same device and dtype as input
- # and store it in the local variables created in init
- quantization_table_y = self.quantization_table_y.to(device, dtype)
- quantization_table_c = self.quantization_table_c.to(device, dtype)
- # Perform encoding-decoding
- image_rgb_jpeg: Tensor = jpeg_codec_differentiable(
- image_rgb,
- jpeg_quality=jpeg_quality,
- quantization_table_c=quantization_table_c,
- quantization_table_y=quantization_table_y,
- )
- return image_rgb_jpeg
|