base.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. from typing import ClassVar, Optional, Union
  19. import torch
  20. import kornia
  21. from kornia.core import Tensor
  22. from kornia.core.external import PILImage as Image
  23. from kornia.models.base import ModelBase
  24. __all__ = ["SemanticSegmentation"]
  25. class SemanticSegmentation(ModelBase):
  26. """Semantic Segmentation is a module that wraps a semantic segmentation model.
  27. This module uses SegmentationModel library for semantic segmentation.
  28. """
  29. ONNX_DEFAULT_INPUTSHAPE: ClassVar[list[int]] = [-1, 3, -1, -1]
  30. ONNX_DEFAULT_OUTPUTSHAPE: ClassVar[list[int]] = [-1, -1, -1, -1]
  31. @torch.inference_mode()
  32. def forward(self, images: Union[Tensor, list[Tensor]]) -> Union[Tensor, list[Tensor]]:
  33. """Forward pass of the semantic segmentation model.
  34. Args:
  35. images: If list of RGB images. Each image is a Tensor with shape :math:`(3, H, W)`.
  36. If Tensor, a Tensor with shape :math:`(B, 3, H, W)`.
  37. Returns:
  38. output tensor.
  39. """
  40. outputs: Union[Tensor, list[Tensor]]
  41. if isinstance(
  42. images,
  43. (
  44. list,
  45. tuple,
  46. ),
  47. ):
  48. outputs = []
  49. for image in images:
  50. image = self.pre_processor(image[None])
  51. output = self.model(image)
  52. output = self.post_processor(output)
  53. outputs.append(output[0])
  54. else:
  55. images = self.pre_processor(images)
  56. outputs = self.model(images)
  57. outputs = self.post_processor(outputs)
  58. return outputs
  59. def get_colormap(self, num_classes: int, colormap: str = "random", manual_seed: int = 2147) -> Tensor:
  60. """Get a color map of size num_classes.
  61. Args:
  62. num_classes: The number of colors in the color map.
  63. colormap: The colormap to use, can be "random" or a custom color map.
  64. manual_seed: The manual seed to use for the colormap.
  65. Returns:
  66. A tensor of shape (num_classes, 3) representing the color map.
  67. """
  68. if colormap == "random":
  69. # Generate a color for each class
  70. g_cpu = torch.Generator()
  71. g_cpu.manual_seed(manual_seed)
  72. colors = torch.rand(num_classes, 3, generator=g_cpu)
  73. else:
  74. raise ValueError(f"Unsupported colormap: {colormap}")
  75. return colors
  76. def visualize_output(self, semantic_mask: Tensor, colors: Tensor) -> Tensor:
  77. """Visualize the output of the segmentation model.
  78. Args:
  79. semantic_mask: The output of the segmentation model. Shape should be (C, H, W) or (B, C, H, W).
  80. colors: The color map to use for visualizing the output of the segmentation model.
  81. Shape should be (num_classes, 3).
  82. Returns:
  83. A tensor of shape (3, H, W) or (B, 3, H, W) representing the visualized output of the segmentation model.
  84. Raises:
  85. ValueError: If the shape of the semantic mask is not of shape (C, H, W) or (B, C, H, W).
  86. ValueError: If the shape of the colors is not of shape (num_classes, 3).
  87. ValueError: If only muliclass segmentation is supported. Please ensure a softmax is used, or submit a PR.
  88. """
  89. if semantic_mask.dim() == 3:
  90. channel_dim = 0
  91. elif semantic_mask.dim() == 4:
  92. channel_dim = 1
  93. else:
  94. raise ValueError(f"Semantic mask must be of shape (C, H, W) or (B, C, H, W), got {semantic_mask.shape}.")
  95. if torch.allclose(
  96. semantic_mask.sum(dim=channel_dim), torch.tensor(1, dtype=semantic_mask.dtype, device=semantic_mask.device)
  97. ):
  98. # Softmax is used, thus, muliclass segmentation
  99. semantic_mask = semantic_mask.argmax(dim=channel_dim, keepdim=True)
  100. # Create a colormap for each pixel based on the class with the highest probability
  101. output = colors[semantic_mask.squeeze(channel_dim)]
  102. if semantic_mask.dim() == 3:
  103. output = output.permute(2, 0, 1)
  104. elif semantic_mask.dim() == 4:
  105. output = output.permute(0, 3, 1, 2)
  106. else:
  107. raise ValueError(
  108. f"Semantic mask must be of shape (C, H, W) or (B, C, H, W), got {semantic_mask.shape}."
  109. )
  110. else:
  111. raise ValueError(
  112. "Only muliclass segmentation is supported. Please ensure a softmax is used, or submit a PR."
  113. )
  114. return output
  115. def visualize(
  116. self,
  117. images: Union[Tensor, list[Tensor]],
  118. semantic_masks: Optional[Union[Tensor, list[Tensor]]] = None,
  119. output_type: str = "torch",
  120. colormap: str = "random",
  121. manual_seed: int = 2147,
  122. ) -> Union[Tensor, list[Tensor], list[Image.Image]]: # type: ignore
  123. """Visualize the segmentation masks.
  124. Args:
  125. images: If list of RGB images. Each image is a Tensor with shape :math:`(3, H, W)`.
  126. If Tensor, a Tensor with shape :math:`(B, 3, H, W)`.
  127. semantic_masks: If list of segmentation masks. Each mask is a Tensor with shape :math:`(C, H, W)`.
  128. If Tensor, a Tensor with shape :math:`(B, C, H, W)`.
  129. output_type: The type of output, can be "torch" or "PIL".
  130. colormap: The colormap to use, can be "random" or a custom color map.
  131. manual_seed: The manual seed to use for the colormap.
  132. """
  133. if semantic_masks is None:
  134. semantic_masks = self.forward(images)
  135. outputs: Union[Tensor, list[Tensor]]
  136. if isinstance(
  137. semantic_masks,
  138. (
  139. list,
  140. tuple,
  141. ),
  142. ):
  143. outputs = []
  144. for semantic_mask in semantic_masks:
  145. if semantic_mask.ndim != 3:
  146. raise ValueError(f"Semantic mask must be of shape (C, H, W), got {semantic_mask.shape}.")
  147. # Generate a color for each class
  148. colors = self.get_colormap(semantic_mask.size(0), colormap, manual_seed=manual_seed)
  149. outputs.append(self.visualize_output(semantic_mask, colors))
  150. else:
  151. # Generate a color for each class
  152. colors = self.get_colormap(semantic_masks.size(1), colormap, manual_seed=manual_seed)
  153. outputs = self.visualize_output(semantic_masks, colors)
  154. return self._tensor_to_type(outputs, output_type, is_batch=True if isinstance(outputs, Tensor) else False)
  155. def save(
  156. self,
  157. images: Union[Tensor, list[Tensor]],
  158. semantic_masks: Optional[Union[Tensor, list[Tensor]]] = None,
  159. directory: Optional[str] = None,
  160. output_type: str = "torch",
  161. colormap: str = "random",
  162. manual_seed: int = 2147,
  163. ) -> None:
  164. """Save the segmentation results.
  165. Args:
  166. images: If list of RGB images. Each image is a Tensor with shape :math:`(3, H, W)`.
  167. If Tensor, a Tensor with shape :math:`(B, 3, H, W)`.
  168. semantic_masks: If list of segmentation masks. Each mask is a Tensor with shape :math:`(C, H, W)`.
  169. If Tensor, a Tensor with shape :math:`(B, C, H, W)`.
  170. directory: The directory to save the results.
  171. output_type: The type of output, can be "torch" or "PIL".
  172. colormap: The colormap to use, can be "random" or a custom color map.
  173. manual_seed: The manual seed to use for the colormap.
  174. """
  175. colored_masks = self.visualize(images, semantic_masks, output_type, colormap=colormap, manual_seed=manual_seed)
  176. overlaid: Union[Tensor, list[Tensor]]
  177. if isinstance(images, Tensor) and isinstance(colored_masks, Tensor):
  178. overlaid = kornia.enhance.add_weighted(images, 0.5, colored_masks, 0.5, 1.0)
  179. elif isinstance(
  180. images,
  181. (
  182. list,
  183. tuple,
  184. ),
  185. ) and isinstance(
  186. colored_masks,
  187. (
  188. list,
  189. tuple,
  190. ),
  191. ):
  192. overlaid = []
  193. for i in range(len(images)):
  194. overlaid.append(kornia.enhance.add_weighted(images[i][None], 0.5, colored_masks[i][None], 0.5, 1.0)[0])
  195. else:
  196. raise ValueError(f"`images` should be a Tensor or a list of Tensors. Got {type(images)}")
  197. self._save_outputs(images, directory, suffix="_src")
  198. self._save_outputs(colored_masks, directory, suffix="_mask")
  199. self._save_outputs(overlaid, directory, suffix="_overlay")