# LICENSE HEADER MANAGED BY add-license-header # # Copyright 2018 Kornia Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from kornia import metrics from kornia.core import Module, Tensor def ssim3d_loss( img1: Tensor, img2: Tensor, window_size: int, max_val: float = 1.0, eps: float = 1e-12, reduction: str = "mean", padding: str = "same", ) -> Tensor: r"""Compute a loss based on the SSIM measurement. The loss, or the Structural dissimilarity (DSSIM) is described as: .. math:: \text{loss}(x, y) = \frac{1 - \text{SSIM}(x, y)}{2} See :meth:`~kornia.losses.ssim` for details about SSIM. Args: img1: the first input image with shape :math:`(B, C, D, H, W)`. img2: the second input image with shape :math:`(B, C, D, H, W)`. window_size: the size of the gaussian kernel to smooth the images. max_val: the dynamic range of the images. eps: Small value for numerically stability when dividing. reduction : Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. padding: ``'same'`` | ``'valid'``. Whether to only use the "valid" convolution area to compute SSIM to match the MATLAB implementation of original SSIM paper. Returns: The loss based on the ssim index. Examples: >>> input1 = torch.rand(1, 4, 5, 5, 5) >>> input2 = torch.rand(1, 4, 5, 5, 5) >>> loss = ssim3d_loss(input1, input2, 5) """ # compute the ssim map ssim_map: Tensor = metrics.ssim3d(img1, img2, window_size, max_val, eps, padding) # compute and reduce the loss loss = 1.0 - ssim_map if reduction == "mean": loss = loss.mean() elif reduction == "sum": loss = loss.sum() elif reduction == "none": pass else: raise NotImplementedError("Invalid reduction option.") return loss class SSIM3DLoss(Module): r"""Create a criterion that computes a loss based on the SSIM measurement. The loss, or the Structural dissimilarity (DSSIM) is described as: .. math:: \text{loss}(x, y) = \frac{1 - \text{SSIM}(x, y)}{2} See :meth:`~kornia.losses.ssim_loss` for details about SSIM. Args: window_size: the size of the gaussian kernel to smooth the images. max_val: the dynamic range of the images. eps: Small value for numerically stability when dividing. reduction : Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. padding: ``'same'`` | ``'valid'``. Whether to only use the "valid" convolution area to compute SSIM to match the MATLAB implementation of original SSIM paper. Returns: The loss based on the ssim index. Examples: >>> input1 = torch.rand(1, 4, 5, 5, 5) >>> input2 = torch.rand(1, 4, 5, 5, 5) >>> criterion = SSIM3DLoss(5) >>> loss = criterion(input1, input2) """ def __init__( self, window_size: int, max_val: float = 1.0, eps: float = 1e-12, reduction: str = "mean", padding: str = "same" ) -> None: super().__init__() self.window_size: int = window_size self.max_val: float = max_val self.eps: float = eps self.reduction: str = reduction self.padding: str = padding def forward(self, img1: Tensor, img2: Tensor) -> Tensor: return ssim3d_loss(img1, img2, self.window_size, self.max_val, self.eps, self.reduction, self.padding)