hausdorff.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. from typing import Callable
  19. import torch
  20. from torch import nn
  21. from kornia.core import Module, Tensor, as_tensor, stack, tensor, where, zeros_like
  22. class _HausdorffERLossBase(Module):
  23. """Base class for binary Hausdorff loss based on morphological erosion.
  24. This is an Hausdorff Distance (HD) Loss that based on morphological erosion,which provided
  25. a differentiable approximation of Hausdorff distance as stated in :cite:`karimi2019reducing`.
  26. The code is refactored on top of `here <https://github.com/PatRyg99/HausdorffLoss/
  27. blob/master/hausdorff_loss.py>`__.
  28. Args:
  29. alpha: controls the erosion rate in each iteration.
  30. k: the number of iterations of erosion.
  31. reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.
  32. 'none': no reduction will be applied, 'mean': the weighted mean of the output is taken,
  33. 'sum': the output will be summed.
  34. Returns:
  35. Estimated Hausdorff Loss.
  36. """
  37. conv: Callable[..., Tensor]
  38. max_pool: Callable[..., Tensor]
  39. def __init__(self, alpha: float = 2.0, k: int = 10, reduction: str = "mean") -> None:
  40. super().__init__()
  41. self.alpha = alpha
  42. self.k = k
  43. self.reduction = reduction
  44. self.register_buffer("kernel", self.get_kernel())
  45. def get_kernel(self) -> Tensor:
  46. """Get kernel for image morphology convolution."""
  47. raise NotImplementedError
  48. def perform_erosion(self, pred: Tensor, target: Tensor) -> Tensor:
  49. bound = (pred - target) ** 2
  50. kernel = as_tensor(self.kernel, device=pred.device, dtype=pred.dtype)
  51. eroded = zeros_like(bound, device=pred.device, dtype=pred.dtype)
  52. mask = torch.ones_like(bound, device=pred.device, dtype=torch.bool)
  53. # Same padding, assuming kernel is odd and square (cube) shaped.
  54. padding = (kernel.size(-1) - 1) // 2
  55. for k in range(self.k):
  56. # compute convolution with kernel
  57. dilation = self.conv(bound, weight=kernel, padding=padding, groups=1)
  58. # apply soft thresholding at 0.5 and normalize
  59. erosion = dilation - 0.5
  60. erosion[erosion < 0] = 0
  61. # image-wise differences for 2D images
  62. erosion_max = self.max_pool(erosion)
  63. erosion_min = -self.max_pool(-erosion)
  64. # No normalization needed if `max - min = 0`
  65. _to_norm = (erosion_max - erosion_min) != 0
  66. to_norm = _to_norm.squeeze()
  67. if to_norm.any():
  68. # NOTE: avoid in-place ops like below, which will not pass gradcheck:
  69. # erosion[to_norm] = (erosion[to_norm] - erosion_min[to_norm]) / (
  70. # erosion_max[to_norm] - erosion_min[to_norm])
  71. _erosion_to_fill = (erosion - erosion_min) / (erosion_max - erosion_min)
  72. erosion = where(mask * _to_norm, _erosion_to_fill, erosion)
  73. # save erosion and add to loss
  74. eroded = eroded + erosion * (k + 1) ** self.alpha
  75. bound = erosion
  76. return eroded
  77. def forward(self, pred: Tensor, target: Tensor) -> Tensor:
  78. """Compute Hausdorff loss.
  79. Args:
  80. pred: predicted tensor with a shape of :math:`(B, C, H, W)` or :math:`(B, C, D, H, W)`.
  81. Each channel is as binary as: 1 -> fg, 0 -> bg.
  82. target: target tensor with a shape of :math:`(B, 1, H, W)` or :math:`(B, C, D, H, W)`.
  83. Returns:
  84. Estimated Hausdorff Loss.
  85. """
  86. if not (pred.shape[2:] == target.shape[2:] and pred.size(0) == target.size(0) and target.size(1) == 1):
  87. raise ValueError(
  88. "Prediction and target need to be of same size, and target should not be one-hot."
  89. f"Got {pred.shape} and {target.shape}."
  90. )
  91. if pred.size(1) < target.max().item():
  92. raise ValueError("Invalid target value.")
  93. out = stack(
  94. [
  95. self.perform_erosion(
  96. pred[:, i : i + 1],
  97. where(
  98. target == i,
  99. tensor(1, device=target.device, dtype=target.dtype),
  100. tensor(0, device=target.device, dtype=target.dtype),
  101. ),
  102. )
  103. for i in range(pred.size(1))
  104. ]
  105. )
  106. if self.reduction == "mean":
  107. out = out.mean()
  108. elif self.reduction == "sum":
  109. out = out.sum()
  110. elif self.reduction == "none":
  111. pass
  112. else:
  113. raise NotImplementedError(f"reduction `{self.reduction}` has not been implemented yet.")
  114. return out
  115. class HausdorffERLoss(_HausdorffERLossBase):
  116. r"""Binary Hausdorff loss based on morphological erosion.
  117. Hausdorff Distance loss measures the maximum distance of a predicted segmentation boundary to
  118. the nearest ground-truth edge pixel. For two segmentation point sets X and Y ,
  119. the one-sided HD from X to Y is defined as:
  120. .. math::
  121. hd(X,Y) = \max_{x \in X} \min_{y \in Y}||x - y||_2
  122. Furthermore, the bidirectional HD is:
  123. .. math::
  124. HD(X,Y) = max(hd(X, Y), hd(Y, X))
  125. This is an Hausdorff Distance (HD) Loss that based on morphological erosion, which provided
  126. a differentiable approximation of Hausdorff distance as stated in :cite:`karimi2019reducing`.
  127. The code is refactored on top of `here <https://github.com/PatRyg99/HausdorffLoss/
  128. blob/master/hausdorff_loss.py>`__.
  129. Args:
  130. alpha: controls the erosion rate in each iteration.
  131. k: the number of iterations of erosion.
  132. reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.
  133. 'none': no reduction will be applied, 'mean': the weighted mean of the output is taken,
  134. 'sum': the output will be summed.
  135. Examples:
  136. >>> hdloss = HausdorffERLoss()
  137. >>> input = torch.randn(5, 3, 20, 20)
  138. >>> target = (torch.rand(5, 1, 20, 20) * 2).long()
  139. >>> res = hdloss(input, target)
  140. """
  141. conv = torch.conv2d
  142. max_pool = nn.AdaptiveMaxPool2d(1)
  143. def get_kernel(self) -> Tensor:
  144. """Get kernel for image morphology convolution."""
  145. cross = tensor([[[0, 1, 0], [1, 1, 1], [0, 1, 0]]])
  146. kernel = cross * 0.2
  147. return kernel[None]
  148. def forward(self, pred: Tensor, target: Tensor) -> Tensor:
  149. """Compute Hausdorff loss.
  150. Args:
  151. pred: predicted tensor with a shape of :math:`(B, C, H, W)`.
  152. Each channel is as binary as: 1 -> fg, 0 -> bg.
  153. target: target tensor with a shape of :math:`(B, 1, H, W)`.
  154. Returns:
  155. Estimated Hausdorff Loss.
  156. """
  157. if pred.dim() != 4:
  158. raise ValueError(f"Only 2D images supported. Got {pred.dim()}.")
  159. if not (target.max() < pred.size(1) and target.min() >= 0 and target.dtype == torch.long):
  160. raise ValueError(
  161. f"Expect long type target value in range (0, {pred.size(1)}). ({target.min()}, {target.max()})"
  162. )
  163. return super().forward(pred, target)
  164. class HausdorffERLoss3D(_HausdorffERLossBase):
  165. r"""Binary 3D Hausdorff loss based on morphological erosion.
  166. Hausdorff Distance loss measures the maximum distance of a predicted segmentation boundary to
  167. the nearest ground-truth edge pixel. For two segmentation point sets X and Y ,
  168. the one-sided HD from X to Y is defined as:
  169. .. math::
  170. hd(X,Y) = \max_{x \in X} \min_{y \in Y}||x - y||_2
  171. Furthermore, the bidirectional HD is:
  172. .. math::
  173. HD(X,Y) = max(hd(X, Y), hd(Y, X))
  174. This is a 3D Hausdorff Distance (HD) Loss that based on morphological erosion, which provided
  175. a differentiable approximation of Hausdorff distance as stated in :cite:`karimi2019reducing`.
  176. The code is refactored on top of `here <https://github.com/PatRyg99/HausdorffLoss/
  177. blob/master/hausdorff_loss.py>`__.
  178. Args:
  179. alpha: controls the erosion rate in each iteration.
  180. k: the number of iterations of erosion.
  181. reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.
  182. 'none': no reduction will be applied, 'mean': the weighted mean of the output is taken,
  183. 'sum': the output will be summed.
  184. Examples:
  185. >>> hdloss = HausdorffERLoss3D()
  186. >>> input = torch.randn(5, 3, 20, 20, 20)
  187. >>> target = (torch.rand(5, 1, 20, 20, 20) * 2).long()
  188. >>> res = hdloss(input, target)
  189. """
  190. conv = torch.conv3d
  191. max_pool = nn.AdaptiveMaxPool3d(1)
  192. def get_kernel(self) -> Tensor:
  193. """Get kernel for image morphology convolution."""
  194. cross = tensor([[[0, 1, 0], [1, 1, 1], [0, 1, 0]]])
  195. bound = tensor([[[0, 0, 0], [0, 1, 0], [0, 0, 0]]])
  196. # NOTE: The original repo claimed it shaped as (3, 1, 3, 3)
  197. # which Jian suspect it is wrongly implemented.
  198. # https://github.com/PatRyg99/HausdorffLoss/blob/9f580acd421af648e74b45d46555ccb7a876c27c/hausdorff_loss.py#L94
  199. kernel = stack([bound, cross, bound], 1) * (1 / 7)
  200. return kernel[None]
  201. def forward(self, pred: Tensor, target: Tensor) -> Tensor:
  202. """Compute 3D Hausdorff loss.
  203. Args:
  204. pred: predicted tensor with a shape of :math:`(B, C, D, H, W)`.
  205. Each channel is as binary as: 1 -> fg, 0 -> bg.
  206. target: target tensor with a shape of :math:`(B, 1, D, H, W)`.
  207. Returns:
  208. Estimated Hausdorff Loss.
  209. """
  210. if pred.dim() != 5:
  211. raise ValueError(f"Only 3D images supported. Got {pred.dim()}.")
  212. return super().forward(pred, target)