segmentation_models.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. from typing import Any, Optional
  19. from torch import nn
  20. import kornia
  21. from kornia.core import Module, tensor
  22. from kornia.core.external import segmentation_models_pytorch as smp
  23. from .base import SemanticSegmentation
  24. __all__ = ["SegmentationModelsBuilder"]
  25. class SegmentationModelsBuilder:
  26. @staticmethod
  27. def build(
  28. model_name: str = "Unet",
  29. encoder_name: str = "resnet34",
  30. encoder_weights: Optional[str] = "imagenet",
  31. in_channels: int = 3,
  32. classes: int = 1,
  33. activation: str = "softmax",
  34. **kwargs: Any,
  35. ) -> SemanticSegmentation:
  36. """SegmentationModel is a module that wraps a segmentation model.
  37. This module uses SegmentationModel library for segmentation.
  38. Args:
  39. model_name: Name of the model to use. Valid options are:
  40. "Unet", "UnetPlusPlus", "MAnet", "LinkNet", "FPN", "PSPNet", "PAN", "DeepLabV3", "DeepLabV3Plus".
  41. encoder_name: Name of the encoder to use.
  42. encoder_depth: Depth of the encoder.
  43. encoder_weights: Weights of the encoder.
  44. decoder_channels: Number of channels in the decoder.
  45. in_channels: Number of channels in the input.
  46. classes: Number of classes to predict.
  47. activation: Type of activation layer.
  48. **kwargs: Additional arguments to pass to the model. Detailed arguments can be found at:
  49. https://github.com/qubvel-org/segmentation_models.pytorch/tree/main/segmentation_models_pytorch/decoders
  50. Note:
  51. Only encoder weights are available.
  52. Pretrained weights for the whole model are not available.
  53. """
  54. preproc_params = smp.encoders.get_preprocessing_params(encoder_name) # type: ignore
  55. preprocessor = SegmentationModelsBuilder.get_preprocessing_pipeline(preproc_params)
  56. segmentation_model = getattr(smp, model_name)(
  57. encoder_name=encoder_name,
  58. encoder_weights=encoder_weights,
  59. in_channels=in_channels,
  60. classes=classes,
  61. activation=activation,
  62. **kwargs,
  63. )
  64. return SemanticSegmentation(
  65. model=segmentation_model,
  66. pre_processor=preprocessor,
  67. post_processor=nn.Identity(),
  68. name=f"{model_name}_{encoder_name}",
  69. )
  70. @staticmethod
  71. def get_preprocessing_pipeline(preproc_params: dict[str, Any]) -> kornia.augmentation.container.ImageSequential:
  72. # Ensure the color space transformation is ONNX-friendly
  73. proc_sequence: list[Module] = []
  74. input_space = preproc_params["input_space"]
  75. if input_space == "BGR":
  76. proc_sequence.append(kornia.color.BgrToRgb())
  77. elif input_space == "RGB":
  78. pass
  79. else:
  80. raise ValueError(f"Unsupported input space: {input_space}")
  81. # Normalize input range if needed
  82. input_range = preproc_params["input_range"]
  83. if input_range[1] == 255:
  84. proc_sequence.append(kornia.enhance.Normalize(mean=0.0, std=1 / 255.0))
  85. elif input_range[1] == 1:
  86. pass
  87. else:
  88. raise ValueError(f"Unsupported input range: {input_range}")
  89. # Handle mean and std normalization
  90. if preproc_params["mean"] is not None:
  91. mean = tensor([preproc_params["mean"]])
  92. else:
  93. mean = tensor(0.0)
  94. if preproc_params["std"] is not None:
  95. std = tensor([preproc_params["std"]])
  96. else:
  97. std = tensor(1.0)
  98. proc_sequence.append(kornia.enhance.Normalize(mean=mean, std=std))
  99. return kornia.augmentation.container.ImageSequential(*proc_sequence)