connected_components.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import torch
  18. import torch.nn.functional as F
  19. from kornia.core import Tensor
  20. def connected_components(image: Tensor, num_iterations: int = 100) -> Tensor:
  21. r"""Compute the Connected-component labelling (CCL) algorithm.
  22. .. image:: https://github.com/kornia/data/raw/main/cells_segmented.png
  23. The implementation is an adaptation of the following repository:
  24. https://gist.github.com/efirdc/5d8bd66859e574c683a504a4690ae8bc
  25. .. warning::
  26. This is an experimental API subject to changes and optimization improvements.
  27. .. note::
  28. See a working example `here <https://kornia.github.io/tutorials/nbs/connected_components.html>`__.
  29. Args:
  30. image: the binarized input image with shape :math:`(*, 1, H, W)`.
  31. The image must be in floating point with range [0, 1].
  32. num_iterations: the number of iterations to make the algorithm to converge.
  33. Return:
  34. The labels image with the same shape of the input image.
  35. Example:
  36. >>> img = torch.rand(2, 1, 4, 5)
  37. >>> img_labels = connected_components(img, num_iterations=100)
  38. """
  39. if not isinstance(image, Tensor):
  40. raise TypeError(f"Input imagetype is not a Tensor. Got: {type(image)}")
  41. if not isinstance(num_iterations, int) or num_iterations < 1:
  42. raise TypeError("Input num_iterations must be a positive integer.")
  43. if len(image.shape) < 3 or image.shape[-3] != 1:
  44. raise ValueError(f"Input image shape must be (*,1,H,W). Got: {image.shape}")
  45. H, W = image.shape[-2:]
  46. image_view = image.view(-1, 1, H, W)
  47. # precompute a mask with the valid values
  48. mask = image_view == 1
  49. # allocate the output tensors for labels
  50. B, _, _, _ = image_view.shape
  51. out = torch.arange(1, B * H * W + 1, device=image.device, dtype=image.dtype).view((-1, 1, H, W))
  52. out[~mask] = 0
  53. for _ in range(num_iterations):
  54. out = F.max_pool2d(out, kernel_size=3, stride=1, padding=1)
  55. out = torch.mul(out, mask) # mask using element-wise multiplication
  56. return out.view_as(image)