mask_rcnn.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590
  1. from collections import OrderedDict
  2. from typing import Any, Callable, Optional
  3. from torch import nn
  4. from torchvision.ops import MultiScaleRoIAlign
  5. from ...ops import misc as misc_nn_ops
  6. from ...transforms._presets import ObjectDetection
  7. from .._api import register_model, Weights, WeightsEnum
  8. from .._meta import _COCO_CATEGORIES
  9. from .._utils import _ovewrite_value_param, handle_legacy_interface
  10. from ..resnet import resnet50, ResNet50_Weights
  11. from ._utils import overwrite_eps
  12. from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
  13. from .faster_rcnn import _default_anchorgen, FasterRCNN, FastRCNNConvFCHead, RPNHead
  14. __all__ = [
  15. "MaskRCNN",
  16. "MaskRCNN_ResNet50_FPN_Weights",
  17. "MaskRCNN_ResNet50_FPN_V2_Weights",
  18. "maskrcnn_resnet50_fpn",
  19. "maskrcnn_resnet50_fpn_v2",
  20. ]
  21. class MaskRCNN(FasterRCNN):
  22. """
  23. Implements Mask R-CNN.
  24. The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
  25. image, and should be in 0-1 range. Different images can have different sizes.
  26. The behavior of the model changes depending on if it is in training or evaluation mode.
  27. During training, the model expects both the input tensors and targets (list of dictionary),
  28. containing:
  29. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  30. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  31. - labels (Int64Tensor[N]): the class label for each ground-truth box
  32. - masks (UInt8Tensor[N, H, W]): the segmentation binary masks for each instance
  33. The model returns a Dict[Tensor] during training, containing the classification and regression
  34. losses for both the RPN and the R-CNN, and the mask loss.
  35. During inference, the model requires only the input tensors, and returns the post-processed
  36. predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
  37. follows:
  38. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  39. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  40. - labels (Int64Tensor[N]): the predicted labels for each image
  41. - scores (Tensor[N]): the scores or each prediction
  42. - masks (FloatTensor[N, 1, H, W]): the predicted masks for each instance, in 0-1 range. In order to
  43. obtain the final segmentation masks, the soft masks can be thresholded, generally
  44. with a value of 0.5 (mask >= 0.5)
  45. Args:
  46. backbone (nn.Module): the network used to compute the features for the model.
  47. It should contain an out_channels attribute, which indicates the number of output
  48. channels that each feature map has (and it should be the same for all feature maps).
  49. The backbone should return a single Tensor or and OrderedDict[Tensor].
  50. num_classes (int): number of output classes of the model (including the background).
  51. If box_predictor is specified, num_classes should be None.
  52. min_size (int): Images are rescaled before feeding them to the backbone:
  53. we attempt to preserve the aspect ratio and scale the shorter edge
  54. to ``min_size``. If the resulting longer edge exceeds ``max_size``,
  55. then downscale so that the longer edge does not exceed ``max_size``.
  56. This may result in the shorter edge beeing lower than ``min_size``.
  57. max_size (int): See ``min_size``.
  58. image_mean (Tuple[float, float, float]): mean values used for input normalization.
  59. They are generally the mean values of the dataset on which the backbone has been trained
  60. on
  61. image_std (Tuple[float, float, float]): std values used for input normalization.
  62. They are generally the std values of the dataset on which the backbone has been trained on
  63. rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
  64. maps.
  65. rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
  66. rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
  67. rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
  68. rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
  69. rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
  70. rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
  71. rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
  72. considered as positive during training of the RPN.
  73. rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
  74. considered as negative during training of the RPN.
  75. rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
  76. for computing the loss
  77. rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
  78. of the RPN
  79. rpn_score_thresh (float): only return proposals with an objectness score greater than rpn_score_thresh
  80. box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
  81. the locations indicated by the bounding boxes
  82. box_head (nn.Module): module that takes the cropped feature maps as input
  83. box_predictor (nn.Module): module that takes the output of box_head and returns the
  84. classification logits and box regression deltas.
  85. box_score_thresh (float): during inference, only return proposals with a classification score
  86. greater than box_score_thresh
  87. box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
  88. box_detections_per_img (int): maximum number of detections per image, for all classes.
  89. box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
  90. considered as positive during training of the classification head
  91. box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
  92. considered as negative during training of the classification head
  93. box_batch_size_per_image (int): number of proposals that are sampled during training of the
  94. classification head
  95. box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
  96. of the classification head
  97. bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
  98. bounding boxes
  99. mask_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
  100. the locations indicated by the bounding boxes, which will be used for the mask head.
  101. mask_head (nn.Module): module that takes the cropped feature maps as input
  102. mask_predictor (nn.Module): module that takes the output of the mask_head and returns the
  103. segmentation mask logits
  104. Example::
  105. >>> import torch
  106. >>> import torchvision
  107. >>> from torchvision.models.detection import MaskRCNN
  108. >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
  109. >>>
  110. >>> # load a pre-trained model for classification and return
  111. >>> # only the features
  112. >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
  113. >>> # MaskRCNN needs to know the number of
  114. >>> # output channels in a backbone. For mobilenet_v2, it's 1280
  115. >>> # so we need to add it here,
  116. >>> backbone.out_channels = 1280
  117. >>>
  118. >>> # let's make the RPN generate 5 x 3 anchors per spatial
  119. >>> # location, with 5 different sizes and 3 different aspect
  120. >>> # ratios. We have a Tuple[Tuple[int]] because each feature
  121. >>> # map could potentially have different sizes and
  122. >>> # aspect ratios
  123. >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
  124. >>> aspect_ratios=((0.5, 1.0, 2.0),))
  125. >>>
  126. >>> # let's define what are the feature maps that we will
  127. >>> # use to perform the region of interest cropping, as well as
  128. >>> # the size of the crop after rescaling.
  129. >>> # if your backbone returns a Tensor, featmap_names is expected to
  130. >>> # be ['0']. More generally, the backbone should return an
  131. >>> # OrderedDict[Tensor], and in featmap_names you can choose which
  132. >>> # feature maps to use.
  133. >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
  134. >>> output_size=7,
  135. >>> sampling_ratio=2)
  136. >>>
  137. >>> mask_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
  138. >>> output_size=14,
  139. >>> sampling_ratio=2)
  140. >>> # put the pieces together inside a MaskRCNN model
  141. >>> model = MaskRCNN(backbone,
  142. >>> num_classes=2,
  143. >>> rpn_anchor_generator=anchor_generator,
  144. >>> box_roi_pool=roi_pooler,
  145. >>> mask_roi_pool=mask_roi_pooler)
  146. >>> model.eval()
  147. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  148. >>> predictions = model(x)
  149. """
  150. def __init__(
  151. self,
  152. backbone,
  153. num_classes=None,
  154. # transform parameters
  155. min_size=800,
  156. max_size=1333,
  157. image_mean=None,
  158. image_std=None,
  159. # RPN parameters
  160. rpn_anchor_generator=None,
  161. rpn_head=None,
  162. rpn_pre_nms_top_n_train=2000,
  163. rpn_pre_nms_top_n_test=1000,
  164. rpn_post_nms_top_n_train=2000,
  165. rpn_post_nms_top_n_test=1000,
  166. rpn_nms_thresh=0.7,
  167. rpn_fg_iou_thresh=0.7,
  168. rpn_bg_iou_thresh=0.3,
  169. rpn_batch_size_per_image=256,
  170. rpn_positive_fraction=0.5,
  171. rpn_score_thresh=0.0,
  172. # Box parameters
  173. box_roi_pool=None,
  174. box_head=None,
  175. box_predictor=None,
  176. box_score_thresh=0.05,
  177. box_nms_thresh=0.5,
  178. box_detections_per_img=100,
  179. box_fg_iou_thresh=0.5,
  180. box_bg_iou_thresh=0.5,
  181. box_batch_size_per_image=512,
  182. box_positive_fraction=0.25,
  183. bbox_reg_weights=None,
  184. # Mask parameters
  185. mask_roi_pool=None,
  186. mask_head=None,
  187. mask_predictor=None,
  188. **kwargs,
  189. ):
  190. if not isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None))):
  191. raise TypeError(
  192. f"mask_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(mask_roi_pool)}"
  193. )
  194. if num_classes is not None:
  195. if mask_predictor is not None:
  196. raise ValueError("num_classes should be None when mask_predictor is specified")
  197. out_channels = backbone.out_channels
  198. if mask_roi_pool is None:
  199. mask_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
  200. if mask_head is None:
  201. mask_layers = (256, 256, 256, 256)
  202. mask_dilation = 1
  203. mask_head = MaskRCNNHeads(out_channels, mask_layers, mask_dilation)
  204. if mask_predictor is None:
  205. mask_predictor_in_channels = 256 # == mask_layers[-1]
  206. mask_dim_reduced = 256
  207. mask_predictor = MaskRCNNPredictor(mask_predictor_in_channels, mask_dim_reduced, num_classes)
  208. super().__init__(
  209. backbone,
  210. num_classes,
  211. # transform parameters
  212. min_size,
  213. max_size,
  214. image_mean,
  215. image_std,
  216. # RPN-specific parameters
  217. rpn_anchor_generator,
  218. rpn_head,
  219. rpn_pre_nms_top_n_train,
  220. rpn_pre_nms_top_n_test,
  221. rpn_post_nms_top_n_train,
  222. rpn_post_nms_top_n_test,
  223. rpn_nms_thresh,
  224. rpn_fg_iou_thresh,
  225. rpn_bg_iou_thresh,
  226. rpn_batch_size_per_image,
  227. rpn_positive_fraction,
  228. rpn_score_thresh,
  229. # Box parameters
  230. box_roi_pool,
  231. box_head,
  232. box_predictor,
  233. box_score_thresh,
  234. box_nms_thresh,
  235. box_detections_per_img,
  236. box_fg_iou_thresh,
  237. box_bg_iou_thresh,
  238. box_batch_size_per_image,
  239. box_positive_fraction,
  240. bbox_reg_weights,
  241. **kwargs,
  242. )
  243. self.roi_heads.mask_roi_pool = mask_roi_pool
  244. self.roi_heads.mask_head = mask_head
  245. self.roi_heads.mask_predictor = mask_predictor
  246. class MaskRCNNHeads(nn.Sequential):
  247. _version = 2
  248. def __init__(self, in_channels, layers, dilation, norm_layer: Optional[Callable[..., nn.Module]] = None):
  249. """
  250. Args:
  251. in_channels (int): number of input channels
  252. layers (list): feature dimensions of each FCN layer
  253. dilation (int): dilation rate of kernel
  254. norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
  255. """
  256. blocks = []
  257. next_feature = in_channels
  258. for layer_features in layers:
  259. blocks.append(
  260. misc_nn_ops.Conv2dNormActivation(
  261. next_feature,
  262. layer_features,
  263. kernel_size=3,
  264. stride=1,
  265. padding=dilation,
  266. dilation=dilation,
  267. norm_layer=norm_layer,
  268. )
  269. )
  270. next_feature = layer_features
  271. super().__init__(*blocks)
  272. for layer in self.modules():
  273. if isinstance(layer, nn.Conv2d):
  274. nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
  275. if layer.bias is not None:
  276. nn.init.zeros_(layer.bias)
  277. def _load_from_state_dict(
  278. self,
  279. state_dict,
  280. prefix,
  281. local_metadata,
  282. strict,
  283. missing_keys,
  284. unexpected_keys,
  285. error_msgs,
  286. ):
  287. version = local_metadata.get("version", None)
  288. if version is None or version < 2:
  289. num_blocks = len(self)
  290. for i in range(num_blocks):
  291. for type in ["weight", "bias"]:
  292. old_key = f"{prefix}mask_fcn{i+1}.{type}"
  293. new_key = f"{prefix}{i}.0.{type}"
  294. if old_key in state_dict:
  295. state_dict[new_key] = state_dict.pop(old_key)
  296. super()._load_from_state_dict(
  297. state_dict,
  298. prefix,
  299. local_metadata,
  300. strict,
  301. missing_keys,
  302. unexpected_keys,
  303. error_msgs,
  304. )
  305. class MaskRCNNPredictor(nn.Sequential):
  306. def __init__(self, in_channels, dim_reduced, num_classes):
  307. super().__init__(
  308. OrderedDict(
  309. [
  310. ("conv5_mask", nn.ConvTranspose2d(in_channels, dim_reduced, 2, 2, 0)),
  311. ("relu", nn.ReLU(inplace=True)),
  312. ("mask_fcn_logits", nn.Conv2d(dim_reduced, num_classes, 1, 1, 0)),
  313. ]
  314. )
  315. )
  316. for name, param in self.named_parameters():
  317. if "weight" in name:
  318. nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu")
  319. # elif "bias" in name:
  320. # nn.init.constant_(param, 0)
  321. _COMMON_META = {
  322. "categories": _COCO_CATEGORIES,
  323. "min_size": (1, 1),
  324. }
  325. class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
  326. COCO_V1 = Weights(
  327. url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth",
  328. transforms=ObjectDetection,
  329. meta={
  330. **_COMMON_META,
  331. "num_params": 44401393,
  332. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn",
  333. "_metrics": {
  334. "COCO-val2017": {
  335. "box_map": 37.9,
  336. "mask_map": 34.6,
  337. }
  338. },
  339. "_ops": 134.38,
  340. "_file_size": 169.84,
  341. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  342. },
  343. )
  344. DEFAULT = COCO_V1
  345. class MaskRCNN_ResNet50_FPN_V2_Weights(WeightsEnum):
  346. COCO_V1 = Weights(
  347. url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_v2_coco-73cbd019.pth",
  348. transforms=ObjectDetection,
  349. meta={
  350. **_COMMON_META,
  351. "num_params": 46359409,
  352. "recipe": "https://github.com/pytorch/vision/pull/5773",
  353. "_metrics": {
  354. "COCO-val2017": {
  355. "box_map": 47.4,
  356. "mask_map": 41.8,
  357. }
  358. },
  359. "_ops": 333.577,
  360. "_file_size": 177.219,
  361. "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
  362. },
  363. )
  364. DEFAULT = COCO_V1
  365. @register_model()
  366. @handle_legacy_interface(
  367. weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.COCO_V1),
  368. weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  369. )
  370. def maskrcnn_resnet50_fpn(
  371. *,
  372. weights: Optional[MaskRCNN_ResNet50_FPN_Weights] = None,
  373. progress: bool = True,
  374. num_classes: Optional[int] = None,
  375. weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
  376. trainable_backbone_layers: Optional[int] = None,
  377. **kwargs: Any,
  378. ) -> MaskRCNN:
  379. """Mask R-CNN model with a ResNet-50-FPN backbone from the `Mask R-CNN
  380. <https://arxiv.org/abs/1703.06870>`_ paper.
  381. .. betastatus:: detection module
  382. The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
  383. image, and should be in ``0-1`` range. Different images can have different sizes.
  384. The behavior of the model changes depending on if it is in training or evaluation mode.
  385. During training, the model expects both the input tensors and targets (list of dictionary),
  386. containing:
  387. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  388. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  389. - labels (``Int64Tensor[N]``): the class label for each ground-truth box
  390. - masks (``UInt8Tensor[N, H, W]``): the segmentation binary masks for each instance
  391. The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
  392. losses for both the RPN and the R-CNN, and the mask loss.
  393. During inference, the model requires only the input tensors, and returns the post-processed
  394. predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
  395. follows, where ``N`` is the number of detected instances:
  396. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  397. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  398. - labels (``Int64Tensor[N]``): the predicted labels for each instance
  399. - scores (``Tensor[N]``): the scores or each instance
  400. - masks (``UInt8Tensor[N, 1, H, W]``): the predicted masks for each instance, in ``0-1`` range. In order to
  401. obtain the final segmentation masks, the soft masks can be thresholded, generally
  402. with a value of 0.5 (``mask >= 0.5``)
  403. For more details on the output and on how to plot the masks, you may refer to :ref:`instance_seg_output`.
  404. Mask R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
  405. Example::
  406. >>> model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT)
  407. >>> model.eval()
  408. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  409. >>> predictions = model(x)
  410. >>>
  411. >>> # optionally, if you want to export the model to ONNX:
  412. >>> torch.onnx.export(model, x, "mask_rcnn.onnx", opset_version = 11)
  413. Args:
  414. weights (:class:`~torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights`, optional): The
  415. pretrained weights to use. See
  416. :class:`~torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights` below for
  417. more details, and possible values. By default, no pre-trained
  418. weights are used.
  419. progress (bool, optional): If True, displays a progress bar of the
  420. download to stderr. Default is True.
  421. num_classes (int, optional): number of output classes of the model (including the background)
  422. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
  423. pretrained weights for the backbone.
  424. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  425. final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
  426. trainable. If ``None`` is passed (the default) this value is set to 3.
  427. **kwargs: parameters passed to the ``torchvision.models.detection.mask_rcnn.MaskRCNN``
  428. base class. Please refer to the `source code
  429. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/mask_rcnn.py>`_
  430. for more details about this class.
  431. .. autoclass:: torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights
  432. :members:
  433. """
  434. weights = MaskRCNN_ResNet50_FPN_Weights.verify(weights)
  435. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  436. if weights is not None:
  437. weights_backbone = None
  438. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  439. elif num_classes is None:
  440. num_classes = 91
  441. is_trained = weights is not None or weights_backbone is not None
  442. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  443. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  444. backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  445. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  446. model = MaskRCNN(backbone, num_classes=num_classes, **kwargs)
  447. if weights is not None:
  448. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  449. if weights == MaskRCNN_ResNet50_FPN_Weights.COCO_V1:
  450. overwrite_eps(model, 0.0)
  451. return model
  452. @register_model()
  453. @handle_legacy_interface(
  454. weights=("pretrained", MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1),
  455. weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  456. )
  457. def maskrcnn_resnet50_fpn_v2(
  458. *,
  459. weights: Optional[MaskRCNN_ResNet50_FPN_V2_Weights] = None,
  460. progress: bool = True,
  461. num_classes: Optional[int] = None,
  462. weights_backbone: Optional[ResNet50_Weights] = None,
  463. trainable_backbone_layers: Optional[int] = None,
  464. **kwargs: Any,
  465. ) -> MaskRCNN:
  466. """Improved Mask R-CNN model with a ResNet-50-FPN backbone from the `Benchmarking Detection Transfer
  467. Learning with Vision Transformers <https://arxiv.org/abs/2111.11429>`_ paper.
  468. .. betastatus:: detection module
  469. :func:`~torchvision.models.detection.maskrcnn_resnet50_fpn` for more details.
  470. Args:
  471. weights (:class:`~torchvision.models.detection.MaskRCNN_ResNet50_FPN_V2_Weights`, optional): The
  472. pretrained weights to use. See
  473. :class:`~torchvision.models.detection.MaskRCNN_ResNet50_FPN_V2_Weights` below for
  474. more details, and possible values. By default, no pre-trained
  475. weights are used.
  476. progress (bool, optional): If True, displays a progress bar of the
  477. download to stderr. Default is True.
  478. num_classes (int, optional): number of output classes of the model (including the background)
  479. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
  480. pretrained weights for the backbone.
  481. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  482. final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
  483. trainable. If ``None`` is passed (the default) this value is set to 3.
  484. **kwargs: parameters passed to the ``torchvision.models.detection.mask_rcnn.MaskRCNN``
  485. base class. Please refer to the `source code
  486. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/mask_rcnn.py>`_
  487. for more details about this class.
  488. .. autoclass:: torchvision.models.detection.MaskRCNN_ResNet50_FPN_V2_Weights
  489. :members:
  490. """
  491. weights = MaskRCNN_ResNet50_FPN_V2_Weights.verify(weights)
  492. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  493. if weights is not None:
  494. weights_backbone = None
  495. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  496. elif num_classes is None:
  497. num_classes = 91
  498. is_trained = weights is not None or weights_backbone is not None
  499. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  500. backbone = resnet50(weights=weights_backbone, progress=progress)
  501. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
  502. rpn_anchor_generator = _default_anchorgen()
  503. rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
  504. box_head = FastRCNNConvFCHead(
  505. (backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
  506. )
  507. mask_head = MaskRCNNHeads(backbone.out_channels, [256, 256, 256, 256], 1, norm_layer=nn.BatchNorm2d)
  508. model = MaskRCNN(
  509. backbone,
  510. num_classes=num_classes,
  511. rpn_anchor_generator=rpn_anchor_generator,
  512. rpn_head=rpn_head,
  513. box_head=box_head,
  514. mask_head=mask_head,
  515. **kwargs,
  516. )
  517. if weights is not None:
  518. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  519. return model