| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from __future__ import annotations
- from typing import Any, Optional
- from torch import nn
- import kornia
- from kornia.core import Module, tensor
- from kornia.core.external import segmentation_models_pytorch as smp
- from .base import SemanticSegmentation
- __all__ = ["SegmentationModelsBuilder"]
- class SegmentationModelsBuilder:
- @staticmethod
- def build(
- model_name: str = "Unet",
- encoder_name: str = "resnet34",
- encoder_weights: Optional[str] = "imagenet",
- in_channels: int = 3,
- classes: int = 1,
- activation: str = "softmax",
- **kwargs: Any,
- ) -> SemanticSegmentation:
- """SegmentationModel is a module that wraps a segmentation model.
- This module uses SegmentationModel library for segmentation.
- Args:
- model_name: Name of the model to use. Valid options are:
- "Unet", "UnetPlusPlus", "MAnet", "LinkNet", "FPN", "PSPNet", "PAN", "DeepLabV3", "DeepLabV3Plus".
- encoder_name: Name of the encoder to use.
- encoder_depth: Depth of the encoder.
- encoder_weights: Weights of the encoder.
- decoder_channels: Number of channels in the decoder.
- in_channels: Number of channels in the input.
- classes: Number of classes to predict.
- activation: Type of activation layer.
- **kwargs: Additional arguments to pass to the model. Detailed arguments can be found at:
- https://github.com/qubvel-org/segmentation_models.pytorch/tree/main/segmentation_models_pytorch/decoders
- Note:
- Only encoder weights are available.
- Pretrained weights for the whole model are not available.
- """
- preproc_params = smp.encoders.get_preprocessing_params(encoder_name) # type: ignore
- preprocessor = SegmentationModelsBuilder.get_preprocessing_pipeline(preproc_params)
- segmentation_model = getattr(smp, model_name)(
- encoder_name=encoder_name,
- encoder_weights=encoder_weights,
- in_channels=in_channels,
- classes=classes,
- activation=activation,
- **kwargs,
- )
- return SemanticSegmentation(
- model=segmentation_model,
- pre_processor=preprocessor,
- post_processor=nn.Identity(),
- name=f"{model_name}_{encoder_name}",
- )
- @staticmethod
- def get_preprocessing_pipeline(preproc_params: dict[str, Any]) -> kornia.augmentation.container.ImageSequential:
- # Ensure the color space transformation is ONNX-friendly
- proc_sequence: list[Module] = []
- input_space = preproc_params["input_space"]
- if input_space == "BGR":
- proc_sequence.append(kornia.color.BgrToRgb())
- elif input_space == "RGB":
- pass
- else:
- raise ValueError(f"Unsupported input space: {input_space}")
- # Normalize input range if needed
- input_range = preproc_params["input_range"]
- if input_range[1] == 255:
- proc_sequence.append(kornia.enhance.Normalize(mean=0.0, std=1 / 255.0))
- elif input_range[1] == 1:
- pass
- else:
- raise ValueError(f"Unsupported input range: {input_range}")
- # Handle mean and std normalization
- if preproc_params["mean"] is not None:
- mean = tensor([preproc_params["mean"]])
- else:
- mean = tensor(0.0)
- if preproc_params["std"] is not None:
- std = tensor([preproc_params["std"]])
- else:
- std = tensor(1.0)
- proc_sequence.append(kornia.enhance.Normalize(mean=mean, std=std))
- return kornia.augmentation.container.ImageSequential(*proc_sequence)
|