| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from __future__ import annotations
- from typing import Callable
- import torch
- from torch import nn
- from kornia.core import Module, Tensor, as_tensor, stack, tensor, where, zeros_like
- class _HausdorffERLossBase(Module):
- """Base class for binary Hausdorff loss based on morphological erosion.
- This is an Hausdorff Distance (HD) Loss that based on morphological erosion,which provided
- a differentiable approximation of Hausdorff distance as stated in :cite:`karimi2019reducing`.
- The code is refactored on top of `here <https://github.com/PatRyg99/HausdorffLoss/
- blob/master/hausdorff_loss.py>`__.
- Args:
- alpha: controls the erosion rate in each iteration.
- k: the number of iterations of erosion.
- reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.
- 'none': no reduction will be applied, 'mean': the weighted mean of the output is taken,
- 'sum': the output will be summed.
- Returns:
- Estimated Hausdorff Loss.
- """
- conv: Callable[..., Tensor]
- max_pool: Callable[..., Tensor]
- def __init__(self, alpha: float = 2.0, k: int = 10, reduction: str = "mean") -> None:
- super().__init__()
- self.alpha = alpha
- self.k = k
- self.reduction = reduction
- self.register_buffer("kernel", self.get_kernel())
- def get_kernel(self) -> Tensor:
- """Get kernel for image morphology convolution."""
- raise NotImplementedError
- def perform_erosion(self, pred: Tensor, target: Tensor) -> Tensor:
- bound = (pred - target) ** 2
- kernel = as_tensor(self.kernel, device=pred.device, dtype=pred.dtype)
- eroded = zeros_like(bound, device=pred.device, dtype=pred.dtype)
- mask = torch.ones_like(bound, device=pred.device, dtype=torch.bool)
- # Same padding, assuming kernel is odd and square (cube) shaped.
- padding = (kernel.size(-1) - 1) // 2
- for k in range(self.k):
- # compute convolution with kernel
- dilation = self.conv(bound, weight=kernel, padding=padding, groups=1)
- # apply soft thresholding at 0.5 and normalize
- erosion = dilation - 0.5
- erosion[erosion < 0] = 0
- # image-wise differences for 2D images
- erosion_max = self.max_pool(erosion)
- erosion_min = -self.max_pool(-erosion)
- # No normalization needed if `max - min = 0`
- _to_norm = (erosion_max - erosion_min) != 0
- to_norm = _to_norm.squeeze()
- if to_norm.any():
- # NOTE: avoid in-place ops like below, which will not pass gradcheck:
- # erosion[to_norm] = (erosion[to_norm] - erosion_min[to_norm]) / (
- # erosion_max[to_norm] - erosion_min[to_norm])
- _erosion_to_fill = (erosion - erosion_min) / (erosion_max - erosion_min)
- erosion = where(mask * _to_norm, _erosion_to_fill, erosion)
- # save erosion and add to loss
- eroded = eroded + erosion * (k + 1) ** self.alpha
- bound = erosion
- return eroded
- def forward(self, pred: Tensor, target: Tensor) -> Tensor:
- """Compute Hausdorff loss.
- Args:
- pred: predicted tensor with a shape of :math:`(B, C, H, W)` or :math:`(B, C, D, H, W)`.
- Each channel is as binary as: 1 -> fg, 0 -> bg.
- target: target tensor with a shape of :math:`(B, 1, H, W)` or :math:`(B, C, D, H, W)`.
- Returns:
- Estimated Hausdorff Loss.
- """
- if not (pred.shape[2:] == target.shape[2:] and pred.size(0) == target.size(0) and target.size(1) == 1):
- raise ValueError(
- "Prediction and target need to be of same size, and target should not be one-hot."
- f"Got {pred.shape} and {target.shape}."
- )
- if pred.size(1) < target.max().item():
- raise ValueError("Invalid target value.")
- out = stack(
- [
- self.perform_erosion(
- pred[:, i : i + 1],
- where(
- target == i,
- tensor(1, device=target.device, dtype=target.dtype),
- tensor(0, device=target.device, dtype=target.dtype),
- ),
- )
- for i in range(pred.size(1))
- ]
- )
- if self.reduction == "mean":
- out = out.mean()
- elif self.reduction == "sum":
- out = out.sum()
- elif self.reduction == "none":
- pass
- else:
- raise NotImplementedError(f"reduction `{self.reduction}` has not been implemented yet.")
- return out
- class HausdorffERLoss(_HausdorffERLossBase):
- r"""Binary Hausdorff loss based on morphological erosion.
- Hausdorff Distance loss measures the maximum distance of a predicted segmentation boundary to
- the nearest ground-truth edge pixel. For two segmentation point sets X and Y ,
- the one-sided HD from X to Y is defined as:
- .. math::
- hd(X,Y) = \max_{x \in X} \min_{y \in Y}||x - y||_2
- Furthermore, the bidirectional HD is:
- .. math::
- HD(X,Y) = max(hd(X, Y), hd(Y, X))
- This is an Hausdorff Distance (HD) Loss that based on morphological erosion, which provided
- a differentiable approximation of Hausdorff distance as stated in :cite:`karimi2019reducing`.
- The code is refactored on top of `here <https://github.com/PatRyg99/HausdorffLoss/
- blob/master/hausdorff_loss.py>`__.
- Args:
- alpha: controls the erosion rate in each iteration.
- k: the number of iterations of erosion.
- reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.
- 'none': no reduction will be applied, 'mean': the weighted mean of the output is taken,
- 'sum': the output will be summed.
- Examples:
- >>> hdloss = HausdorffERLoss()
- >>> input = torch.randn(5, 3, 20, 20)
- >>> target = (torch.rand(5, 1, 20, 20) * 2).long()
- >>> res = hdloss(input, target)
- """
- conv = torch.conv2d
- max_pool = nn.AdaptiveMaxPool2d(1)
- def get_kernel(self) -> Tensor:
- """Get kernel for image morphology convolution."""
- cross = tensor([[[0, 1, 0], [1, 1, 1], [0, 1, 0]]])
- kernel = cross * 0.2
- return kernel[None]
- def forward(self, pred: Tensor, target: Tensor) -> Tensor:
- """Compute Hausdorff loss.
- Args:
- pred: predicted tensor with a shape of :math:`(B, C, H, W)`.
- Each channel is as binary as: 1 -> fg, 0 -> bg.
- target: target tensor with a shape of :math:`(B, 1, H, W)`.
- Returns:
- Estimated Hausdorff Loss.
- """
- if pred.dim() != 4:
- raise ValueError(f"Only 2D images supported. Got {pred.dim()}.")
- if not (target.max() < pred.size(1) and target.min() >= 0 and target.dtype == torch.long):
- raise ValueError(
- f"Expect long type target value in range (0, {pred.size(1)}). ({target.min()}, {target.max()})"
- )
- return super().forward(pred, target)
- class HausdorffERLoss3D(_HausdorffERLossBase):
- r"""Binary 3D Hausdorff loss based on morphological erosion.
- Hausdorff Distance loss measures the maximum distance of a predicted segmentation boundary to
- the nearest ground-truth edge pixel. For two segmentation point sets X and Y ,
- the one-sided HD from X to Y is defined as:
- .. math::
- hd(X,Y) = \max_{x \in X} \min_{y \in Y}||x - y||_2
- Furthermore, the bidirectional HD is:
- .. math::
- HD(X,Y) = max(hd(X, Y), hd(Y, X))
- This is a 3D Hausdorff Distance (HD) Loss that based on morphological erosion, which provided
- a differentiable approximation of Hausdorff distance as stated in :cite:`karimi2019reducing`.
- The code is refactored on top of `here <https://github.com/PatRyg99/HausdorffLoss/
- blob/master/hausdorff_loss.py>`__.
- Args:
- alpha: controls the erosion rate in each iteration.
- k: the number of iterations of erosion.
- reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.
- 'none': no reduction will be applied, 'mean': the weighted mean of the output is taken,
- 'sum': the output will be summed.
- Examples:
- >>> hdloss = HausdorffERLoss3D()
- >>> input = torch.randn(5, 3, 20, 20, 20)
- >>> target = (torch.rand(5, 1, 20, 20, 20) * 2).long()
- >>> res = hdloss(input, target)
- """
- conv = torch.conv3d
- max_pool = nn.AdaptiveMaxPool3d(1)
- def get_kernel(self) -> Tensor:
- """Get kernel for image morphology convolution."""
- cross = tensor([[[0, 1, 0], [1, 1, 1], [0, 1, 0]]])
- bound = tensor([[[0, 0, 0], [0, 1, 0], [0, 0, 0]]])
- # NOTE: The original repo claimed it shaped as (3, 1, 3, 3)
- # which Jian suspect it is wrongly implemented.
- # https://github.com/PatRyg99/HausdorffLoss/blob/9f580acd421af648e74b45d46555ccb7a876c27c/hausdorff_loss.py#L94
- kernel = stack([bound, cross, bound], 1) * (1 / 7)
- return kernel[None]
- def forward(self, pred: Tensor, target: Tensor) -> Tensor:
- """Compute 3D Hausdorff loss.
- Args:
- pred: predicted tensor with a shape of :math:`(B, C, D, H, W)`.
- Each channel is as binary as: 1 -> fg, 0 -> bg.
- target: target tensor with a shape of :math:`(B, 1, D, H, W)`.
- Returns:
- Estimated Hausdorff Loss.
- """
- if pred.dim() != 5:
- raise ValueError(f"Only 3D images supported. Got {pred.dim()}.")
- return super().forward(pred, target)
|