| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from __future__ import annotations
- from typing import ClassVar, Optional, Union
- import torch
- import kornia
- from kornia.core import Tensor
- from kornia.core.external import PILImage as Image
- from kornia.models.base import ModelBase
- __all__ = ["SemanticSegmentation"]
- class SemanticSegmentation(ModelBase):
- """Semantic Segmentation is a module that wraps a semantic segmentation model.
- This module uses SegmentationModel library for semantic segmentation.
- """
- ONNX_DEFAULT_INPUTSHAPE: ClassVar[list[int]] = [-1, 3, -1, -1]
- ONNX_DEFAULT_OUTPUTSHAPE: ClassVar[list[int]] = [-1, -1, -1, -1]
- @torch.inference_mode()
- def forward(self, images: Union[Tensor, list[Tensor]]) -> Union[Tensor, list[Tensor]]:
- """Forward pass of the semantic segmentation model.
- Args:
- images: If list of RGB images. Each image is a Tensor with shape :math:`(3, H, W)`.
- If Tensor, a Tensor with shape :math:`(B, 3, H, W)`.
- Returns:
- output tensor.
- """
- outputs: Union[Tensor, list[Tensor]]
- if isinstance(
- images,
- (
- list,
- tuple,
- ),
- ):
- outputs = []
- for image in images:
- image = self.pre_processor(image[None])
- output = self.model(image)
- output = self.post_processor(output)
- outputs.append(output[0])
- else:
- images = self.pre_processor(images)
- outputs = self.model(images)
- outputs = self.post_processor(outputs)
- return outputs
- def get_colormap(self, num_classes: int, colormap: str = "random", manual_seed: int = 2147) -> Tensor:
- """Get a color map of size num_classes.
- Args:
- num_classes: The number of colors in the color map.
- colormap: The colormap to use, can be "random" or a custom color map.
- manual_seed: The manual seed to use for the colormap.
- Returns:
- A tensor of shape (num_classes, 3) representing the color map.
- """
- if colormap == "random":
- # Generate a color for each class
- g_cpu = torch.Generator()
- g_cpu.manual_seed(manual_seed)
- colors = torch.rand(num_classes, 3, generator=g_cpu)
- else:
- raise ValueError(f"Unsupported colormap: {colormap}")
- return colors
- def visualize_output(self, semantic_mask: Tensor, colors: Tensor) -> Tensor:
- """Visualize the output of the segmentation model.
- Args:
- semantic_mask: The output of the segmentation model. Shape should be (C, H, W) or (B, C, H, W).
- colors: The color map to use for visualizing the output of the segmentation model.
- Shape should be (num_classes, 3).
- Returns:
- A tensor of shape (3, H, W) or (B, 3, H, W) representing the visualized output of the segmentation model.
- Raises:
- ValueError: If the shape of the semantic mask is not of shape (C, H, W) or (B, C, H, W).
- ValueError: If the shape of the colors is not of shape (num_classes, 3).
- ValueError: If only muliclass segmentation is supported. Please ensure a softmax is used, or submit a PR.
- """
- if semantic_mask.dim() == 3:
- channel_dim = 0
- elif semantic_mask.dim() == 4:
- channel_dim = 1
- else:
- raise ValueError(f"Semantic mask must be of shape (C, H, W) or (B, C, H, W), got {semantic_mask.shape}.")
- if torch.allclose(
- semantic_mask.sum(dim=channel_dim), torch.tensor(1, dtype=semantic_mask.dtype, device=semantic_mask.device)
- ):
- # Softmax is used, thus, muliclass segmentation
- semantic_mask = semantic_mask.argmax(dim=channel_dim, keepdim=True)
- # Create a colormap for each pixel based on the class with the highest probability
- output = colors[semantic_mask.squeeze(channel_dim)]
- if semantic_mask.dim() == 3:
- output = output.permute(2, 0, 1)
- elif semantic_mask.dim() == 4:
- output = output.permute(0, 3, 1, 2)
- else:
- raise ValueError(
- f"Semantic mask must be of shape (C, H, W) or (B, C, H, W), got {semantic_mask.shape}."
- )
- else:
- raise ValueError(
- "Only muliclass segmentation is supported. Please ensure a softmax is used, or submit a PR."
- )
- return output
- def visualize(
- self,
- images: Union[Tensor, list[Tensor]],
- semantic_masks: Optional[Union[Tensor, list[Tensor]]] = None,
- output_type: str = "torch",
- colormap: str = "random",
- manual_seed: int = 2147,
- ) -> Union[Tensor, list[Tensor], list[Image.Image]]: # type: ignore
- """Visualize the segmentation masks.
- Args:
- images: If list of RGB images. Each image is a Tensor with shape :math:`(3, H, W)`.
- If Tensor, a Tensor with shape :math:`(B, 3, H, W)`.
- semantic_masks: If list of segmentation masks. Each mask is a Tensor with shape :math:`(C, H, W)`.
- If Tensor, a Tensor with shape :math:`(B, C, H, W)`.
- output_type: The type of output, can be "torch" or "PIL".
- colormap: The colormap to use, can be "random" or a custom color map.
- manual_seed: The manual seed to use for the colormap.
- """
- if semantic_masks is None:
- semantic_masks = self.forward(images)
- outputs: Union[Tensor, list[Tensor]]
- if isinstance(
- semantic_masks,
- (
- list,
- tuple,
- ),
- ):
- outputs = []
- for semantic_mask in semantic_masks:
- if semantic_mask.ndim != 3:
- raise ValueError(f"Semantic mask must be of shape (C, H, W), got {semantic_mask.shape}.")
- # Generate a color for each class
- colors = self.get_colormap(semantic_mask.size(0), colormap, manual_seed=manual_seed)
- outputs.append(self.visualize_output(semantic_mask, colors))
- else:
- # Generate a color for each class
- colors = self.get_colormap(semantic_masks.size(1), colormap, manual_seed=manual_seed)
- outputs = self.visualize_output(semantic_masks, colors)
- return self._tensor_to_type(outputs, output_type, is_batch=True if isinstance(outputs, Tensor) else False)
- def save(
- self,
- images: Union[Tensor, list[Tensor]],
- semantic_masks: Optional[Union[Tensor, list[Tensor]]] = None,
- directory: Optional[str] = None,
- output_type: str = "torch",
- colormap: str = "random",
- manual_seed: int = 2147,
- ) -> None:
- """Save the segmentation results.
- Args:
- images: If list of RGB images. Each image is a Tensor with shape :math:`(3, H, W)`.
- If Tensor, a Tensor with shape :math:`(B, 3, H, W)`.
- semantic_masks: If list of segmentation masks. Each mask is a Tensor with shape :math:`(C, H, W)`.
- If Tensor, a Tensor with shape :math:`(B, C, H, W)`.
- directory: The directory to save the results.
- output_type: The type of output, can be "torch" or "PIL".
- colormap: The colormap to use, can be "random" or a custom color map.
- manual_seed: The manual seed to use for the colormap.
- """
- colored_masks = self.visualize(images, semantic_masks, output_type, colormap=colormap, manual_seed=manual_seed)
- overlaid: Union[Tensor, list[Tensor]]
- if isinstance(images, Tensor) and isinstance(colored_masks, Tensor):
- overlaid = kornia.enhance.add_weighted(images, 0.5, colored_masks, 0.5, 1.0)
- elif isinstance(
- images,
- (
- list,
- tuple,
- ),
- ) and isinstance(
- colored_masks,
- (
- list,
- tuple,
- ),
- ):
- overlaid = []
- for i in range(len(images)):
- overlaid.append(kornia.enhance.add_weighted(images[i][None], 0.5, colored_masks[i][None], 0.5, 1.0)[0])
- else:
- raise ValueError(f"`images` should be a Tensor or a list of Tensors. Got {type(images)}")
- self._save_outputs(images, directory, suffix="_src")
- self._save_outputs(colored_masks, directory, suffix="_mask")
- self._save_outputs(overlaid, directory, suffix="_overlay")
|