deeplabv3.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. from collections.abc import Sequence
  2. from functools import partial
  3. from typing import Any, Optional
  4. import torch
  5. from torch import nn
  6. from torch.nn import functional as F
  7. from ...transforms._presets import SemanticSegmentation
  8. from .._api import register_model, Weights, WeightsEnum
  9. from .._meta import _VOC_CATEGORIES
  10. from .._utils import _ovewrite_value_param, handle_legacy_interface, IntermediateLayerGetter
  11. from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights, MobileNetV3
  12. from ..resnet import ResNet, resnet101, ResNet101_Weights, resnet50, ResNet50_Weights
  13. from ._utils import _SimpleSegmentationModel
  14. from .fcn import FCNHead
  15. __all__ = [
  16. "DeepLabV3",
  17. "DeepLabV3_ResNet50_Weights",
  18. "DeepLabV3_ResNet101_Weights",
  19. "DeepLabV3_MobileNet_V3_Large_Weights",
  20. "deeplabv3_mobilenet_v3_large",
  21. "deeplabv3_resnet50",
  22. "deeplabv3_resnet101",
  23. ]
  24. class DeepLabV3(_SimpleSegmentationModel):
  25. """
  26. Implements DeepLabV3 model from
  27. `"Rethinking Atrous Convolution for Semantic Image Segmentation"
  28. <https://arxiv.org/abs/1706.05587>`_.
  29. Args:
  30. backbone (nn.Module): the network used to compute the features for the model.
  31. The backbone should return an OrderedDict[Tensor], with the key being
  32. "out" for the last feature map used, and "aux" if an auxiliary classifier
  33. is used.
  34. classifier (nn.Module): module that takes the "out" element returned from
  35. the backbone and returns a dense prediction.
  36. aux_classifier (nn.Module, optional): auxiliary classifier used during training
  37. """
  38. pass
  39. class DeepLabHead(nn.Sequential):
  40. def __init__(self, in_channels: int, num_classes: int, atrous_rates: Sequence[int] = (12, 24, 36)) -> None:
  41. super().__init__(
  42. ASPP(in_channels, atrous_rates),
  43. nn.Conv2d(256, 256, 3, padding=1, bias=False),
  44. nn.BatchNorm2d(256),
  45. nn.ReLU(),
  46. nn.Conv2d(256, num_classes, 1),
  47. )
  48. class ASPPConv(nn.Sequential):
  49. def __init__(self, in_channels: int, out_channels: int, dilation: int) -> None:
  50. modules = [
  51. nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
  52. nn.BatchNorm2d(out_channels),
  53. nn.ReLU(),
  54. ]
  55. super().__init__(*modules)
  56. class ASPPPooling(nn.Sequential):
  57. def __init__(self, in_channels: int, out_channels: int) -> None:
  58. super().__init__(
  59. nn.AdaptiveAvgPool2d(1),
  60. nn.Conv2d(in_channels, out_channels, 1, bias=False),
  61. nn.BatchNorm2d(out_channels),
  62. nn.ReLU(),
  63. )
  64. def forward(self, x: torch.Tensor) -> torch.Tensor:
  65. size = x.shape[-2:]
  66. for mod in self:
  67. x = mod(x)
  68. return F.interpolate(x, size=size, mode="bilinear", align_corners=False)
  69. class ASPP(nn.Module):
  70. def __init__(self, in_channels: int, atrous_rates: Sequence[int], out_channels: int = 256) -> None:
  71. super().__init__()
  72. modules = []
  73. modules.append(
  74. nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU())
  75. )
  76. rates = tuple(atrous_rates)
  77. for rate in rates:
  78. modules.append(ASPPConv(in_channels, out_channels, rate))
  79. modules.append(ASPPPooling(in_channels, out_channels))
  80. self.convs = nn.ModuleList(modules)
  81. self.project = nn.Sequential(
  82. nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False),
  83. nn.BatchNorm2d(out_channels),
  84. nn.ReLU(),
  85. nn.Dropout(0.5),
  86. )
  87. def forward(self, x: torch.Tensor) -> torch.Tensor:
  88. _res = []
  89. for conv in self.convs:
  90. _res.append(conv(x))
  91. res = torch.cat(_res, dim=1)
  92. return self.project(res)
  93. def _deeplabv3_resnet(
  94. backbone: ResNet,
  95. num_classes: int,
  96. aux: Optional[bool],
  97. ) -> DeepLabV3:
  98. return_layers = {"layer4": "out"}
  99. if aux:
  100. return_layers["layer3"] = "aux"
  101. backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
  102. aux_classifier = FCNHead(1024, num_classes) if aux else None
  103. classifier = DeepLabHead(2048, num_classes)
  104. return DeepLabV3(backbone, classifier, aux_classifier)
  105. _COMMON_META = {
  106. "categories": _VOC_CATEGORIES,
  107. "min_size": (1, 1),
  108. "_docs": """
  109. These weights were trained on a subset of COCO, using only the 20 categories that are present in the Pascal VOC
  110. dataset.
  111. """,
  112. }
  113. class DeepLabV3_ResNet50_Weights(WeightsEnum):
  114. COCO_WITH_VOC_LABELS_V1 = Weights(
  115. url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth",
  116. transforms=partial(SemanticSegmentation, resize_size=520),
  117. meta={
  118. **_COMMON_META,
  119. "num_params": 42004074,
  120. "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet50",
  121. "_metrics": {
  122. "COCO-val2017-VOC-labels": {
  123. "miou": 66.4,
  124. "pixel_acc": 92.4,
  125. }
  126. },
  127. "_ops": 178.722,
  128. "_file_size": 160.515,
  129. },
  130. )
  131. DEFAULT = COCO_WITH_VOC_LABELS_V1
  132. class DeepLabV3_ResNet101_Weights(WeightsEnum):
  133. COCO_WITH_VOC_LABELS_V1 = Weights(
  134. url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth",
  135. transforms=partial(SemanticSegmentation, resize_size=520),
  136. meta={
  137. **_COMMON_META,
  138. "num_params": 60996202,
  139. "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet101",
  140. "_metrics": {
  141. "COCO-val2017-VOC-labels": {
  142. "miou": 67.4,
  143. "pixel_acc": 92.4,
  144. }
  145. },
  146. "_ops": 258.743,
  147. "_file_size": 233.217,
  148. },
  149. )
  150. DEFAULT = COCO_WITH_VOC_LABELS_V1
  151. class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum):
  152. COCO_WITH_VOC_LABELS_V1 = Weights(
  153. url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth",
  154. transforms=partial(SemanticSegmentation, resize_size=520),
  155. meta={
  156. **_COMMON_META,
  157. "num_params": 11029328,
  158. "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_mobilenet_v3_large",
  159. "_metrics": {
  160. "COCO-val2017-VOC-labels": {
  161. "miou": 60.3,
  162. "pixel_acc": 91.2,
  163. }
  164. },
  165. "_ops": 10.452,
  166. "_file_size": 42.301,
  167. },
  168. )
  169. DEFAULT = COCO_WITH_VOC_LABELS_V1
  170. def _deeplabv3_mobilenetv3(
  171. backbone: MobileNetV3,
  172. num_classes: int,
  173. aux: Optional[bool],
  174. ) -> DeepLabV3:
  175. backbone = backbone.features
  176. # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
  177. # The first and last blocks are always included because they are the C0 (conv1) and Cn.
  178. stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
  179. out_pos = stage_indices[-1] # use C5 which has output_stride = 16
  180. out_inplanes = backbone[out_pos].out_channels
  181. aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8
  182. aux_inplanes = backbone[aux_pos].out_channels
  183. return_layers = {str(out_pos): "out"}
  184. if aux:
  185. return_layers[str(aux_pos)] = "aux"
  186. backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
  187. aux_classifier = FCNHead(aux_inplanes, num_classes) if aux else None
  188. classifier = DeepLabHead(out_inplanes, num_classes)
  189. return DeepLabV3(backbone, classifier, aux_classifier)
  190. @register_model()
  191. @handle_legacy_interface(
  192. weights=("pretrained", DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1),
  193. weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  194. )
  195. def deeplabv3_resnet50(
  196. *,
  197. weights: Optional[DeepLabV3_ResNet50_Weights] = None,
  198. progress: bool = True,
  199. num_classes: Optional[int] = None,
  200. aux_loss: Optional[bool] = None,
  201. weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
  202. **kwargs: Any,
  203. ) -> DeepLabV3:
  204. """Constructs a DeepLabV3 model with a ResNet-50 backbone.
  205. .. betastatus:: segmentation module
  206. Reference: `Rethinking Atrous Convolution for Semantic Image Segmentation <https://arxiv.org/abs/1706.05587>`__.
  207. Args:
  208. weights (:class:`~torchvision.models.segmentation.DeepLabV3_ResNet50_Weights`, optional): The
  209. pretrained weights to use. See
  210. :class:`~torchvision.models.segmentation.DeepLabV3_ResNet50_Weights` below for
  211. more details, and possible values. By default, no pre-trained
  212. weights are used.
  213. progress (bool, optional): If True, displays a progress bar of the
  214. download to stderr. Default is True.
  215. num_classes (int, optional): number of output classes of the model (including the background)
  216. aux_loss (bool, optional): If True, it uses an auxiliary loss
  217. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for the
  218. backbone
  219. **kwargs: unused
  220. .. autoclass:: torchvision.models.segmentation.DeepLabV3_ResNet50_Weights
  221. :members:
  222. """
  223. weights = DeepLabV3_ResNet50_Weights.verify(weights)
  224. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  225. if weights is not None:
  226. weights_backbone = None
  227. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  228. aux_loss = _ovewrite_value_param("aux_loss", aux_loss, True)
  229. elif num_classes is None:
  230. num_classes = 21
  231. backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
  232. model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
  233. if weights is not None:
  234. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  235. return model
  236. @register_model()
  237. @handle_legacy_interface(
  238. weights=("pretrained", DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1),
  239. weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1),
  240. )
  241. def deeplabv3_resnet101(
  242. *,
  243. weights: Optional[DeepLabV3_ResNet101_Weights] = None,
  244. progress: bool = True,
  245. num_classes: Optional[int] = None,
  246. aux_loss: Optional[bool] = None,
  247. weights_backbone: Optional[ResNet101_Weights] = ResNet101_Weights.IMAGENET1K_V1,
  248. **kwargs: Any,
  249. ) -> DeepLabV3:
  250. """Constructs a DeepLabV3 model with a ResNet-101 backbone.
  251. .. betastatus:: segmentation module
  252. Reference: `Rethinking Atrous Convolution for Semantic Image Segmentation <https://arxiv.org/abs/1706.05587>`__.
  253. Args:
  254. weights (:class:`~torchvision.models.segmentation.DeepLabV3_ResNet101_Weights`, optional): The
  255. pretrained weights to use. See
  256. :class:`~torchvision.models.segmentation.DeepLabV3_ResNet101_Weights` below for
  257. more details, and possible values. By default, no pre-trained
  258. weights are used.
  259. progress (bool, optional): If True, displays a progress bar of the
  260. download to stderr. Default is True.
  261. num_classes (int, optional): number of output classes of the model (including the background)
  262. aux_loss (bool, optional): If True, it uses an auxiliary loss
  263. weights_backbone (:class:`~torchvision.models.ResNet101_Weights`, optional): The pretrained weights for the
  264. backbone
  265. **kwargs: unused
  266. .. autoclass:: torchvision.models.segmentation.DeepLabV3_ResNet101_Weights
  267. :members:
  268. """
  269. weights = DeepLabV3_ResNet101_Weights.verify(weights)
  270. weights_backbone = ResNet101_Weights.verify(weights_backbone)
  271. if weights is not None:
  272. weights_backbone = None
  273. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  274. aux_loss = _ovewrite_value_param("aux_loss", aux_loss, True)
  275. elif num_classes is None:
  276. num_classes = 21
  277. backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
  278. model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
  279. if weights is not None:
  280. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  281. return model
  282. @register_model()
  283. @handle_legacy_interface(
  284. weights=("pretrained", DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1),
  285. weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
  286. )
  287. def deeplabv3_mobilenet_v3_large(
  288. *,
  289. weights: Optional[DeepLabV3_MobileNet_V3_Large_Weights] = None,
  290. progress: bool = True,
  291. num_classes: Optional[int] = None,
  292. aux_loss: Optional[bool] = None,
  293. weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
  294. **kwargs: Any,
  295. ) -> DeepLabV3:
  296. """Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.
  297. Reference: `Rethinking Atrous Convolution for Semantic Image Segmentation <https://arxiv.org/abs/1706.05587>`__.
  298. Args:
  299. weights (:class:`~torchvision.models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights`, optional): The
  300. pretrained weights to use. See
  301. :class:`~torchvision.models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights` below for
  302. more details, and possible values. By default, no pre-trained
  303. weights are used.
  304. progress (bool, optional): If True, displays a progress bar of the
  305. download to stderr. Default is True.
  306. num_classes (int, optional): number of output classes of the model (including the background)
  307. aux_loss (bool, optional): If True, it uses an auxiliary loss
  308. weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The pretrained weights
  309. for the backbone
  310. **kwargs: unused
  311. .. autoclass:: torchvision.models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights
  312. :members:
  313. """
  314. weights = DeepLabV3_MobileNet_V3_Large_Weights.verify(weights)
  315. weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
  316. if weights is not None:
  317. weights_backbone = None
  318. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  319. aux_loss = _ovewrite_value_param("aux_loss", aux_loss, True)
  320. elif num_classes is None:
  321. num_classes = 21
  322. backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True)
  323. model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss)
  324. if weights is not None:
  325. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  326. return model