retinanet.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903
  1. import math
  2. import warnings
  3. from collections import OrderedDict
  4. from functools import partial
  5. from typing import Any, Callable, Optional
  6. import torch
  7. from torch import nn, Tensor
  8. from ...ops import boxes as box_ops, misc as misc_nn_ops, sigmoid_focal_loss
  9. from ...ops.feature_pyramid_network import LastLevelP6P7
  10. from ...transforms._presets import ObjectDetection
  11. from ...utils import _log_api_usage_once
  12. from .._api import register_model, Weights, WeightsEnum
  13. from .._meta import _COCO_CATEGORIES
  14. from .._utils import _ovewrite_value_param, handle_legacy_interface
  15. from ..resnet import resnet50, ResNet50_Weights
  16. from . import _utils as det_utils
  17. from ._utils import _box_loss, overwrite_eps
  18. from .anchor_utils import AnchorGenerator
  19. from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
  20. from .transform import GeneralizedRCNNTransform
  21. __all__ = [
  22. "RetinaNet",
  23. "RetinaNet_ResNet50_FPN_Weights",
  24. "RetinaNet_ResNet50_FPN_V2_Weights",
  25. "retinanet_resnet50_fpn",
  26. "retinanet_resnet50_fpn_v2",
  27. ]
  28. def _sum(x: list[Tensor]) -> Tensor:
  29. res = x[0]
  30. for i in x[1:]:
  31. res = res + i
  32. return res
  33. def _v1_to_v2_weights(state_dict, prefix):
  34. for i in range(4):
  35. for type in ["weight", "bias"]:
  36. old_key = f"{prefix}conv.{2*i}.{type}"
  37. new_key = f"{prefix}conv.{i}.0.{type}"
  38. if old_key in state_dict:
  39. state_dict[new_key] = state_dict.pop(old_key)
  40. def _default_anchorgen():
  41. anchor_sizes = tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512])
  42. aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
  43. anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
  44. return anchor_generator
  45. class RetinaNetHead(nn.Module):
  46. """
  47. A regression and classification head for use in RetinaNet.
  48. Args:
  49. in_channels (int): number of channels of the input feature
  50. num_anchors (int): number of anchors to be predicted
  51. num_classes (int): number of classes to be predicted
  52. norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
  53. """
  54. def __init__(self, in_channels, num_anchors, num_classes, norm_layer: Optional[Callable[..., nn.Module]] = None):
  55. super().__init__()
  56. self.classification_head = RetinaNetClassificationHead(
  57. in_channels, num_anchors, num_classes, norm_layer=norm_layer
  58. )
  59. self.regression_head = RetinaNetRegressionHead(in_channels, num_anchors, norm_layer=norm_layer)
  60. def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
  61. # type: (list[dict[str, Tensor]], dict[str, Tensor], list[Tensor], list[Tensor]) -> dict[str, Tensor]
  62. return {
  63. "classification": self.classification_head.compute_loss(targets, head_outputs, matched_idxs),
  64. "bbox_regression": self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs),
  65. }
  66. def forward(self, x):
  67. # type: (list[Tensor]) -> dict[str, Tensor]
  68. return {"cls_logits": self.classification_head(x), "bbox_regression": self.regression_head(x)}
  69. class RetinaNetClassificationHead(nn.Module):
  70. """
  71. A classification head for use in RetinaNet.
  72. Args:
  73. in_channels (int): number of channels of the input feature
  74. num_anchors (int): number of anchors to be predicted
  75. num_classes (int): number of classes to be predicted
  76. norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
  77. """
  78. _version = 2
  79. def __init__(
  80. self,
  81. in_channels,
  82. num_anchors,
  83. num_classes,
  84. prior_probability=0.01,
  85. norm_layer: Optional[Callable[..., nn.Module]] = None,
  86. ):
  87. super().__init__()
  88. conv = []
  89. for _ in range(4):
  90. conv.append(misc_nn_ops.Conv2dNormActivation(in_channels, in_channels, norm_layer=norm_layer))
  91. self.conv = nn.Sequential(*conv)
  92. for layer in self.conv.modules():
  93. if isinstance(layer, nn.Conv2d):
  94. torch.nn.init.normal_(layer.weight, std=0.01)
  95. if layer.bias is not None:
  96. torch.nn.init.constant_(layer.bias, 0)
  97. self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
  98. torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
  99. torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability))
  100. self.num_classes = num_classes
  101. self.num_anchors = num_anchors
  102. # This is to fix using det_utils.Matcher.BETWEEN_THRESHOLDS in TorchScript.
  103. # TorchScript doesn't support class attributes.
  104. # https://github.com/pytorch/vision/pull/1697#issuecomment-630255584
  105. self.BETWEEN_THRESHOLDS = det_utils.Matcher.BETWEEN_THRESHOLDS
  106. def _load_from_state_dict(
  107. self,
  108. state_dict,
  109. prefix,
  110. local_metadata,
  111. strict,
  112. missing_keys,
  113. unexpected_keys,
  114. error_msgs,
  115. ):
  116. version = local_metadata.get("version", None)
  117. if version is None or version < 2:
  118. _v1_to_v2_weights(state_dict, prefix)
  119. super()._load_from_state_dict(
  120. state_dict,
  121. prefix,
  122. local_metadata,
  123. strict,
  124. missing_keys,
  125. unexpected_keys,
  126. error_msgs,
  127. )
  128. def compute_loss(self, targets, head_outputs, matched_idxs):
  129. # type: (list[dict[str, Tensor]], dict[str, Tensor], list[Tensor]) -> Tensor
  130. losses = []
  131. cls_logits = head_outputs["cls_logits"]
  132. for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs):
  133. # determine only the foreground
  134. foreground_idxs_per_image = matched_idxs_per_image >= 0
  135. num_foreground = foreground_idxs_per_image.sum()
  136. # create the target classification
  137. gt_classes_target = torch.zeros_like(cls_logits_per_image)
  138. gt_classes_target[
  139. foreground_idxs_per_image,
  140. targets_per_image["labels"][matched_idxs_per_image[foreground_idxs_per_image]],
  141. ] = 1.0
  142. # find indices for which anchors should be ignored
  143. valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS
  144. # compute the classification loss
  145. losses.append(
  146. sigmoid_focal_loss(
  147. cls_logits_per_image[valid_idxs_per_image],
  148. gt_classes_target[valid_idxs_per_image],
  149. reduction="sum",
  150. )
  151. / max(1, num_foreground)
  152. )
  153. return _sum(losses) / len(targets)
  154. def forward(self, x):
  155. # type: (list[Tensor]) -> Tensor
  156. all_cls_logits = []
  157. for features in x:
  158. cls_logits = self.conv(features)
  159. cls_logits = self.cls_logits(cls_logits)
  160. # Permute classification output from (N, A * K, H, W) to (N, HWA, K).
  161. N, _, H, W = cls_logits.shape
  162. cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
  163. cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
  164. cls_logits = cls_logits.reshape(N, -1, self.num_classes) # Size=(N, HWA, 4)
  165. all_cls_logits.append(cls_logits)
  166. return torch.cat(all_cls_logits, dim=1)
  167. class RetinaNetRegressionHead(nn.Module):
  168. """
  169. A regression head for use in RetinaNet.
  170. Args:
  171. in_channels (int): number of channels of the input feature
  172. num_anchors (int): number of anchors to be predicted
  173. norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
  174. """
  175. _version = 2
  176. __annotations__ = {
  177. "box_coder": det_utils.BoxCoder,
  178. }
  179. def __init__(self, in_channels, num_anchors, norm_layer: Optional[Callable[..., nn.Module]] = None):
  180. super().__init__()
  181. conv = []
  182. for _ in range(4):
  183. conv.append(misc_nn_ops.Conv2dNormActivation(in_channels, in_channels, norm_layer=norm_layer))
  184. self.conv = nn.Sequential(*conv)
  185. self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)
  186. torch.nn.init.normal_(self.bbox_reg.weight, std=0.01)
  187. torch.nn.init.zeros_(self.bbox_reg.bias)
  188. for layer in self.conv.modules():
  189. if isinstance(layer, nn.Conv2d):
  190. torch.nn.init.normal_(layer.weight, std=0.01)
  191. if layer.bias is not None:
  192. torch.nn.init.zeros_(layer.bias)
  193. self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
  194. self._loss_type = "l1"
  195. def _load_from_state_dict(
  196. self,
  197. state_dict,
  198. prefix,
  199. local_metadata,
  200. strict,
  201. missing_keys,
  202. unexpected_keys,
  203. error_msgs,
  204. ):
  205. version = local_metadata.get("version", None)
  206. if version is None or version < 2:
  207. _v1_to_v2_weights(state_dict, prefix)
  208. super()._load_from_state_dict(
  209. state_dict,
  210. prefix,
  211. local_metadata,
  212. strict,
  213. missing_keys,
  214. unexpected_keys,
  215. error_msgs,
  216. )
  217. def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
  218. # type: (list[dict[str, Tensor]], dict[str, Tensor], list[Tensor], list[Tensor]) -> Tensor
  219. losses = []
  220. bbox_regression = head_outputs["bbox_regression"]
  221. for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in zip(
  222. targets, bbox_regression, anchors, matched_idxs
  223. ):
  224. # determine only the foreground indices, ignore the rest
  225. foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]
  226. num_foreground = foreground_idxs_per_image.numel()
  227. # select only the foreground boxes
  228. matched_gt_boxes_per_image = targets_per_image["boxes"][matched_idxs_per_image[foreground_idxs_per_image]]
  229. bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
  230. anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]
  231. # compute the loss
  232. losses.append(
  233. _box_loss(
  234. self._loss_type,
  235. self.box_coder,
  236. anchors_per_image,
  237. matched_gt_boxes_per_image,
  238. bbox_regression_per_image,
  239. )
  240. / max(1, num_foreground)
  241. )
  242. return _sum(losses) / max(1, len(targets))
  243. def forward(self, x):
  244. # type: (list[Tensor]) -> Tensor
  245. all_bbox_regression = []
  246. for features in x:
  247. bbox_regression = self.conv(features)
  248. bbox_regression = self.bbox_reg(bbox_regression)
  249. # Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
  250. N, _, H, W = bbox_regression.shape
  251. bbox_regression = bbox_regression.view(N, -1, 4, H, W)
  252. bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
  253. bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4)
  254. all_bbox_regression.append(bbox_regression)
  255. return torch.cat(all_bbox_regression, dim=1)
  256. class RetinaNet(nn.Module):
  257. """
  258. Implements RetinaNet.
  259. The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
  260. image, and should be in 0-1 range. Different images can have different sizes.
  261. The behavior of the model changes depending on if it is in training or evaluation mode.
  262. During training, the model expects both the input tensors and targets (list of dictionary),
  263. containing:
  264. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  265. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  266. - labels (Int64Tensor[N]): the class label for each ground-truth box
  267. The model returns a Dict[Tensor] during training, containing the classification and regression
  268. losses.
  269. During inference, the model requires only the input tensors, and returns the post-processed
  270. predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
  271. follows:
  272. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  273. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  274. - labels (Int64Tensor[N]): the predicted labels for each image
  275. - scores (Tensor[N]): the scores for each prediction
  276. Args:
  277. backbone (nn.Module): the network used to compute the features for the model.
  278. It should contain an out_channels attribute, which indicates the number of output
  279. channels that each feature map has (and it should be the same for all feature maps).
  280. The backbone should return a single Tensor or an OrderedDict[Tensor].
  281. num_classes (int): number of output classes of the model (including the background).
  282. min_size (int): Images are rescaled before feeding them to the backbone:
  283. we attempt to preserve the aspect ratio and scale the shorter edge
  284. to ``min_size``. If the resulting longer edge exceeds ``max_size``,
  285. then downscale so that the longer edge does not exceed ``max_size``.
  286. This may result in the shorter edge beeing lower than ``min_size``.
  287. max_size (int): See ``min_size``.
  288. image_mean (Tuple[float, float, float]): mean values used for input normalization.
  289. They are generally the mean values of the dataset on which the backbone has been trained
  290. on
  291. image_std (Tuple[float, float, float]): std values used for input normalization.
  292. They are generally the std values of the dataset on which the backbone has been trained on
  293. anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
  294. maps.
  295. head (nn.Module): Module run on top of the feature pyramid.
  296. Defaults to a module containing a classification and regression module.
  297. score_thresh (float): Score threshold used for postprocessing the detections.
  298. nms_thresh (float): NMS threshold used for postprocessing the detections.
  299. detections_per_img (int): Number of best detections to keep after NMS.
  300. fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
  301. considered as positive during training.
  302. bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
  303. considered as negative during training.
  304. topk_candidates (int): Number of best detections to keep before NMS.
  305. Example:
  306. >>> import torch
  307. >>> import torchvision
  308. >>> from torchvision.models.detection import RetinaNet
  309. >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
  310. >>> # load a pre-trained model for classification and return
  311. >>> # only the features
  312. >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
  313. >>> # RetinaNet needs to know the number of
  314. >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
  315. >>> # so we need to add it here
  316. >>> backbone.out_channels = 1280
  317. >>>
  318. >>> # let's make the network generate 5 x 3 anchors per spatial
  319. >>> # location, with 5 different sizes and 3 different aspect
  320. >>> # ratios. We have a Tuple[Tuple[int]] because each feature
  321. >>> # map could potentially have different sizes and
  322. >>> # aspect ratios
  323. >>> anchor_generator = AnchorGenerator(
  324. >>> sizes=((32, 64, 128, 256, 512),),
  325. >>> aspect_ratios=((0.5, 1.0, 2.0),)
  326. >>> )
  327. >>>
  328. >>> # put the pieces together inside a RetinaNet model
  329. >>> model = RetinaNet(backbone,
  330. >>> num_classes=2,
  331. >>> anchor_generator=anchor_generator)
  332. >>> model.eval()
  333. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  334. >>> predictions = model(x)
  335. """
  336. __annotations__ = {
  337. "box_coder": det_utils.BoxCoder,
  338. "proposal_matcher": det_utils.Matcher,
  339. }
  340. def __init__(
  341. self,
  342. backbone,
  343. num_classes,
  344. # transform parameters
  345. min_size=800,
  346. max_size=1333,
  347. image_mean=None,
  348. image_std=None,
  349. # Anchor parameters
  350. anchor_generator=None,
  351. head=None,
  352. proposal_matcher=None,
  353. score_thresh=0.05,
  354. nms_thresh=0.5,
  355. detections_per_img=300,
  356. fg_iou_thresh=0.5,
  357. bg_iou_thresh=0.4,
  358. topk_candidates=1000,
  359. **kwargs,
  360. ):
  361. super().__init__()
  362. _log_api_usage_once(self)
  363. if not hasattr(backbone, "out_channels"):
  364. raise ValueError(
  365. "backbone should contain an attribute out_channels "
  366. "specifying the number of output channels (assumed to be the "
  367. "same for all the levels)"
  368. )
  369. self.backbone = backbone
  370. if not isinstance(anchor_generator, (AnchorGenerator, type(None))):
  371. raise TypeError(
  372. f"anchor_generator should be of type AnchorGenerator or None instead of {type(anchor_generator)}"
  373. )
  374. if anchor_generator is None:
  375. anchor_generator = _default_anchorgen()
  376. self.anchor_generator = anchor_generator
  377. if head is None:
  378. head = RetinaNetHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes)
  379. self.head = head
  380. if proposal_matcher is None:
  381. proposal_matcher = det_utils.Matcher(
  382. fg_iou_thresh,
  383. bg_iou_thresh,
  384. allow_low_quality_matches=True,
  385. )
  386. self.proposal_matcher = proposal_matcher
  387. self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
  388. if image_mean is None:
  389. image_mean = [0.485, 0.456, 0.406]
  390. if image_std is None:
  391. image_std = [0.229, 0.224, 0.225]
  392. self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
  393. self.score_thresh = score_thresh
  394. self.nms_thresh = nms_thresh
  395. self.detections_per_img = detections_per_img
  396. self.topk_candidates = topk_candidates
  397. # used only on torchscript mode
  398. self._has_warned = False
  399. @torch.jit.unused
  400. def eager_outputs(self, losses, detections):
  401. # type: (dict[str, Tensor], list[dict[str, Tensor]]) -> tuple[dict[str, Tensor], list[dict[str, Tensor]]]
  402. if self.training:
  403. return losses
  404. return detections
  405. def compute_loss(self, targets, head_outputs, anchors):
  406. # type: (list[dict[str, Tensor]], dict[str, Tensor], list[Tensor]) -> dict[str, Tensor]
  407. matched_idxs = []
  408. for anchors_per_image, targets_per_image in zip(anchors, targets):
  409. if targets_per_image["boxes"].numel() == 0:
  410. matched_idxs.append(
  411. torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device)
  412. )
  413. continue
  414. match_quality_matrix = box_ops.box_iou(targets_per_image["boxes"], anchors_per_image)
  415. matched_idxs.append(self.proposal_matcher(match_quality_matrix))
  416. return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)
  417. def postprocess_detections(self, head_outputs, anchors, image_shapes):
  418. # type: (dict[str, list[Tensor]], list[list[Tensor]], list[tuple[int, int]]) -> list[dict[str, Tensor]]
  419. class_logits = head_outputs["cls_logits"]
  420. box_regression = head_outputs["bbox_regression"]
  421. num_images = len(image_shapes)
  422. detections: list[dict[str, Tensor]] = []
  423. for index in range(num_images):
  424. box_regression_per_image = [br[index] for br in box_regression]
  425. logits_per_image = [cl[index] for cl in class_logits]
  426. anchors_per_image, image_shape = anchors[index], image_shapes[index]
  427. image_boxes = []
  428. image_scores = []
  429. image_labels = []
  430. for box_regression_per_level, logits_per_level, anchors_per_level in zip(
  431. box_regression_per_image, logits_per_image, anchors_per_image
  432. ):
  433. num_classes = logits_per_level.shape[-1]
  434. # remove low scoring boxes
  435. scores_per_level = torch.sigmoid(logits_per_level).flatten()
  436. keep_idxs = scores_per_level > self.score_thresh
  437. scores_per_level = scores_per_level[keep_idxs]
  438. topk_idxs = torch.where(keep_idxs)[0]
  439. # keep only topk scoring predictions
  440. num_topk = det_utils._topk_min(topk_idxs, self.topk_candidates, 0)
  441. scores_per_level, idxs = scores_per_level.topk(num_topk)
  442. topk_idxs = topk_idxs[idxs]
  443. anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
  444. labels_per_level = topk_idxs % num_classes
  445. boxes_per_level = self.box_coder.decode_single(
  446. box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
  447. )
  448. boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)
  449. image_boxes.append(boxes_per_level)
  450. image_scores.append(scores_per_level)
  451. image_labels.append(labels_per_level)
  452. image_boxes = torch.cat(image_boxes, dim=0)
  453. image_scores = torch.cat(image_scores, dim=0)
  454. image_labels = torch.cat(image_labels, dim=0)
  455. # non-maximum suppression
  456. keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
  457. keep = keep[: self.detections_per_img]
  458. detections.append(
  459. {
  460. "boxes": image_boxes[keep],
  461. "scores": image_scores[keep],
  462. "labels": image_labels[keep],
  463. }
  464. )
  465. return detections
  466. def forward(self, images, targets=None):
  467. # type: (list[Tensor], Optional[list[dict[str, Tensor]]]) -> tuple[dict[str, Tensor], list[dict[str, Tensor]]]
  468. """
  469. Args:
  470. images (list[Tensor]): images to be processed
  471. targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
  472. Returns:
  473. result (list[BoxList] or dict[Tensor]): the output from the model.
  474. During training, it returns a dict[Tensor] which contains the losses.
  475. During testing, it returns list[BoxList] contains additional fields
  476. like `scores`, `labels` and `mask` (for Mask R-CNN models).
  477. """
  478. if self.training:
  479. if targets is None:
  480. torch._assert(False, "targets should not be none when in training mode")
  481. else:
  482. for target in targets:
  483. boxes = target["boxes"]
  484. torch._assert(isinstance(boxes, torch.Tensor), "Expected target boxes to be of type Tensor.")
  485. torch._assert(
  486. len(boxes.shape) == 2 and boxes.shape[-1] == 4,
  487. "Expected target boxes to be a tensor of shape [N, 4].",
  488. )
  489. # get the original image sizes
  490. original_image_sizes: list[tuple[int, int]] = []
  491. for img in images:
  492. val = img.shape[-2:]
  493. torch._assert(
  494. len(val) == 2,
  495. f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
  496. )
  497. original_image_sizes.append((val[0], val[1]))
  498. # transform the input
  499. images, targets = self.transform(images, targets)
  500. # Check for degenerate boxes
  501. # TODO: Move this to a function
  502. if targets is not None:
  503. for target_idx, target in enumerate(targets):
  504. boxes = target["boxes"]
  505. degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
  506. if degenerate_boxes.any():
  507. # print the first degenerate box
  508. bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
  509. degen_bb: list[float] = boxes[bb_idx].tolist()
  510. torch._assert(
  511. False,
  512. "All bounding boxes should have positive height and width."
  513. f" Found invalid box {degen_bb} for target at index {target_idx}.",
  514. )
  515. # get the features from the backbone
  516. features = self.backbone(images.tensors)
  517. if isinstance(features, torch.Tensor):
  518. features = OrderedDict([("0", features)])
  519. # TODO: Do we want a list or a dict?
  520. features = list(features.values())
  521. # compute the retinanet heads outputs using the features
  522. head_outputs = self.head(features)
  523. # create the set of anchors
  524. anchors = self.anchor_generator(images, features)
  525. losses = {}
  526. detections: list[dict[str, Tensor]] = []
  527. if self.training:
  528. if targets is None:
  529. torch._assert(False, "targets should not be none when in training mode")
  530. else:
  531. # compute the losses
  532. losses = self.compute_loss(targets, head_outputs, anchors)
  533. else:
  534. # recover level sizes
  535. num_anchors_per_level = [x.size(2) * x.size(3) for x in features]
  536. HW = 0
  537. for v in num_anchors_per_level:
  538. HW += v
  539. HWA = head_outputs["cls_logits"].size(1)
  540. A = HWA // HW
  541. num_anchors_per_level = [hw * A for hw in num_anchors_per_level]
  542. # split outputs per level
  543. split_head_outputs: dict[str, list[Tensor]] = {}
  544. for k in head_outputs:
  545. split_head_outputs[k] = list(head_outputs[k].split(num_anchors_per_level, dim=1))
  546. split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors]
  547. # compute the detections
  548. detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes)
  549. detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
  550. if torch.jit.is_scripting():
  551. if not self._has_warned:
  552. warnings.warn("RetinaNet always returns a (Losses, Detections) tuple in scripting")
  553. self._has_warned = True
  554. return losses, detections
  555. return self.eager_outputs(losses, detections)
  556. _COMMON_META = {
  557. "categories": _COCO_CATEGORIES,
  558. "min_size": (1, 1),
  559. }
  560. class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
  561. COCO_V1 = Weights(
  562. url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
  563. transforms=ObjectDetection,
  564. meta={
  565. **_COMMON_META,
  566. "num_params": 34014999,
  567. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet",
  568. "_metrics": {
  569. "COCO-val2017": {
  570. "box_map": 36.4,
  571. }
  572. },
  573. "_ops": 151.54,
  574. "_file_size": 130.267,
  575. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  576. },
  577. )
  578. DEFAULT = COCO_V1
  579. class RetinaNet_ResNet50_FPN_V2_Weights(WeightsEnum):
  580. COCO_V1 = Weights(
  581. url="https://download.pytorch.org/models/retinanet_resnet50_fpn_v2_coco-5905b1c5.pth",
  582. transforms=ObjectDetection,
  583. meta={
  584. **_COMMON_META,
  585. "num_params": 38198935,
  586. "recipe": "https://github.com/pytorch/vision/pull/5756",
  587. "_metrics": {
  588. "COCO-val2017": {
  589. "box_map": 41.5,
  590. }
  591. },
  592. "_ops": 152.238,
  593. "_file_size": 146.037,
  594. "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
  595. },
  596. )
  597. DEFAULT = COCO_V1
  598. @register_model()
  599. @handle_legacy_interface(
  600. weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.COCO_V1),
  601. weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  602. )
  603. def retinanet_resnet50_fpn(
  604. *,
  605. weights: Optional[RetinaNet_ResNet50_FPN_Weights] = None,
  606. progress: bool = True,
  607. num_classes: Optional[int] = None,
  608. weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
  609. trainable_backbone_layers: Optional[int] = None,
  610. **kwargs: Any,
  611. ) -> RetinaNet:
  612. """
  613. Constructs a RetinaNet model with a ResNet-50-FPN backbone.
  614. .. betastatus:: detection module
  615. Reference: `Focal Loss for Dense Object Detection <https://arxiv.org/abs/1708.02002>`_.
  616. The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
  617. image, and should be in ``0-1`` range. Different images can have different sizes.
  618. The behavior of the model changes depending on if it is in training or evaluation mode.
  619. During training, the model expects both the input tensors and targets (list of dictionary),
  620. containing:
  621. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  622. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  623. - labels (``Int64Tensor[N]``): the class label for each ground-truth box
  624. The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
  625. losses.
  626. During inference, the model requires only the input tensors, and returns the post-processed
  627. predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
  628. follows, where ``N`` is the number of detections:
  629. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  630. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  631. - labels (``Int64Tensor[N]``): the predicted labels for each detection
  632. - scores (``Tensor[N]``): the scores of each detection
  633. For more details on the output, you may refer to :ref:`instance_seg_output`.
  634. Example::
  635. >>> model = torchvision.models.detection.retinanet_resnet50_fpn(weights=RetinaNet_ResNet50_FPN_Weights.DEFAULT)
  636. >>> model.eval()
  637. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  638. >>> predictions = model(x)
  639. Args:
  640. weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights`, optional): The
  641. pretrained weights to use. See
  642. :class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights`
  643. below for more details, and possible values. By default, no
  644. pre-trained weights are used.
  645. progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
  646. num_classes (int, optional): number of output classes of the model (including the background)
  647. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
  648. the backbone.
  649. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
  650. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
  651. passed (the default) this value is set to 3.
  652. **kwargs: parameters passed to the ``torchvision.models.detection.RetinaNet``
  653. base class. Please refer to the `source code
  654. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py>`_
  655. for more details about this class.
  656. .. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights
  657. :members:
  658. """
  659. weights = RetinaNet_ResNet50_FPN_Weights.verify(weights)
  660. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  661. if weights is not None:
  662. weights_backbone = None
  663. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  664. elif num_classes is None:
  665. num_classes = 91
  666. is_trained = weights is not None or weights_backbone is not None
  667. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  668. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  669. backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  670. # skip P2 because it generates too many anchors (according to their paper)
  671. backbone = _resnet_fpn_extractor(
  672. backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
  673. )
  674. model = RetinaNet(backbone, num_classes, **kwargs)
  675. if weights is not None:
  676. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  677. if weights == RetinaNet_ResNet50_FPN_Weights.COCO_V1:
  678. overwrite_eps(model, 0.0)
  679. return model
  680. @register_model()
  681. @handle_legacy_interface(
  682. weights=("pretrained", RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1),
  683. weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  684. )
  685. def retinanet_resnet50_fpn_v2(
  686. *,
  687. weights: Optional[RetinaNet_ResNet50_FPN_V2_Weights] = None,
  688. progress: bool = True,
  689. num_classes: Optional[int] = None,
  690. weights_backbone: Optional[ResNet50_Weights] = None,
  691. trainable_backbone_layers: Optional[int] = None,
  692. **kwargs: Any,
  693. ) -> RetinaNet:
  694. """
  695. Constructs an improved RetinaNet model with a ResNet-50-FPN backbone.
  696. .. betastatus:: detection module
  697. Reference: `Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection
  698. <https://arxiv.org/abs/1912.02424>`_.
  699. :func:`~torchvision.models.detection.retinanet_resnet50_fpn` for more details.
  700. Args:
  701. weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights`, optional): The
  702. pretrained weights to use. See
  703. :class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights`
  704. below for more details, and possible values. By default, no
  705. pre-trained weights are used.
  706. progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
  707. num_classes (int, optional): number of output classes of the model (including the background)
  708. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
  709. the backbone.
  710. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
  711. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
  712. passed (the default) this value is set to 3.
  713. **kwargs: parameters passed to the ``torchvision.models.detection.RetinaNet``
  714. base class. Please refer to the `source code
  715. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py>`_
  716. for more details about this class.
  717. .. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights
  718. :members:
  719. """
  720. weights = RetinaNet_ResNet50_FPN_V2_Weights.verify(weights)
  721. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  722. if weights is not None:
  723. weights_backbone = None
  724. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  725. elif num_classes is None:
  726. num_classes = 91
  727. is_trained = weights is not None or weights_backbone is not None
  728. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  729. backbone = resnet50(weights=weights_backbone, progress=progress)
  730. backbone = _resnet_fpn_extractor(
  731. backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(2048, 256)
  732. )
  733. anchor_generator = _default_anchorgen()
  734. head = RetinaNetHead(
  735. backbone.out_channels,
  736. anchor_generator.num_anchors_per_location()[0],
  737. num_classes,
  738. norm_layer=partial(nn.GroupNorm, 32),
  739. )
  740. head.regression_head._loss_type = "giou"
  741. model = RetinaNet(backbone, num_classes, anchor_generator=anchor_generator, head=head, **kwargs)
  742. if weights is not None:
  743. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  744. return model